001    /**
002     * Copyright (c) 2000-present Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.kernel.upgrade;
016    
017    import com.liferay.portal.kernel.dao.jdbc.DataAccess;
018    import com.liferay.portal.kernel.log.Log;
019    import com.liferay.portal.kernel.log.LogFactoryUtil;
020    import com.liferay.portal.kernel.util.LoggingTimer;
021    import com.liferay.portal.kernel.util.StringBundler;
022    
023    import java.io.IOException;
024    
025    import java.sql.Connection;
026    import java.sql.PreparedStatement;
027    import java.sql.ResultSet;
028    import java.sql.SQLException;
029    
030    import java.util.ArrayList;
031    import java.util.List;
032    import java.util.concurrent.Callable;
033    import java.util.concurrent.ExecutorService;
034    import java.util.concurrent.Executors;
035    import java.util.concurrent.Future;
036    
037    /**
038     * @author Brian Wing Shun Chan
039     */
040    public abstract class BaseUpgradeCompanyId extends UpgradeProcess {
041    
042            @Override
043            protected void doUpgrade() throws Exception {
044                    List<Callable<Void>> callables = new ArrayList<>();
045    
046                    for (TableUpdater tableUpdater : getTableUpdaters()) {
047                            if (!hasColumn(tableUpdater.getTableName(), "companyId")) {
048                                    tableUpdater.setCreateCompanyIdColumn(true);
049                            }
050    
051                            callables.add(tableUpdater);
052                    }
053    
054                    ExecutorService executorService = Executors.newFixedThreadPool(
055                            callables.size());
056    
057                    try {
058                            List<Future<Void>> futures = executorService.invokeAll(callables);
059    
060                            for (Future<Void> future : futures) {
061                                    future.get();
062                            }
063                    }
064                    finally {
065                            executorService.shutdown();
066                    }
067            }
068    
069            protected abstract TableUpdater[] getTableUpdaters();
070    
071            protected class TableUpdater implements Callable<Void> {
072    
073                    public TableUpdater(
074                            String tableName, String foreignTableName,
075                            String foreignColumnName) {
076    
077                            _tableName = tableName;
078    
079                            _columnName = foreignColumnName;
080    
081                            _foreignNamesArray = new String[][] {
082                                    new String[] {foreignTableName, foreignColumnName}
083                            };
084                    }
085    
086                    public TableUpdater(
087                            String tableName, String columnName, String[][] foreignNamesArray) {
088    
089                            _tableName = tableName;
090                            _columnName = columnName;
091                            _foreignNamesArray = foreignNamesArray;
092                    }
093    
094                    @Override
095                    public final Void call() throws Exception {
096                            try (LoggingTimer loggingTimer = new LoggingTimer(_tableName);
097                                    Connection connection =
098                                            DataAccess.getUpgradeOptimizedConnection()) {
099    
100                                    if (_createCompanyIdColumn) {
101                                            if (_log.isInfoEnabled()) {
102                                                    _log.info(
103                                                            "Adding column companyId to table " + _tableName);
104                                            }
105    
106                                            runSQL(
107                                                    connection,
108                                                    "alter table " + _tableName +" add companyId LONG");
109                                    }
110                                    else {
111                                            if (_log.isInfoEnabled()) {
112                                                    _log.info(
113                                                            "Skipping the creation of companyId column for " +
114                                                                    "table " + _tableName);
115                                            }
116                                    }
117    
118                                    update(connection);
119                            }
120    
121                            return null;
122                    }
123    
124                    public String getTableName() {
125                            return _tableName;
126                    }
127    
128                    public void setCreateCompanyIdColumn(boolean createCompanyIdColumn) {
129                            _createCompanyIdColumn = createCompanyIdColumn;
130                    }
131    
132                    public void update(Connection connection)
133                            throws IOException, SQLException {
134    
135                            for (String[] foreignNames : _foreignNamesArray) {
136                                    runSQL(
137                                            connection,
138                                            getUpdateSQL(connection, foreignNames[0], foreignNames[1]));
139                            }
140                    }
141    
142                    protected List<Long> getCompanyIds(Connection connection)
143                            throws SQLException {
144    
145                            List<Long> companyIds = new ArrayList<>();
146    
147                            try (PreparedStatement ps = connection.prepareStatement(
148                                            "select companyId from Company");
149                                    ResultSet rs = ps.executeQuery()) {
150    
151                                    while (rs.next()) {
152                                            long companyId = rs.getLong(1);
153    
154                                            companyIds.add(companyId);
155                                    }
156                            }
157    
158                            return companyIds;
159                    }
160    
161                    protected String getSelectSQL(
162                                    Connection connection, String foreignTableName,
163                                    String foreignColumnName)
164                            throws SQLException {
165    
166                            List<Long> companyIds = getCompanyIds(connection);
167    
168                            if (companyIds.size() == 1) {
169                                    return String.valueOf(companyIds.get(0));
170                            }
171    
172                            StringBundler sb = new StringBundler(10);
173    
174                            sb.append("select max(companyId) from ");
175                            sb.append(foreignTableName);
176                            sb.append(" where ");
177                            sb.append(foreignTableName);
178                            sb.append(".");
179                            sb.append(foreignColumnName);
180                            sb.append(" = ");
181                            sb.append(_tableName);
182                            sb.append(".");
183                            sb.append(_columnName);
184    
185                            return sb.toString();
186                    }
187    
188                    protected String getUpdateSQL(
189                                    Connection connection, String foreignTableName,
190                                    String foreignColumnName)
191                            throws SQLException {
192    
193                            return getUpdateSQL(
194                                    getSelectSQL(connection, foreignTableName, foreignColumnName));
195                    }
196    
197                    protected String getUpdateSQL(String selectSQL) {
198                            StringBundler sb = new StringBundler(5);
199    
200                            sb.append("update ");
201                            sb.append(_tableName);
202                            sb.append(" set companyId = (");
203                            sb.append(selectSQL);
204                            sb.append(")");
205    
206                            return sb.toString();
207                    }
208    
209                    private final String _columnName;
210                    private boolean _createCompanyIdColumn;
211                    private final String[][] _foreignNamesArray;
212                    private final String _tableName;
213    
214            }
215    
216            private static final Log _log = LogFactoryUtil.getLog(
217                    BaseUpgradeCompanyId.class);
218    
219    }