001
014
015 package com.liferay.portal.test.jdbc;
016
017 import com.liferay.portal.kernel.concurrent.ConcurrentHashSet;
018 import com.liferay.portal.kernel.dao.db.DB;
019 import com.liferay.portal.kernel.dao.db.DBFactoryUtil;
020 import com.liferay.portal.kernel.dao.jdbc.DataAccess;
021 import com.liferay.portal.kernel.io.unsync.UnsyncStringReader;
022 import com.liferay.portal.kernel.log.Log;
023 import com.liferay.portal.kernel.log.LogFactoryUtil;
024 import com.liferay.portal.kernel.util.FileUtil;
025 import com.liferay.portal.kernel.util.StringUtil;
026 import com.liferay.portal.upgrade.util.Table;
027
028 import java.io.File;
029
030 import java.sql.Connection;
031 import java.sql.DatabaseMetaData;
032 import java.sql.ResultSet;
033
034 import java.util.ArrayList;
035 import java.util.Arrays;
036 import java.util.List;
037 import java.util.Set;
038 import java.util.concurrent.ConcurrentHashMap;
039 import java.util.concurrent.ConcurrentMap;
040
041 import net.sf.jsqlparser.expression.Expression;
042 import net.sf.jsqlparser.parser.CCJSqlParserManager;
043 import net.sf.jsqlparser.parser.JSqlParser;
044 import net.sf.jsqlparser.statement.Statement;
045 import net.sf.jsqlparser.statement.delete.Delete;
046 import net.sf.jsqlparser.statement.insert.Insert;
047 import net.sf.jsqlparser.statement.update.Update;
048 import net.sf.jsqlparser.util.TablesNamesFinder;
049
050
053 public class ResetDatabaseUtil {
054
055 public static synchronized boolean initialize() {
056 if (_initialized) {
057 reloadDatabase();
058
059 return false;
060 }
061
062 dumpDatabase();
063
064 _initialized = true;
065
066 return true;
067 }
068
069 public static void processSQL(Connection connection, String sql)
070 throws Exception {
071
072 if (!_recording) {
073 return;
074 }
075
076 List<String> tableNames = _getModifiedTableNames(sql);
077
078 if (tableNames == null) {
079 return;
080 }
081
082 for (String tableName : tableNames) {
083 tableName = StringUtil.toLowerCase(tableName);
084
085 Table table = _cachedTables.get(tableName);
086
087 if (table == null) {
088 _log.error(
089 "Unable to get table " + tableName + " from cache " +
090 _cachedTables.keySet());
091
092 continue;
093 }
094
095 if (_modifiedTables.putIfAbsent(tableName, table) == null) {
096 table.generateTempFile(connection);
097 }
098 }
099 }
100
101 public static void resetModifiedTables() {
102 _recording = false;
103
104 Connection connection = null;
105
106 try {
107 connection = DataAccess.getUpgradeOptimizedConnection();
108
109 for (Table table : _modifiedTables.values()) {
110 DB db = DBFactoryUtil.getDB();
111
112 db.runSQL(connection, table.getDeleteSQL());
113
114 table.populateTable();
115
116 String tempFileName = table.getTempFileName();
117
118 if (tempFileName != null) {
119 FileUtil.delete(tempFileName);
120 }
121 }
122 }
123 catch (Exception e) {
124 throw new RuntimeException(e);
125 }
126 finally {
127 DataAccess.cleanUp(connection);
128
129 _modifiedTables.clear();
130 }
131 }
132
133 public static void startRecording() {
134 _recording = true;
135 }
136
137 protected static void dumpDatabase() {
138 Connection connection = null;
139 ResultSet tableResultSet = null;
140
141 try {
142 connection = DataAccess.getUpgradeOptimizedConnection();
143
144 DatabaseMetaData databaseMetaData = connection.getMetaData();
145
146 tableResultSet = databaseMetaData.getTables(null, null, null, null);
147
148 while (tableResultSet.next()) {
149 String tableName = tableResultSet.getString("TABLE_NAME");
150
151 ResultSet columnResultSet = databaseMetaData.getColumns(
152 null, null, tableName, null);
153
154 List<Object[]> columns = new ArrayList<Object[]>();
155
156 try {
157 while (columnResultSet.next()) {
158 columns.add(
159 new Object[] {
160 columnResultSet.getString("COLUMN_NAME"),
161 columnResultSet.getInt("DATA_TYPE")});
162 }
163 }
164 finally {
165 DataAccess.cleanUp(columnResultSet);
166 }
167
168 Table table = new Table(
169 tableName, columns.toArray(new Object[columns.size()][]));
170
171 table.generateTempFile(connection);
172
173 String tempFileName = table.getTempFileName();
174
175 if (tempFileName != null) {
176 File tempFile = new File(tempFileName);
177
178 tempFile.deleteOnExit();
179 }
180
181 _tables.add(table);
182
183 _cachedTables.put(
184 StringUtil.toLowerCase(tableName),
185 new Table(
186 tableName,
187 columns.toArray(new Object[columns.size()][])));
188 }
189 }
190 catch (Exception e) {
191 throw new RuntimeException(e);
192 }
193 finally {
194 DataAccess.cleanUp(connection, null, tableResultSet);
195 }
196 }
197
198 protected static void reloadDatabase() {
199 Connection connection = null;
200
201 try {
202 connection = DataAccess.getUpgradeOptimizedConnection();
203
204 for (Table table : _tables) {
205 DB db = DBFactoryUtil.getDB();
206
207 db.runSQL(connection, table.getDeleteSQL());
208
209 table.populateTable();
210 }
211 }
212 catch (Exception e) {
213 throw new RuntimeException(e);
214 }
215 finally {
216 DataAccess.cleanUp(connection);
217 }
218 }
219
220 private static List<String> _getModifiedTableNames(String sql) {
221 Statement statement = null;
222
223 try {
224 statement = _jSqlParser.parse(new UnsyncStringReader(sql));
225 }
226 catch (Exception e) {
227 return null;
228 }
229
230 if (statement instanceof Delete) {
231 Delete delete = (Delete)statement;
232
233 Expression expression = delete.getWhere();
234
235 if (expression == null) {
236
237
238
239
240 net.sf.jsqlparser.schema.Table table = delete.getTable();
241
242 return Arrays.asList(table.getName());
243 }
244
245 TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
246
247 return tablesNamesFinder.getTableList(delete);
248 }
249
250 if (statement instanceof Insert) {
251 Insert insert = (Insert)statement;
252
253 TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
254
255 return tablesNamesFinder.getTableList(insert);
256 }
257
258 if (statement instanceof Update) {
259 Update update = (Update)statement;
260
261 TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
262
263 return tablesNamesFinder.getTableList(update);
264 }
265
266 return null;
267 }
268
269 private static final Log _log = LogFactoryUtil.getLog(
270 ResetDatabaseUtil.class);
271
272 private static final ConcurrentMap<String, Table> _cachedTables =
273 new ConcurrentHashMap<String, Table>();
274 private static boolean _initialized;
275 private static final JSqlParser _jSqlParser = new CCJSqlParserManager();
276 private static final ConcurrentMap<String, Table> _modifiedTables =
277 new ConcurrentHashMap<String, Table>();
278 private static volatile boolean _recording;
279 private static final Set<Table> _tables = new ConcurrentHashSet<Table>();
280
281 }