001
014
015 package com.liferay.portal.kernel.test.rule;
016
017 import com.liferay.portal.kernel.exception.PortalException;
018 import com.liferay.portal.kernel.exception.SystemException;
019 import com.liferay.portal.kernel.test.ReflectionTestUtil;
020 import com.liferay.portal.kernel.test.rule.BaseTestRule.StatementWrapper;
021 import com.liferay.portal.kernel.transaction.Propagation;
022 import com.liferay.portal.kernel.transaction.TransactionConfig;
023 import com.liferay.portal.kernel.transaction.TransactionInvokerUtil;
024 import com.liferay.portal.kernel.transaction.Transactional;
025 import com.liferay.portal.kernel.util.ReflectionUtil;
026
027 import java.lang.reflect.Method;
028
029 import java.util.ArrayList;
030 import java.util.List;
031 import java.util.concurrent.Callable;
032
033 import org.junit.internal.runners.statements.RunAfters;
034 import org.junit.internal.runners.statements.RunBefores;
035 import org.junit.rules.RunRules;
036 import org.junit.rules.TestRule;
037 import org.junit.runner.Description;
038 import org.junit.runners.model.FrameworkMethod;
039 import org.junit.runners.model.Statement;
040
041
044 public class TransactionalTestRule implements TestRule {
045
046 public static final TransactionalTestRule INSTANCE =
047 new TransactionalTestRule();
048
049 public TransactionalTestRule() {
050 this(Propagation.SUPPORTS);
051 }
052
053 public TransactionalTestRule(Propagation propagation) {
054 _transactionConfig = TransactionConfig.Factory.create(
055 propagation,
056 new Class<?>[] {PortalException.class, SystemException.class});
057 }
058
059 @Override
060 public Statement apply(Statement statement, final Description description) {
061 Statement currentStatement = statement;
062
063 while (true) {
064 if (currentStatement instanceof StatementWrapper) {
065 StatementWrapper statementWrapper =
066 (StatementWrapper)currentStatement;
067
068 currentStatement = statementWrapper.getStatement();
069
070 continue;
071 }
072
073 if (currentStatement instanceof RunRules) {
074 currentStatement = ReflectionTestUtil.getFieldValue(
075 currentStatement, "statement");
076
077 continue;
078 }
079
080 if (currentStatement instanceof RunBefores) {
081 replaceFrameworkMethods(currentStatement, "befores");
082
083 currentStatement = ReflectionTestUtil.getFieldValue(
084 currentStatement, "next");
085
086 continue;
087 }
088
089 if (currentStatement instanceof RunAfters) {
090 replaceFrameworkMethods(currentStatement, "afters");
091
092 currentStatement = ReflectionTestUtil.getFieldValue(
093 currentStatement, "next");
094
095 continue;
096 }
097
098 return new StatementWrapper(statement) {
099
100 @Override
101 public void evaluate() throws Throwable {
102 TransactionInvokerUtil.invoke(
103 getTransactionConfig(
104 description.getAnnotation(Transactional.class)),
105 new Callable<Void>() {
106
107 @Override
108 public Void call() throws Exception {
109 try {
110 statement.evaluate();
111 }
112 catch (Throwable t) {
113 ReflectionUtil.throwException(t);
114 }
115
116 return null;
117 }
118
119 });
120 }
121
122 };
123 }
124 }
125
126 public TransactionConfig getTransactionConfig(Transactional transactional) {
127 if (transactional != null) {
128 return TransactionConfig.Factory.create(
129 transactional.isolation(), transactional.propagation(),
130 transactional.readOnly(), transactional.timeout(),
131 transactional.rollbackFor(),
132 transactional.rollbackForClassName(),
133 transactional.noRollbackFor(),
134 transactional.noRollbackForClassName());
135 }
136
137 return _transactionConfig;
138 }
139
140 protected void replaceFrameworkMethods(Statement statement, String name) {
141 List<FrameworkMethod> newFrameworkMethods = new ArrayList<>();
142
143 List<FrameworkMethod> frameworkMethods =
144 ReflectionTestUtil.<List<FrameworkMethod>>getFieldValue(
145 statement, name);
146
147 for (FrameworkMethod frameworkMethod : frameworkMethods) {
148 if (frameworkMethod instanceof TransactionalFrameworkMethod) {
149 newFrameworkMethods.add(frameworkMethod);
150
151 continue;
152 }
153
154 Transactional transactional = frameworkMethod.getAnnotation(
155 Transactional.class);
156
157 if (transactional == null) {
158 newFrameworkMethods.add(
159 new TransactionalFrameworkMethod(
160 frameworkMethod.getMethod(), _transactionConfig));
161 }
162 else {
163 newFrameworkMethods.add(
164 new TransactionalFrameworkMethod(
165 frameworkMethod.getMethod(),
166 getTransactionConfig(transactional)));
167 }
168 }
169
170 ReflectionTestUtil.setFieldValue(statement, name, newFrameworkMethods);
171 }
172
173 protected static class TransactionalFrameworkMethod
174 extends FrameworkMethod {
175
176 @Override
177 public Object invokeExplosively(
178 final Object target, final Object... params)
179 throws Throwable {
180
181 return TransactionInvokerUtil.invoke(
182 _transactionConfig,
183 new Callable<Object>() {
184
185 @Override
186 public Object call() throws Exception {
187 try {
188 return TransactionalFrameworkMethod.super.invokeExplosively(
189 target, params);
190 }
191 catch (Throwable t) {
192 ReflectionUtil.throwException(t);
193 }
194
195 return null;
196 }
197
198 });
199 }
200
201 protected TransactionalFrameworkMethod(
202 Method method, TransactionConfig transactionConfig) {
203
204 super(method);
205
206 _transactionConfig = transactionConfig;
207 }
208
209 private final TransactionConfig _transactionConfig;
210
211 }
212
213 private final TransactionConfig _transactionConfig;
214
215 }