001
014
015 package com.liferay.portal.kernel.util;
016
017 import com.liferay.portal.kernel.log.Log;
018 import com.liferay.portal.kernel.log.LogFactoryUtil;
019
020 import java.lang.reflect.Field;
021 import java.lang.reflect.Modifier;
022
023 import java.util.HashMap;
024 import java.util.HashSet;
025 import java.util.Map;
026 import java.util.Set;
027
028
031 public class DefaultThreadLocalBinder implements ThreadLocalBinder {
032
033 public void afterPropertiesSet() throws Exception {
034 if (_threadLocalSources == null) {
035 throw new IllegalArgumentException("Thread local sources is null");
036 }
037
038 init(getClassLoader());
039 }
040
041 @Override
042 public void bind() {
043 Map<ThreadLocal<?>, ?> threadLocalValues = _threadLocalValues.get();
044
045 for (Map.Entry<ThreadLocal<?>, ?> entry :
046 threadLocalValues.entrySet()) {
047
048 ThreadLocal<Object> threadLocal =
049 (ThreadLocal<Object>)entry.getKey();
050 Object value = entry.getValue();
051
052 threadLocal.set(value);
053 }
054 }
055
056 @Override
057 public void cleanUp() {
058 for (ThreadLocal<?> threadLocal : _threadLocals) {
059 threadLocal.remove();
060 }
061 }
062
063 public ClassLoader getClassLoader() {
064 if (_classLoader == null) {
065 Thread currentThread = Thread.currentThread();
066
067 _classLoader = currentThread.getContextClassLoader();
068 }
069
070 return _classLoader;
071 }
072
073 public void init(ClassLoader classLoader) throws Exception {
074 for (Map.Entry<String, String> entry : _threadLocalSources.entrySet()) {
075 String className = entry.getKey();
076 String fieldName = entry.getValue();
077
078 Class<?> clazz = classLoader.loadClass(className);
079
080 Field field = ReflectionUtil.getDeclaredField(clazz, fieldName);
081
082 if (!ThreadLocal.class.isAssignableFrom(field.getType())) {
083 if (_log.isWarnEnabled()) {
084 _log.warn(
085 fieldName +
086 " is not of type ThreadLocal. Skip binding.");
087 }
088
089 continue;
090 }
091
092 if (!Modifier.isStatic(field.getModifiers())) {
093 if (_log.isWarnEnabled()) {
094 _log.warn(
095 fieldName +
096 " is not a static ThreadLocal. Skip binding.");
097 }
098
099 continue;
100 }
101
102 ThreadLocal<?> threadLocal = (ThreadLocal<?>)field.get(null);
103
104 if (threadLocal == null) {
105 if (_log.isWarnEnabled()) {
106 _log.warn(fieldName + " is not initialized. Skip binding.");
107 }
108
109 continue;
110 }
111
112 _threadLocals.add(threadLocal);
113 }
114 }
115
116 @Override
117 public void record() {
118 Map<ThreadLocal<?>, Object> threadLocalValues = new HashMap<>();
119
120 for (ThreadLocal<?> threadLocal : _threadLocals) {
121 Object value = threadLocal.get();
122
123 threadLocalValues.put(threadLocal, value);
124 }
125
126 _threadLocalValues.set(threadLocalValues);
127 }
128
129 public void setClassLoader(ClassLoader classLoader) {
130 _classLoader = classLoader;
131 }
132
133 public void setThreadLocalSources(Map<String, String> threadLocalSources) {
134 _threadLocalSources = threadLocalSources;
135 }
136
137 private static final Log _log = LogFactoryUtil.getLog(
138 DefaultThreadLocalBinder.class);
139
140 private static final ThreadLocal<Map<ThreadLocal<?>, ?>>
141 _threadLocalValues = new AutoResetThreadLocal<Map<ThreadLocal<?>, ?>>(
142 DefaultThreadLocalBinder.class + "._threadLocalValueMap") {
143
144 @Override
145 protected Map<ThreadLocal<?>, ?> copy(
146 Map<ThreadLocal<?>, ?> threadLocalValueMap) {
147
148 return threadLocalValueMap;
149 }
150
151 };
152
153 private ClassLoader _classLoader;
154 private final Set<ThreadLocal<?>> _threadLocals = new HashSet<>();
155 private Map<String, String> _threadLocalSources;
156
157 }