001    /**
002     * Copyright (c) 2000-present Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.kernel.test.rule;
016    
017    import com.liferay.portal.kernel.concurrent.ConcurrentReferenceKeyHashMap;
018    import com.liferay.portal.kernel.memory.FinalizeManager;
019    import com.liferay.portal.kernel.test.ReflectionTestUtil;
020    import com.liferay.portal.kernel.test.rule.callback.TestCallback;
021    
022    import java.util.Deque;
023    import java.util.LinkedList;
024    import java.util.Map;
025    
026    import org.junit.internal.runners.statements.ExpectException;
027    import org.junit.internal.runners.statements.FailOnTimeout;
028    import org.junit.internal.runners.statements.InvokeMethod;
029    import org.junit.internal.runners.statements.RunAfters;
030    import org.junit.internal.runners.statements.RunBefores;
031    import org.junit.rules.TestRule;
032    import org.junit.runner.Description;
033    import org.junit.runners.model.Statement;
034    
035    /**
036     * @author Shuyang Zhou
037     */
038    public class BaseTestRule<C, M>
039            implements ArquillianClassRuleHandler, TestRule {
040    
041            public BaseTestRule(TestCallback<C, M> testCallback) {
042                    _testCallback = testCallback;
043            }
044    
045            @Override
046            public final Statement apply(
047                    Statement statement, final Description description) {
048    
049                    String methodName = description.getMethodName();
050    
051                    if (methodName != null) {
052                            return new StatementWrapper(statement) {
053    
054                                    @Override
055                                    public void evaluate() throws Throwable {
056                                            Object target = inspectTarget(statement);
057    
058                                            M m = _testCallback.beforeMethod(description, target);
059    
060                                            try {
061                                                    statement.evaluate();
062                                            }
063                                            finally {
064                                                    _testCallback.afterMethod(description, m, target);
065                                            }
066                                    }
067    
068                            };
069                    }
070    
071                    boolean arquillianTest = ArquillianUtil.isArquillianTest(description);
072    
073                    if (!arquillianTest) {
074                            return new StatementWrapper(statement) {
075    
076                                    @Override
077                                    public void evaluate() throws Throwable {
078                                            C c = _testCallback.beforeClass(description);
079    
080                                            try {
081                                                    statement.evaluate();
082                                            }
083                                            finally {
084                                                    _testCallback.afterClass(description, c);
085                                            }
086                                    }
087    
088                            };
089                    }
090    
091                    return new StatementWrapper(statement) {
092    
093                            @Override
094                            public void evaluate() throws Throwable {
095                                    Class<?> clazz = description.getTestClass();
096    
097                                    if (_handleBeforeClassThreadLocal.get()) {
098                                            Deque<Object> deque = _classCarryOnMap.get(clazz);
099    
100                                            if (deque == null) {
101                                                    deque = new LinkedList<>();
102    
103                                                    _classCarryOnMap.put(clazz, deque);
104                                            }
105    
106                                            deque.addLast(_testCallback.beforeClass(description));
107                                    }
108    
109                                    try {
110                                            statement.evaluate();
111                                    }
112                                    finally {
113                                            if (_handleAfterClassThreadLocal.get()) {
114                                                    Deque<Object> deque = _classCarryOnMap.get(clazz);
115    
116                                                    _testCallback.afterClass(
117                                                            description, (C)deque.removeLast());
118    
119                                                    if (deque.isEmpty()) {
120                                                            _classCarryOnMap.remove(clazz);
121                                                    }
122                                            }
123                                    }
124                            }
125    
126                    };
127            }
128    
129            @Override
130            public void handleAfterClass(boolean enable) {
131                    _handleAfterClassThreadLocal.set(enable);
132            }
133    
134            @Override
135            public void handleBeforeClass(boolean enable) {
136                    _handleBeforeClassThreadLocal.set(enable);
137            }
138    
139            public static abstract class StatementWrapper extends Statement {
140    
141                    public StatementWrapper(Statement statement) {
142                            this.statement = statement;
143                    }
144    
145                    public Statement getStatement() {
146                            return statement;
147                    }
148    
149                    protected final Statement statement;
150    
151            }
152    
153            protected Object inspectTarget(Statement statement) {
154                    while (statement instanceof StatementWrapper) {
155                            StatementWrapper statementWrapper = (StatementWrapper)statement;
156    
157                            statement = statementWrapper.getStatement();
158                    }
159    
160                    if ((statement instanceof InvokeMethod) ||
161                            (statement instanceof RunAfters) ||
162                            (statement instanceof RunBefores)) {
163    
164                            return ReflectionTestUtil.getFieldValue(statement, "target");
165                    }
166                    else if (statement instanceof ExpectException) {
167                            return inspectTarget(
168                                    ReflectionTestUtil.<Statement>getFieldValue(statement, "next"));
169                    }
170                    else if (statement instanceof FailOnTimeout) {
171                            return inspectTarget(
172                                    ReflectionTestUtil.<Statement>getFieldValue(
173                                            statement, "originalStatement"));
174                    }
175    
176                    throw new IllegalStateException("Unknow statement " + statement);
177            }
178    
179            private static final Map<Class<?>, Deque<Object>> _classCarryOnMap =
180                    new ConcurrentReferenceKeyHashMap<>(
181                            FinalizeManager.WEAK_REFERENCE_FACTORY);
182    
183            private final ThreadLocal<Boolean> _handleAfterClassThreadLocal =
184                    new ThreadLocal<Boolean>() {
185    
186                            @Override
187                            protected Boolean initialValue() {
188                                    return Boolean.FALSE;
189                            }
190    
191                    };
192    
193            private final ThreadLocal<Boolean> _handleBeforeClassThreadLocal =
194                    new ThreadLocal<Boolean>() {
195    
196                            @Override
197                            protected Boolean initialValue() {
198                                    return Boolean.FALSE;
199                            }
200    
201                    };
202    
203            private final TestCallback<C, M> _testCallback;
204    
205    }