001    /**
002     * Copyright (c) 2000-2012 Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.kernel.util;
016    
017    import com.liferay.portal.kernel.log.Log;
018    import com.liferay.portal.kernel.log.LogFactoryUtil;
019    
020    import java.lang.reflect.Field;
021    import java.lang.reflect.Modifier;
022    
023    import java.util.HashMap;
024    import java.util.HashSet;
025    import java.util.Map;
026    import java.util.Set;
027    
028    /**
029     * @author Shuyang Zhou
030     */
031    public class DefaultThreadLocalBinder implements ThreadLocalBinder {
032    
033            public void afterPropertiesSet() throws Exception {
034                    if (_threadLocalSources == null) {
035                            throw new IllegalArgumentException("Thread local sources is null");
036                    }
037    
038                    init(getClassLoader());
039            }
040    
041            public void bind() {
042                    Map<ThreadLocal<?>, ?> threadLocalValues = _threadLocalValues.get();
043    
044                    for (Map.Entry<ThreadLocal<?>, ?> entry :
045                                    threadLocalValues.entrySet()) {
046    
047                            ThreadLocal<Object> threadLocal =
048                                    (ThreadLocal<Object>)entry.getKey();
049                            Object value = entry.getValue();
050    
051                            threadLocal.set(value);
052                    }
053            }
054    
055            public void cleanUp() {
056                    for (ThreadLocal<?> threadLocal : _threadLocals) {
057                            threadLocal.remove();
058                    }
059            }
060    
061            public ClassLoader getClassLoader() {
062                    if (_classLoader == null) {
063                            Thread currentThread = Thread.currentThread();
064    
065                            _classLoader = currentThread.getContextClassLoader();
066                    }
067    
068                    return _classLoader;
069            }
070    
071            public void init(ClassLoader classLoader) throws Exception {
072                    for (Map.Entry<String, String> entry : _threadLocalSources.entrySet()) {
073                            String className = entry.getKey();
074                            String fieldName = entry.getValue();
075    
076                            Class<?> clazz = classLoader.loadClass(className);
077    
078                            Field field = ReflectionUtil.getDeclaredField(clazz, fieldName);
079    
080                            if (!ThreadLocal.class.isAssignableFrom(field.getType())) {
081                                    if (_log.isWarnEnabled()) {
082                                            _log.warn(
083                                                    fieldName +
084                                                            " is not type of ThreadLocal. Skip binding.");
085                                    }
086    
087                                    continue;
088                            }
089    
090                            if (!Modifier.isStatic(field.getModifiers())) {
091                                    if (_log.isWarnEnabled()) {
092                                            _log.warn(
093                                                    fieldName +
094                                                            " is not a static ThreadLocal. Skip binding.");
095                                    }
096    
097                                    continue;
098                            }
099    
100                            ThreadLocal<?> threadLocal = (ThreadLocal<?>)field.get(null);
101    
102                            if (threadLocal == null) {
103                                    if (_log.isWarnEnabled()) {
104                                            _log.warn(fieldName + " is not initialized. Skip binding.");
105                                    }
106    
107                                    continue;
108                            }
109    
110                            _threadLocals.add(threadLocal);
111                    }
112    
113            }
114    
115            public void record() {
116                    Map<ThreadLocal<?>, Object> threadLocalValues =
117                            new HashMap<ThreadLocal<?>, Object>();
118    
119                    for (ThreadLocal<?> threadLocal : _threadLocals) {
120                            Object value = threadLocal.get();
121    
122                            threadLocalValues.put(threadLocal, value);
123                    }
124    
125                    _threadLocalValues.set(threadLocalValues);
126            }
127    
128            public void setClassLoader(ClassLoader classLoader) {
129                    _classLoader = classLoader;
130            }
131    
132            public void setThreadLocalSources(Map<String, String> threadLocalSources) {
133                    _threadLocalSources = threadLocalSources;
134            }
135    
136            private static Log _log = LogFactoryUtil.getLog(
137                    DefaultThreadLocalBinder.class);
138    
139            private static ThreadLocal<Map<ThreadLocal<?>, ?>> _threadLocalValues =
140                    new AutoResetThreadLocal<Map<ThreadLocal<?>, ?>>(
141                            DefaultThreadLocalBinder.class + "._threadLocalValueMap") {
142    
143                            @Override
144                            protected Map<ThreadLocal<?>, ?> copy(
145                                    Map<ThreadLocal<?>, ?> threadLocalValueMap) {
146    
147                                    return threadLocalValueMap;
148                            }
149    
150                    };
151    
152            private ClassLoader _classLoader;
153            private Set<ThreadLocal<?>> _threadLocals = new HashSet<ThreadLocal<?>>();
154            private Map<String, String> _threadLocalSources;
155    
156    }