001
014
015 package com.liferay.portal.kernel.test.rule;
016
017 import com.liferay.portal.kernel.concurrent.ConcurrentReferenceKeyHashMap;
018 import com.liferay.portal.kernel.memory.FinalizeManager;
019 import com.liferay.portal.kernel.test.ReflectionTestUtil;
020 import com.liferay.portal.kernel.test.rule.callback.TestCallback;
021
022 import java.util.Deque;
023 import java.util.LinkedList;
024 import java.util.Map;
025
026 import org.junit.internal.runners.statements.ExpectException;
027 import org.junit.internal.runners.statements.FailOnTimeout;
028 import org.junit.internal.runners.statements.InvokeMethod;
029 import org.junit.internal.runners.statements.RunAfters;
030 import org.junit.internal.runners.statements.RunBefores;
031 import org.junit.rules.TestRule;
032 import org.junit.runner.Description;
033 import org.junit.runner.RunWith;
034 import org.junit.runner.Runner;
035 import org.junit.runners.model.Statement;
036
037
040 public class BaseTestRule<C, M>
041 implements ArquillianClassRuleHandler, TestRule {
042
043 public BaseTestRule(TestCallback<C, M> testCallback) {
044 _testCallback = testCallback;
045 }
046
047 @Override
048 public final Statement apply(
049 Statement statement, final Description description) {
050
051 String methodName = description.getMethodName();
052
053 if (methodName != null) {
054 return new StatementWrapper(statement) {
055
056 @Override
057 public void evaluate() throws Throwable {
058 Object target = inspectTarget(statement);
059
060 M m = _testCallback.beforeMethod(description, target);
061
062 try {
063 statement.evaluate();
064 }
065 finally {
066 _testCallback.afterMethod(description, m, target);
067 }
068 }
069
070 };
071 }
072
073 boolean arquillianTest = _isArquillianTest(description);
074
075 if (!arquillianTest) {
076 return new StatementWrapper(statement) {
077
078 @Override
079 public void evaluate() throws Throwable {
080 C c = _testCallback.beforeClass(description);
081
082 try {
083 statement.evaluate();
084 }
085 finally {
086 _testCallback.afterClass(description, c);
087 }
088 }
089
090 };
091 }
092
093 return new StatementWrapper(statement) {
094
095 @Override
096 public void evaluate() throws Throwable {
097 Class<?> clazz = description.getTestClass();
098
099 if (_handleBeforeClassThreadLocal.get()) {
100 Deque<Object> deque = _classCarryOnMap.get(clazz);
101
102 if (deque == null) {
103 deque = new LinkedList<>();
104
105 _classCarryOnMap.put(clazz, deque);
106 }
107
108 deque.addLast(_testCallback.beforeClass(description));
109 }
110
111 try {
112 statement.evaluate();
113 }
114 finally {
115 if (_handleAfterClassThreadLocal.get()) {
116 Deque<Object> deque = _classCarryOnMap.get(clazz);
117
118 _testCallback.afterClass(
119 description, (C)deque.removeLast());
120
121 if (deque.isEmpty()) {
122 _classCarryOnMap.remove(clazz);
123 }
124 }
125 }
126 }
127
128 };
129 }
130
131 @Override
132 public void handleAfterClass(boolean enable) {
133 _handleAfterClassThreadLocal.set(enable);
134 }
135
136 @Override
137 public void handleBeforeClass(boolean enable) {
138 _handleBeforeClassThreadLocal.set(enable);
139 }
140
141 public static abstract class StatementWrapper extends Statement {
142
143 public StatementWrapper(Statement statement) {
144 this.statement = statement;
145 }
146
147 public Statement getStatement() {
148 return statement;
149 }
150
151 protected final Statement statement;
152
153 }
154
155 protected Object inspectTarget(Statement statement) {
156 while (statement instanceof StatementWrapper) {
157 StatementWrapper statementWrapper = (StatementWrapper)statement;
158
159 statement = statementWrapper.getStatement();
160 }
161
162 if ((statement instanceof InvokeMethod) ||
163 (statement instanceof RunAfters) ||
164 (statement instanceof RunBefores)) {
165
166 return ReflectionTestUtil.getFieldValue(statement, "target");
167 }
168 else if (statement instanceof ExpectException) {
169 return inspectTarget(
170 ReflectionTestUtil.<Statement>getFieldValue(statement, "next"));
171 }
172 else if (statement instanceof FailOnTimeout) {
173 return inspectTarget(
174 ReflectionTestUtil.<Statement>getFieldValue(
175 statement, "originalStatement"));
176 }
177
178 throw new IllegalStateException("Unknow statement " + statement);
179 }
180
181 private static boolean _isArquillianTest(Description description) {
182 RunWith runWith = description.getAnnotation(RunWith.class);
183
184 if (runWith == null) {
185 return false;
186 }
187
188 Class<? extends Runner> runnerClass = runWith.value();
189
190 String runnerClassName = runnerClass.getName();
191
192 if (runnerClassName.equals(
193 "com.liferay.arquillian.extension.junit.bridge.junit." +
194 "Arquillian")) {
195
196 return true;
197 }
198
199 return false;
200 }
201
202 private static final Map<Class<?>, Deque<Object>> _classCarryOnMap =
203 new ConcurrentReferenceKeyHashMap<>(
204 FinalizeManager.WEAK_REFERENCE_FACTORY);
205
206 private final ThreadLocal<Boolean> _handleAfterClassThreadLocal =
207 new ThreadLocal<Boolean>() {
208
209 @Override
210 protected Boolean initialValue() {
211 return Boolean.FALSE;
212 }
213
214 };
215
216 private final ThreadLocal<Boolean> _handleBeforeClassThreadLocal =
217 new ThreadLocal<Boolean>() {
218
219 @Override
220 protected Boolean initialValue() {
221 return Boolean.FALSE;
222 }
223
224 };
225
226 private final TestCallback<C, M> _testCallback;
227
228 }