001    /**
002     * Copyright (c) 2000-present Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.kernel.util;
016    
017    import java.util.HashMap;
018    import java.util.HashSet;
019    import java.util.Map;
020    import java.util.Set;
021    import java.util.concurrent.atomic.AtomicInteger;
022    
023    /**
024     * @author Shuyang Zhou
025     */
026    public class CentralizedThreadLocal<T> extends ThreadLocal<T> {
027    
028            public static void clearLongLivedThreadLocals() {
029                    _longLivedThreadLocals.remove();
030            }
031    
032            public static void clearShortLivedThreadLocals() {
033                    _shortLivedThreadLocals.remove();
034            }
035    
036            public static Map<CentralizedThreadLocal<?>, Object>
037                    getLongLivedThreadLocals() {
038    
039                    return _toMap(_longLivedThreadLocals.get());
040            }
041    
042            public static Map<CentralizedThreadLocal<?>, Object>
043                    getShortLivedThreadLocals() {
044    
045                    return _toMap(_shortLivedThreadLocals.get());
046            }
047    
048            public static void setThreadLocals(
049                    Map<CentralizedThreadLocal<?>, Object> longLivedThreadLocals,
050                    Map<CentralizedThreadLocal<?>, Object> shortLivedThreadLocals) {
051    
052                    ThreadLocalMap threadLocalMap = _longLivedThreadLocals.get();
053    
054                    for (Map.Entry<CentralizedThreadLocal<?>, Object> entry :
055                                    longLivedThreadLocals.entrySet()) {
056    
057                            threadLocalMap.putEntry(entry.getKey(), entry.getValue());
058                    }
059    
060                    threadLocalMap = _shortLivedThreadLocals.get();
061    
062                    for (Map.Entry<CentralizedThreadLocal<?>, Object> entry :
063                                    shortLivedThreadLocals.entrySet()) {
064    
065                            threadLocalMap.putEntry(entry.getKey(), entry.getValue());
066                    }
067            }
068    
069            public CentralizedThreadLocal(boolean shortLived) {
070                    _shortLived = shortLived;
071    
072                    if (shortLived) {
073                            _hashCode = _shortLivedNextHasCode.getAndAdd(_HASH_INCREMENT);
074                    }
075                    else {
076                            _hashCode = _longLivedNextHasCode.getAndAdd(_HASH_INCREMENT);
077                    }
078            }
079    
080            @Override
081            public boolean equals(Object obj) {
082                    if (this == obj) {
083                            return true;
084                    }
085    
086                    return false;
087            }
088    
089            @Override
090            public T get() {
091                    ThreadLocalMap threadLocalMap = _getThreadLocalMap();
092    
093                    Entry entry = threadLocalMap.getEntry(this);
094    
095                    if (entry == null) {
096                            T value = initialValue();
097    
098                            threadLocalMap.putEntry(this, value);
099    
100                            return value;
101                    }
102                    else {
103                            return (T)entry._value;
104                    }
105            }
106    
107            @Override
108            public int hashCode() {
109                    return _hashCode;
110            }
111    
112            @Override
113            public void remove() {
114                    ThreadLocalMap threadLocalMap = _getThreadLocalMap();
115    
116                    threadLocalMap.removeEntry(this);
117            }
118    
119            @Override
120            public void set(T value) {
121                    ThreadLocalMap threadLocalMap = _getThreadLocalMap();
122    
123                    threadLocalMap.putEntry(this, value);
124            }
125    
126            protected T copy(T value) {
127                    if (value != null) {
128                            Class<?> clazz = value.getClass();
129    
130                            if (_immutableTypes.contains(clazz)) {
131                                    return value;
132                            }
133                    }
134    
135                    return null;
136            }
137    
138            private static Map<CentralizedThreadLocal<?>, Object> _toMap(
139                    ThreadLocalMap threadLocalMap) {
140    
141                    Map<CentralizedThreadLocal<?>, Object> map = new HashMap<>(
142                            threadLocalMap._table.length);
143    
144                    for (Entry entry : threadLocalMap._table) {
145                            if (entry != null) {
146                                    CentralizedThreadLocal<Object> centralizedThreadLocal =
147                                            (CentralizedThreadLocal<Object>)entry._key;
148    
149                                    Object value = centralizedThreadLocal.copy(entry._value);
150    
151                                    if (value != null) {
152                                            map.put(centralizedThreadLocal, value);
153                                    }
154                            }
155                    }
156    
157                    return map;
158            }
159    
160            private ThreadLocalMap _getThreadLocalMap() {
161                    if (_shortLived) {
162                            return _shortLivedThreadLocals.get();
163                    }
164                    else {
165                            return _longLivedThreadLocals.get();
166                    }
167            }
168    
169            private static final int _HASH_INCREMENT = 0x61c88647;
170    
171            private static final Set<Class<?>> _immutableTypes = new HashSet<>();
172            private static final AtomicInteger _longLivedNextHasCode =
173                    new AtomicInteger();
174            private static final ThreadLocal<ThreadLocalMap> _longLivedThreadLocals =
175                    new ThreadLocalMapThreadLocal();
176            private static final AtomicInteger _shortLivedNextHasCode =
177                    new AtomicInteger();
178            private static final ThreadLocal<ThreadLocalMap> _shortLivedThreadLocals =
179                    new ThreadLocalMapThreadLocal();
180    
181            static {
182                    _immutableTypes.add(Boolean.class);
183                    _immutableTypes.add(Byte.class);
184                    _immutableTypes.add(Character.class);
185                    _immutableTypes.add(Short.class);
186                    _immutableTypes.add(Integer.class);
187                    _immutableTypes.add(Long.class);
188                    _immutableTypes.add(Float.class);
189                    _immutableTypes.add(Double.class);
190                    _immutableTypes.add(String.class);
191            }
192    
193            private final int _hashCode;
194            private final boolean _shortLived;
195    
196            private static class Entry {
197    
198                    public Entry(CentralizedThreadLocal<?> key, Object value, Entry next) {
199                            _key = key;
200                            _value = value;
201                            _next = next;
202                    }
203    
204                    private CentralizedThreadLocal<?> _key;
205                    private Entry _next;
206                    private Object _value;
207    
208            }
209    
210            private static class ThreadLocalMap {
211    
212                    public void expand(int newCapacity) {
213                            if (_table.length == _MAXIMUM_CAPACITY) {
214                                    _threshold = Integer.MAX_VALUE;
215    
216                                    return;
217                            }
218    
219                            Entry[] newTable = new Entry[newCapacity];
220    
221                            for (int i = 0; i < _table.length; i++) {
222                                    Entry entry = _table[i];
223    
224                                    if (entry == null) {
225                                            continue;
226                                    }
227    
228                                    _table[i] = null;
229    
230                                    do {
231                                            Entry nextEntry = entry._next;
232    
233                                            int index = entry._key._hashCode & (newCapacity - 1);
234    
235                                            entry._next = newTable[index];
236    
237                                            newTable[index] = entry;
238    
239                                            entry = nextEntry;
240                                    }
241                                    while (entry != null);
242                            }
243    
244                            _table = newTable;
245    
246                            _threshold = newCapacity * 2 / 3;
247                    }
248    
249                    public Entry getEntry(CentralizedThreadLocal<?> key) {
250                            int index = key._hashCode & (_table.length - 1);
251    
252                            Entry entry = _table[index];
253    
254                            if (entry == null) {
255                                    return null;
256                            }
257    
258                            if (entry._key == key) {
259                                    return entry;
260                            }
261    
262                            while ((entry = entry._next) != null) {
263                                    if (entry._key == key) {
264                                            return entry;
265                                    }
266                            }
267    
268                            return null;
269                    }
270    
271                    public void putEntry(CentralizedThreadLocal<?> key, Object value) {
272                            int index = key._hashCode & (_table.length - 1);
273    
274                            for (Entry entry = _table[index]; entry != null;
275                                            entry = entry._next) {
276    
277                                    if (entry._key == key) {
278                                            entry._value = value;
279    
280                                            return;
281                                    }
282                            }
283    
284                            _table[index] = new Entry(key, value, _table[index]);
285    
286                            if (_size++ >= _threshold) {
287                                    expand(2 * _table.length);
288                            }
289                    }
290    
291                    public void removeEntry(CentralizedThreadLocal<?> key) {
292                            int index = key._hashCode & (_table.length - 1);
293    
294                            Entry previousEntry = null;
295    
296                            Entry entry = _table[index];
297    
298                            while (entry != null) {
299                                    Entry nextEntry = entry._next;
300    
301                                    if (entry._key == key) {
302                                            _size--;
303    
304                                            if (previousEntry == null) {
305                                                    _table[index] = nextEntry;
306                                            }
307                                            else {
308                                                    previousEntry._next = nextEntry;
309                                            }
310    
311                                            return;
312                                    }
313    
314                                    previousEntry = entry;
315                                    entry = nextEntry;
316                            }
317                    }
318    
319                    private static final int _INITIAL_CAPACITY = 16;
320    
321                    private static final int _MAXIMUM_CAPACITY = 1 << 30;
322    
323                    private int _size;
324                    private Entry[] _table = new Entry[_INITIAL_CAPACITY];
325                    private int _threshold = _INITIAL_CAPACITY * 2 / 3;
326    
327            }
328    
329            private static class ThreadLocalMapThreadLocal
330                    extends ThreadLocal<ThreadLocalMap> {
331    
332                    @Override
333                    protected ThreadLocalMap initialValue() {
334                            return new ThreadLocalMap();
335                    }
336    
337            }
338    
339    }