001    /**
002     * Copyright (c) 2000-present Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.test.jdbc;
016    
017    import com.liferay.portal.kernel.concurrent.ConcurrentHashSet;
018    import com.liferay.portal.kernel.dao.db.DB;
019    import com.liferay.portal.kernel.dao.db.DBFactoryUtil;
020    import com.liferay.portal.kernel.dao.jdbc.DataAccess;
021    import com.liferay.portal.kernel.io.unsync.UnsyncStringReader;
022    import com.liferay.portal.kernel.log.Log;
023    import com.liferay.portal.kernel.log.LogFactoryUtil;
024    import com.liferay.portal.kernel.util.FileUtil;
025    import com.liferay.portal.kernel.util.StringUtil;
026    import com.liferay.portal.upgrade.util.Table;
027    
028    import java.io.File;
029    
030    import java.sql.Connection;
031    import java.sql.DatabaseMetaData;
032    import java.sql.ResultSet;
033    
034    import java.util.ArrayList;
035    import java.util.Arrays;
036    import java.util.List;
037    import java.util.Set;
038    import java.util.concurrent.ConcurrentHashMap;
039    import java.util.concurrent.ConcurrentMap;
040    
041    import net.sf.jsqlparser.expression.Expression;
042    import net.sf.jsqlparser.parser.CCJSqlParserManager;
043    import net.sf.jsqlparser.parser.JSqlParser;
044    import net.sf.jsqlparser.statement.Statement;
045    import net.sf.jsqlparser.statement.delete.Delete;
046    import net.sf.jsqlparser.statement.insert.Insert;
047    import net.sf.jsqlparser.statement.update.Update;
048    import net.sf.jsqlparser.util.TablesNamesFinder;
049    
050    /**
051     * @author Shuyang Zhou
052     */
053    public class ResetDatabaseUtil {
054    
055            public static synchronized boolean initialize() {
056                    if (_initialized) {
057                            reloadDatabase();
058    
059                            return false;
060                    }
061    
062                    dumpDatabase();
063    
064                    _initialized = true;
065    
066                    return true;
067            }
068    
069            public static void processSQL(Connection connection, String sql)
070                    throws Exception {
071    
072                    if (!_recording) {
073                            return;
074                    }
075    
076                    List<String> tableNames = _getModifiedTableNames(sql);
077    
078                    if (tableNames == null) {
079                            return;
080                    }
081    
082                    for (String tableName : tableNames) {
083                            tableName = StringUtil.toLowerCase(tableName);
084    
085                            Table table = _cachedTables.get(tableName);
086    
087                            if (table == null) {
088                                    _log.error(
089                                            "Unable to get table " + tableName + " from cache " +
090                                                    _cachedTables.keySet());
091    
092                                    continue;
093                            }
094    
095                            if (_modifiedTables.putIfAbsent(tableName, table) ==
096                                            null) {
097    
098                                    table.generateTempFile(connection);
099                            }
100                    }
101            }
102    
103            public static void resetModifiedTables() {
104                    _recording = false;
105    
106                    Connection connection = null;
107    
108                    try {
109                            connection = DataAccess.getUpgradeOptimizedConnection();
110    
111                            for (Table table : _modifiedTables.values()) {
112                                    DB db = DBFactoryUtil.getDB();
113    
114                                    db.runSQL(connection, table.getDeleteSQL());
115    
116                                    table.populateTable();
117    
118                                    String tempFileName = table.getTempFileName();
119    
120                                    if (tempFileName != null) {
121                                            FileUtil.delete(tempFileName);
122                                    }
123                            }
124                    }
125                    catch (Exception e) {
126                            throw new RuntimeException(e);
127                    }
128                    finally {
129                            DataAccess.cleanUp(connection);
130    
131                            _modifiedTables.clear();
132                    }
133            }
134    
135            public static void startRecording() {
136                    _recording = true;
137            }
138    
139            protected static void dumpDatabase() {
140                    Connection connection = null;
141                    ResultSet tableResultSet = null;
142    
143                    try {
144                            connection = DataAccess.getUpgradeOptimizedConnection();
145    
146                            DatabaseMetaData databaseMetaData = connection.getMetaData();
147    
148                            tableResultSet = databaseMetaData.getTables(null, null, null, null);
149    
150                            while (tableResultSet.next()) {
151                                    String tableName = tableResultSet.getString("TABLE_NAME");
152    
153                                    ResultSet columnResultSet = databaseMetaData.getColumns(
154                                            null, null, tableName, null);
155    
156                                    List<Object[]> columns = new ArrayList<Object[]>();
157    
158                                    try {
159                                            while (columnResultSet.next()) {
160                                                    columns.add(
161                                                            new Object[] {
162                                                                    columnResultSet.getString("COLUMN_NAME"),
163                                                                    columnResultSet.getInt("DATA_TYPE")});
164                                            }
165                                    }
166                                    finally {
167                                            DataAccess.cleanUp(columnResultSet);
168                                    }
169    
170                                    Table table = new Table(
171                                            tableName, columns.toArray(new Object[columns.size()][]));
172    
173                                    table.generateTempFile(connection);
174    
175                                    String tempFileName = table.getTempFileName();
176    
177                                    if (tempFileName != null) {
178                                            File tempFile = new File(tempFileName);
179    
180                                            tempFile.deleteOnExit();
181                                    }
182    
183                                    _tables.add(table);
184    
185                                    _cachedTables.put(
186                                            StringUtil.toLowerCase(tableName),
187                                            new Table(
188                                                    tableName,
189                                                    columns.toArray(new Object[columns.size()][])));
190                            }
191                    }
192                    catch (Exception e) {
193                            throw new RuntimeException(e);
194                    }
195                    finally {
196                            DataAccess.cleanUp(connection, null, tableResultSet);
197                    }
198            }
199    
200            protected static void reloadDatabase() {
201                    Connection connection = null;
202    
203                    try {
204                            connection = DataAccess.getUpgradeOptimizedConnection();
205    
206                            for (Table table : _tables) {
207                                    DB db = DBFactoryUtil.getDB();
208    
209                                    db.runSQL(connection, table.getDeleteSQL());
210    
211                                    table.populateTable();
212                            }
213                    }
214                    catch (Exception e) {
215                            throw new RuntimeException(e);
216                    }
217                    finally {
218                            DataAccess.cleanUp(connection);
219                    }
220            }
221    
222            private static List<String> _getModifiedTableNames(String sql) {
223                    Statement statement = null;
224    
225                    try {
226                            statement = _jSqlParser.parse(new UnsyncStringReader(sql));
227                    }
228                    catch (Exception e) {
229                            return null;
230                    }
231    
232                    if (statement instanceof Delete) {
233                            Delete delete = (Delete)statement;
234    
235                            Expression expression = delete.getWhere();
236    
237                            if (expression == null) {
238    
239                                    // Workaround for
240                                    // https://github.com/JSQLParser/JSqlParser/pull/55
241    
242                                    net.sf.jsqlparser.schema.Table table = delete.getTable();
243    
244                                    return Arrays.asList(table.getName());
245                            }
246    
247                            TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
248    
249                            return tablesNamesFinder.getTableList(delete);
250                    }
251    
252                    if (statement instanceof Insert) {
253                            Insert insert = (Insert)statement;
254    
255                            TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
256    
257                            return tablesNamesFinder.getTableList(insert);
258                    }
259    
260                    if (statement instanceof Update) {
261                            Update update = (Update)statement;
262    
263                            TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
264    
265                            return tablesNamesFinder.getTableList(update);
266                    }
267    
268                    return null;
269            }
270    
271            private static final Log _log = LogFactoryUtil.getLog(
272                    ResetDatabaseUtil.class);
273    
274            private static final ConcurrentMap<String, Table> _cachedTables =
275                    new ConcurrentHashMap<String, Table>();
276            private static boolean _initialized;
277            private static final JSqlParser _jSqlParser = new CCJSqlParserManager();
278            private static final ConcurrentMap<String, Table> _modifiedTables =
279                    new ConcurrentHashMap<String, Table>();
280            private static volatile boolean _recording;
281            private static final Set<Table> _tables = new ConcurrentHashSet<Table>();
282    
283    }