001    /**
002     * Copyright (c) 2000-present Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.test.jdbc;
016    
017    import com.liferay.portal.kernel.concurrent.ConcurrentHashSet;
018    import com.liferay.portal.kernel.dao.db.DB;
019    import com.liferay.portal.kernel.dao.db.DBFactoryUtil;
020    import com.liferay.portal.kernel.dao.jdbc.DataAccess;
021    import com.liferay.portal.kernel.io.unsync.UnsyncStringReader;
022    import com.liferay.portal.kernel.log.Log;
023    import com.liferay.portal.kernel.log.LogFactoryUtil;
024    import com.liferay.portal.kernel.util.FileUtil;
025    import com.liferay.portal.kernel.util.StringUtil;
026    import com.liferay.portal.upgrade.util.Table;
027    
028    import java.io.File;
029    
030    import java.sql.Connection;
031    import java.sql.DatabaseMetaData;
032    import java.sql.ResultSet;
033    
034    import java.util.ArrayList;
035    import java.util.Arrays;
036    import java.util.List;
037    import java.util.Set;
038    import java.util.concurrent.ConcurrentHashMap;
039    import java.util.concurrent.ConcurrentMap;
040    
041    import net.sf.jsqlparser.expression.Expression;
042    import net.sf.jsqlparser.parser.CCJSqlParserManager;
043    import net.sf.jsqlparser.parser.JSqlParser;
044    import net.sf.jsqlparser.statement.Statement;
045    import net.sf.jsqlparser.statement.delete.Delete;
046    import net.sf.jsqlparser.statement.insert.Insert;
047    import net.sf.jsqlparser.statement.update.Update;
048    import net.sf.jsqlparser.util.TablesNamesFinder;
049    
050    /**
051     * @author Shuyang Zhou
052     */
053    public class ResetDatabaseUtil {
054    
055            public static synchronized boolean initialize() {
056                    if (_initialized) {
057                            reloadDatabase();
058    
059                            return false;
060                    }
061    
062                    dumpDatabase();
063    
064                    _initialized = true;
065    
066                    return true;
067            }
068    
069            public static void processSQL(Connection connection, String sql)
070                    throws Exception {
071    
072                    if (!_recording) {
073                            return;
074                    }
075    
076                    List<String> tableNames = _getModifiedTableNames(sql);
077    
078                    if (tableNames == null) {
079                            return;
080                    }
081    
082                    for (String tableName : tableNames) {
083                            tableName = StringUtil.toLowerCase(tableName);
084    
085                            Table table = _cachedTables.get(tableName);
086    
087                            if (table == null) {
088                                    _log.error(
089                                            "Unable to get table " + tableName + " from cache " +
090                                                    _cachedTables.keySet());
091    
092                                    continue;
093                            }
094    
095                            if (_modifiedTables.putIfAbsent(tableName, table) == null) {
096                                    table.generateTempFile(connection);
097                            }
098                    }
099            }
100    
101            public static void resetModifiedTables() {
102                    _recording = false;
103    
104                    Connection connection = null;
105    
106                    try {
107                            connection = DataAccess.getUpgradeOptimizedConnection();
108    
109                            for (Table table : _modifiedTables.values()) {
110                                    DB db = DBFactoryUtil.getDB();
111    
112                                    db.runSQL(connection, table.getDeleteSQL());
113    
114                                    table.populateTable();
115    
116                                    String tempFileName = table.getTempFileName();
117    
118                                    if (tempFileName != null) {
119                                            FileUtil.delete(tempFileName);
120                                    }
121                            }
122                    }
123                    catch (Exception e) {
124                            throw new RuntimeException(e);
125                    }
126                    finally {
127                            DataAccess.cleanUp(connection);
128    
129                            _modifiedTables.clear();
130                    }
131            }
132    
133            public static void startRecording() {
134                    _recording = true;
135            }
136    
137            protected static void dumpDatabase() {
138                    Connection connection = null;
139                    ResultSet tableResultSet = null;
140    
141                    try {
142                            connection = DataAccess.getUpgradeOptimizedConnection();
143    
144                            DatabaseMetaData databaseMetaData = connection.getMetaData();
145    
146                            tableResultSet = databaseMetaData.getTables(null, null, null, null);
147    
148                            while (tableResultSet.next()) {
149                                    String tableName = tableResultSet.getString("TABLE_NAME");
150    
151                                    ResultSet columnResultSet = databaseMetaData.getColumns(
152                                            null, null, tableName, null);
153    
154                                    List<Object[]> columns = new ArrayList<Object[]>();
155    
156                                    try {
157                                            while (columnResultSet.next()) {
158                                                    columns.add(
159                                                            new Object[] {
160                                                                    columnResultSet.getString("COLUMN_NAME"),
161                                                                    columnResultSet.getInt("DATA_TYPE")});
162                                            }
163                                    }
164                                    finally {
165                                            DataAccess.cleanUp(columnResultSet);
166                                    }
167    
168                                    Table table = new Table(
169                                            tableName, columns.toArray(new Object[columns.size()][]));
170    
171                                    table.generateTempFile(connection);
172    
173                                    String tempFileName = table.getTempFileName();
174    
175                                    if (tempFileName != null) {
176                                            File tempFile = new File(tempFileName);
177    
178                                            tempFile.deleteOnExit();
179                                    }
180    
181                                    _tables.add(table);
182    
183                                    _cachedTables.put(
184                                            StringUtil.toLowerCase(tableName),
185                                            new Table(
186                                                    tableName,
187                                                    columns.toArray(new Object[columns.size()][])));
188                            }
189                    }
190                    catch (Exception e) {
191                            throw new RuntimeException(e);
192                    }
193                    finally {
194                            DataAccess.cleanUp(connection, null, tableResultSet);
195                    }
196            }
197    
198            protected static void reloadDatabase() {
199                    Connection connection = null;
200    
201                    try {
202                            connection = DataAccess.getUpgradeOptimizedConnection();
203    
204                            for (Table table : _tables) {
205                                    DB db = DBFactoryUtil.getDB();
206    
207                                    db.runSQL(connection, table.getDeleteSQL());
208    
209                                    table.populateTable();
210                            }
211                    }
212                    catch (Exception e) {
213                            throw new RuntimeException(e);
214                    }
215                    finally {
216                            DataAccess.cleanUp(connection);
217                    }
218            }
219    
220            private static List<String> _getModifiedTableNames(String sql) {
221                    Statement statement = null;
222    
223                    try {
224                            statement = _jSqlParser.parse(new UnsyncStringReader(sql));
225                    }
226                    catch (Exception e) {
227                            return null;
228                    }
229    
230                    if (statement instanceof Delete) {
231                            Delete delete = (Delete)statement;
232    
233                            Expression expression = delete.getWhere();
234    
235                            if (expression == null) {
236    
237                                    // Workaround for
238                                    // https://github.com/JSQLParser/JSqlParser/pull/55
239    
240                                    net.sf.jsqlparser.schema.Table table = delete.getTable();
241    
242                                    return Arrays.asList(table.getName());
243                            }
244    
245                            TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
246    
247                            return tablesNamesFinder.getTableList(delete);
248                    }
249    
250                    if (statement instanceof Insert) {
251                            Insert insert = (Insert)statement;
252    
253                            TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
254    
255                            return tablesNamesFinder.getTableList(insert);
256                    }
257    
258                    if (statement instanceof Update) {
259                            Update update = (Update)statement;
260    
261                            TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
262    
263                            return tablesNamesFinder.getTableList(update);
264                    }
265    
266                    return null;
267            }
268    
269            private static final Log _log = LogFactoryUtil.getLog(
270                    ResetDatabaseUtil.class);
271    
272            private static final ConcurrentMap<String, Table> _cachedTables =
273                    new ConcurrentHashMap<String, Table>();
274            private static boolean _initialized;
275            private static final JSqlParser _jSqlParser = new CCJSqlParserManager();
276            private static final ConcurrentMap<String, Table> _modifiedTables =
277                    new ConcurrentHashMap<String, Table>();
278            private static volatile boolean _recording;
279            private static final Set<Table> _tables = new ConcurrentHashSet<Table>();
280    
281    }