001    /**
002     * Copyright (c) 2000-present Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.kernel.util;
016    
017    import com.liferay.portal.kernel.log.Log;
018    import com.liferay.portal.kernel.log.LogFactoryUtil;
019    
020    import java.lang.reflect.Field;
021    import java.lang.reflect.Modifier;
022    
023    import java.util.HashMap;
024    import java.util.HashSet;
025    import java.util.Map;
026    import java.util.Set;
027    
028    /**
029     * @author Shuyang Zhou
030     */
031    public class DefaultThreadLocalBinder implements ThreadLocalBinder {
032    
033            public void afterPropertiesSet() throws Exception {
034                    if (_threadLocalSources == null) {
035                            throw new IllegalArgumentException("Thread local sources is null");
036                    }
037    
038                    init(getClassLoader());
039            }
040    
041            @Override
042            public void bind() {
043                    Map<ThreadLocal<?>, ?> threadLocalValues = _threadLocalValues.get();
044    
045                    for (Map.Entry<ThreadLocal<?>, ?> entry :
046                                    threadLocalValues.entrySet()) {
047    
048                            ThreadLocal<Object> threadLocal =
049                                    (ThreadLocal<Object>)entry.getKey();
050                            Object value = entry.getValue();
051    
052                            threadLocal.set(value);
053                    }
054            }
055    
056            @Override
057            public void cleanUp() {
058                    for (ThreadLocal<?> threadLocal : _threadLocals) {
059                            threadLocal.remove();
060                    }
061            }
062    
063            public ClassLoader getClassLoader() {
064                    if (_classLoader == null) {
065                            Thread currentThread = Thread.currentThread();
066    
067                            _classLoader = currentThread.getContextClassLoader();
068                    }
069    
070                    return _classLoader;
071            }
072    
073            public void init(ClassLoader classLoader) throws Exception {
074                    for (Map.Entry<String, String> entry : _threadLocalSources.entrySet()) {
075                            String className = entry.getKey();
076                            String fieldName = entry.getValue();
077    
078                            Class<?> clazz = classLoader.loadClass(className);
079    
080                            Field field = ReflectionUtil.getDeclaredField(clazz, fieldName);
081    
082                            if (!ThreadLocal.class.isAssignableFrom(field.getType())) {
083                                    if (_log.isWarnEnabled()) {
084                                            _log.warn(
085                                                    fieldName +
086                                                            " is not of type ThreadLocal. Skip binding.");
087                                    }
088    
089                                    continue;
090                            }
091    
092                            if (!Modifier.isStatic(field.getModifiers())) {
093                                    if (_log.isWarnEnabled()) {
094                                            _log.warn(
095                                                    fieldName +
096                                                            " is not a static ThreadLocal. Skip binding.");
097                                    }
098    
099                                    continue;
100                            }
101    
102                            ThreadLocal<?> threadLocal = (ThreadLocal<?>)field.get(null);
103    
104                            if (threadLocal == null) {
105                                    if (_log.isWarnEnabled()) {
106                                            _log.warn(fieldName + " is not initialized. Skip binding.");
107                                    }
108    
109                                    continue;
110                            }
111    
112                            _threadLocals.add(threadLocal);
113                    }
114            }
115    
116            @Override
117            public void record() {
118                    Map<ThreadLocal<?>, Object> threadLocalValues = new HashMap<>();
119    
120                    for (ThreadLocal<?> threadLocal : _threadLocals) {
121                            Object value = threadLocal.get();
122    
123                            threadLocalValues.put(threadLocal, value);
124                    }
125    
126                    _threadLocalValues.set(threadLocalValues);
127            }
128    
129            public void setClassLoader(ClassLoader classLoader) {
130                    _classLoader = classLoader;
131            }
132    
133            public void setThreadLocalSources(Map<String, String> threadLocalSources) {
134                    _threadLocalSources = threadLocalSources;
135            }
136    
137            private static final Log _log = LogFactoryUtil.getLog(
138                    DefaultThreadLocalBinder.class);
139    
140            private static final ThreadLocal<Map<ThreadLocal<?>, ?>>
141                    _threadLocalValues = new AutoResetThreadLocal<Map<ThreadLocal<?>, ?>>(
142                            DefaultThreadLocalBinder.class + "._threadLocalValueMap") {
143    
144                            @Override
145                            protected Map<ThreadLocal<?>, ?> copy(
146                                    Map<ThreadLocal<?>, ?> threadLocalValueMap) {
147    
148                                    return threadLocalValueMap;
149                            }
150    
151                    };
152    
153            private ClassLoader _classLoader;
154            private final Set<ThreadLocal<?>> _threadLocals = new HashSet<>();
155            private Map<String, String> _threadLocalSources;
156    
157    }