001    /**
002     * Copyright (c) 2000-present Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.kernel.servlet.filters.invoker;
016    
017    import com.liferay.portal.kernel.log.Log;
018    import com.liferay.portal.kernel.log.LogFactoryUtil;
019    import com.liferay.portal.kernel.servlet.ServletContextPool;
020    import com.liferay.portal.kernel.util.GetterUtil;
021    import com.liferay.portal.kernel.util.InstanceFactory;
022    import com.liferay.portal.kernel.util.ObjectValuePair;
023    import com.liferay.portal.kernel.util.StringPool;
024    import com.liferay.portal.kernel.util.StringUtil;
025    import com.liferay.portal.kernel.util.Validator;
026    import com.liferay.portal.kernel.xml.Document;
027    import com.liferay.portal.kernel.xml.Element;
028    import com.liferay.portal.kernel.xml.UnsecureSAXReaderUtil;
029    import com.liferay.registry.Registry;
030    import com.liferay.registry.RegistryUtil;
031    import com.liferay.registry.ServiceReference;
032    import com.liferay.registry.ServiceTracker;
033    import com.liferay.registry.ServiceTrackerCustomizer;
034    import com.liferay.registry.util.StringPlus;
035    
036    import java.io.InputStream;
037    
038    import java.util.ArrayList;
039    import java.util.HashMap;
040    import java.util.List;
041    import java.util.Map;
042    import java.util.Set;
043    import java.util.concurrent.ConcurrentHashMap;
044    import java.util.concurrent.ConcurrentMap;
045    import java.util.concurrent.CopyOnWriteArraySet;
046    
047    import javax.servlet.Filter;
048    import javax.servlet.FilterChain;
049    import javax.servlet.FilterConfig;
050    import javax.servlet.ServletContext;
051    import javax.servlet.ServletException;
052    import javax.servlet.http.HttpServletRequest;
053    
054    /**
055     * @author Mika Koivisto
056     * @author Brian Wing Shun Chan
057     */
058    public class InvokerFilterHelper {
059    
060            public void destroy() {
061                    _serviceTracker.close();
062    
063                    for (List<FilterMapping> filterMappings : _filterMappingsMap.values()) {
064                            FilterMapping filterMapping = filterMappings.get(0);
065    
066                            Filter filter = filterMapping.getFilter();
067    
068                            try {
069                                    filter.destroy();
070                            }
071                            catch (Exception e) {
072                                    _log.error(e, e);
073                            }
074                    }
075    
076                    _filterMappingsMap.clear();
077                    _filterNames.clear();
078    
079                    clearFilterChainsCache();
080            }
081    
082            public void init(FilterConfig filterConfig) throws ServletException {
083                    try {
084                            ServletContext servletContext = filterConfig.getServletContext();
085    
086                            readLiferayFilterWebXML(servletContext, "/WEB-INF/liferay-web.xml");
087    
088                            Registry registry = RegistryUtil.getRegistry();
089    
090                            String servletContextName = GetterUtil.getString(
091                                    servletContext.getServletContextName());
092    
093                            com.liferay.registry.Filter filter = registry.getFilter(
094                                    "(&(objectClass=" + Filter.class.getName() +
095                                            ")(servlet-context-name=" + servletContextName +
096                                                    ")(servlet-filter-name=*))");
097    
098                            _serviceTracker = registry.trackServices(
099                                    filter, new FilterServiceTrackerCustomizer());
100    
101                            _serviceTracker.open();
102                    }
103                    catch (Exception e) {
104                            _log.error(e, e);
105    
106                            throw new ServletException(e);
107                    }
108            }
109    
110            public void registerFilterMapping(
111                    FilterMapping filterMapping, String filterName, boolean after) {
112    
113                    while (true) {
114                            List<FilterMapping> oldFilterMappings = _filterMappingsMap.get(
115                                    filterName);
116    
117                            List<FilterMapping> newFilterMappings = null;
118    
119                            if (oldFilterMappings == null) {
120                                    newFilterMappings = new ArrayList<>();
121                            }
122                            else {
123                                    newFilterMappings = new ArrayList<>(oldFilterMappings);
124                            }
125    
126                            if (after) {
127                                    newFilterMappings.add(filterMapping);
128                            }
129                            else {
130                                    newFilterMappings.add(0, filterMapping);
131                            }
132    
133                            if (newFilterMappings.size() == 1) {
134                                    if (_filterMappingsMap.putIfAbsent(
135                                                    filterName, newFilterMappings) == null) {
136    
137                                            _filterNames.add(filterName);
138    
139                                            break;
140                                    }
141                            }
142                            else if (_filterMappingsMap.replace(
143                                                    filterName, oldFilterMappings, newFilterMappings)) {
144    
145                                    break;
146                            }
147                    }
148            }
149    
150            public void unregisterFilterMapping(FilterMapping filterMapping) {
151                    String filterName = filterMapping.getFilterName();
152    
153                    while (true) {
154                            List<FilterMapping> oldFilterMappings = _filterMappingsMap.get(
155                                    filterName);
156    
157                            List<FilterMapping> newFilterMappings = new ArrayList<>(
158                                    oldFilterMappings);
159    
160                            newFilterMappings.remove(filterMapping);
161    
162                            if (newFilterMappings.isEmpty()) {
163                                    if (_filterMappingsMap.remove(filterName, oldFilterMappings)) {
164                                            _filterNames.remove(filterName);
165    
166                                            break;
167                                    }
168                            }
169                            else if (_filterMappingsMap.replace(
170                                                    filterName, oldFilterMappings, newFilterMappings)) {
171    
172                                    break;
173                            }
174                    }
175            }
176    
177            public void unregisterFilterMappings(String filterName) {
178                    List<FilterMapping> filterMappings = _filterMappingsMap.remove(
179                            filterName);
180    
181                    if (filterMappings == null) {
182                            return;
183                    }
184    
185                    FilterMapping filterMapping = filterMappings.get(0);
186    
187                    Filter filter = filterMapping.getFilter();
188    
189                    if (filter != null) {
190                            try {
191                                    filter.destroy();
192                            }
193                            catch (Exception e) {
194                                    _log.error(e, e);
195                            }
196                    }
197    
198                    _filterNames.remove(filterName);
199    
200                    clearFilterChainsCache();
201            }
202    
203            public void updateFilterMappings(String filterName, Filter filter) {
204                    while (true) {
205                            List<FilterMapping> oldFilterMappings = _filterMappingsMap.get(
206                                    filterName);
207    
208                            if (oldFilterMappings == null) {
209                                    if (_log.isWarnEnabled()) {
210                                            _log.warn(
211                                                    "No filter mappings for filter name " + filterName);
212                                    }
213    
214                                    return;
215                            }
216    
217                            List<FilterMapping> newFilterMappings = new ArrayList<>();
218    
219                            for (FilterMapping oldFilterMapping : oldFilterMappings) {
220                                    newFilterMappings.add(oldFilterMapping.replaceFilter(filter));
221                            }
222    
223                            if (_filterMappingsMap.replace(
224                                            filterName, oldFilterMappings, newFilterMappings)) {
225    
226                                    break;
227                            }
228                    }
229            }
230    
231            protected void addInvokerFilter(InvokerFilter invokerFilter) {
232                    _invokerFilters.add(invokerFilter);
233            }
234    
235            protected void clearFilterChainsCache() {
236                    for (InvokerFilter invokerFilter : _invokerFilters) {
237                            invokerFilter.clearFilterChainsCache();
238                    }
239            }
240    
241            protected InvokerFilterChain createInvokerFilterChain(
242                    HttpServletRequest request, Dispatcher dispatcher, String uri,
243                    FilterChain filterChain) {
244    
245                    InvokerFilterChain invokerFilterChain = new InvokerFilterChain(
246                            filterChain);
247    
248                    for (String filterName : _filterNames) {
249                            List<FilterMapping> filterMappings = _filterMappingsMap.get(
250                                    filterName);
251    
252                            if (filterMappings == null) {
253                                    continue;
254                            }
255    
256                            for (FilterMapping filterMapping : filterMappings) {
257                                    if (filterMapping.isMatch(request, dispatcher, uri)) {
258                                            invokerFilterChain.addFilter(filterMapping.getFilter());
259                                    }
260                            }
261                    }
262    
263                    return invokerFilterChain;
264            }
265    
266            protected Filter initFilter(
267                    ServletContext servletContext, String filterClassName,
268                    String filterName, FilterConfig filterConfig) {
269    
270                    ClassLoader pluginClassLoader = servletContext.getClassLoader();
271    
272                    Thread currentThread = Thread.currentThread();
273    
274                    ClassLoader contextClassLoader = currentThread.getContextClassLoader();
275    
276                    if (contextClassLoader != pluginClassLoader) {
277                            currentThread.setContextClassLoader(pluginClassLoader);
278                    }
279    
280                    try {
281                            Filter filter = (Filter)InstanceFactory.newInstance(
282                                    pluginClassLoader, filterClassName);
283    
284                            filter.init(filterConfig);
285    
286                            return filter;
287                    }
288                    catch (Exception e) {
289                            _log.error("Unable to initialize filter " + filterClassName, e);
290    
291                            return null;
292                    }
293                    finally {
294                            if (contextClassLoader != pluginClassLoader) {
295                                    currentThread.setContextClassLoader(contextClassLoader);
296                            }
297                    }
298            }
299    
300            protected void readLiferayFilterWebXML(
301                            ServletContext servletContext, String path)
302                    throws Exception {
303    
304                    InputStream inputStream = servletContext.getResourceAsStream(path);
305    
306                    if (inputStream == null) {
307                            return;
308                    }
309    
310                    Document document = UnsecureSAXReaderUtil.read(inputStream, true);
311    
312                    Element rootElement = document.getRootElement();
313    
314                    Map<String, ObjectValuePair<Filter, FilterConfig>>
315                            filterObjectValuePairs = new HashMap<>();
316    
317                    for (Element filterElement : rootElement.elements("filter")) {
318                            String filterName = filterElement.elementText("filter-name");
319                            String filterClassName = filterElement.elementText("filter-class");
320    
321                            Map<String, String> initParameterMap = new HashMap<>();
322    
323                            List<Element> initParamElements = filterElement.elements(
324                                    "init-param");
325    
326                            for (Element initParamElement : initParamElements) {
327                                    String name = initParamElement.elementText("param-name");
328                                    String value = initParamElement.elementText("param-value");
329    
330                                    initParameterMap.put(name, value);
331                            }
332    
333                            FilterConfig filterConfig = new InvokerFilterConfig(
334                                    servletContext, filterName, initParameterMap);
335    
336                            Filter filter = initFilter(
337                                    servletContext, filterClassName, filterName, filterConfig);
338    
339                            if (filter != null) {
340                                    filterObjectValuePairs.put(
341                                            filterName, new ObjectValuePair<>(filter, filterConfig));
342                            }
343                    }
344    
345                    List<Element> filterMappingElements = rootElement.elements(
346                            "filter-mapping");
347    
348                    for (Element filterMappingElement : filterMappingElements) {
349                            String filterName = filterMappingElement.elementText("filter-name");
350    
351                            List<String> urlPatterns = new ArrayList<>();
352    
353                            List<Element> urlPatternElements = filterMappingElement.elements(
354                                    "url-pattern");
355    
356                            for (Element urlPatternElement : urlPatternElements) {
357                                    urlPatterns.add(urlPatternElement.getTextTrim());
358                            }
359    
360                            List<String> dispatchers = new ArrayList<>(4);
361    
362                            List<Element> dispatcherElements = filterMappingElement.elements(
363                                    "dispatcher");
364    
365                            for (Element dispatcherElement : dispatcherElements) {
366                                    String dispatcher = StringUtil.toUpperCase(
367                                            dispatcherElement.getTextTrim());
368    
369                                    dispatchers.add(dispatcher);
370                            }
371    
372                            ObjectValuePair<Filter, FilterConfig> filterObjectValuePair =
373                                    filterObjectValuePairs.get(filterName);
374    
375                            if (filterObjectValuePair == null) {
376                                    if (_log.isWarnEnabled()) {
377                                            _log.warn(
378                                                    "No filter and filter config for filter name " +
379                                                            filterName);
380                                    }
381    
382                                    continue;
383                            }
384    
385                            FilterMapping filterMapping = new FilterMapping(
386                                    filterName, filterObjectValuePair.getKey(),
387                                    filterObjectValuePair.getValue(), urlPatterns, dispatchers);
388    
389                            registerFilterMapping(filterMapping, filterName, true);
390                    }
391            }
392    
393            private static final Log _log = LogFactoryUtil.getLog(
394                    InvokerFilterHelper.class);
395    
396            private final ConcurrentMap<String, List<FilterMapping>>
397                    _filterMappingsMap = new ConcurrentHashMap<>();
398            private final Set<String> _filterNames = new CopyOnWriteArraySet<>();
399            private final List<InvokerFilter> _invokerFilters = new ArrayList<>();
400            private ServiceTracker<Filter, FilterMapping> _serviceTracker;
401    
402            private class FilterServiceTrackerCustomizer
403                    implements ServiceTrackerCustomizer<Filter, FilterMapping> {
404    
405                    @Override
406                    public FilterMapping addingService(
407                            ServiceReference<Filter> serviceReference) {
408    
409                            Registry registry = RegistryUtil.getRegistry();
410    
411                            Filter filter = registry.getService(serviceReference);
412    
413                            String afterFilter = GetterUtil.getString(
414                                    serviceReference.getProperty("after-filter"));
415                            String beforeFilter = GetterUtil.getString(
416                                    serviceReference.getProperty("before-filter"));
417                            List<String> dispatchers = StringPlus.asList(
418                                    serviceReference.getProperty("dispatcher"));
419                            String servletContextName = GetterUtil.getString(
420                                    serviceReference.getProperty("servlet-context-name"),
421                                    StringPool.BLANK);
422                            String servletFilterName = GetterUtil.getString(
423                                    serviceReference.getProperty("servlet-filter-name"));
424                            List<String> urlPatterns = StringPlus.asList(
425                                    serviceReference.getProperty("url-pattern"));
426    
427                            String positionFilterName = beforeFilter;
428                            boolean after = false;
429    
430                            if (Validator.isNotNull(afterFilter)) {
431                                    positionFilterName = afterFilter;
432                                    after = true;
433                            }
434    
435                            Map<String, String> initParameterMap = new HashMap<>();
436    
437                            Map<String, Object> properties = serviceReference.getProperties();
438    
439                            for (String key : properties.keySet()) {
440                                    if (!key.startsWith("init.param.")) {
441                                            continue;
442                                    }
443    
444                                    String value = GetterUtil.getString(
445                                            serviceReference.getProperty(key));
446    
447                                    initParameterMap.put(key, value);
448                            }
449    
450                            ServletContext servletContext = ServletContextPool.get(
451                                    servletContextName);
452    
453                            FilterConfig filterConfig = new InvokerFilterConfig(
454                                    servletContext, servletFilterName, initParameterMap);
455    
456                            try {
457                                    filter.init(filterConfig);
458                            }
459                            catch (ServletException se) {
460                                    _log.error(se, se);
461    
462                                    registry.ungetService(serviceReference);
463    
464                                    return null;
465                            }
466    
467                            updateFilterMappings(servletFilterName, filter);
468    
469                            FilterMapping filterMapping = new FilterMapping(
470                                    servletFilterName, filter, filterConfig, urlPatterns,
471                                    dispatchers);
472    
473                            registerFilterMapping(filterMapping, positionFilterName, after);
474    
475                            clearFilterChainsCache();
476    
477                            return filterMapping;
478                    }
479    
480                    @Override
481                    public void modifiedService(
482                            ServiceReference<Filter> serviceReference,
483                            FilterMapping filterMapping) {
484    
485                            removedService(serviceReference, filterMapping);
486    
487                            addingService(serviceReference);
488                    }
489    
490                    @Override
491                    public void removedService(
492                            ServiceReference<Filter> serviceReference,
493                            FilterMapping filterMapping) {
494    
495                            Registry registry = RegistryUtil.getRegistry();
496    
497                            registry.ungetService(serviceReference);
498    
499                            unregisterFilterMappings(
500                                    GetterUtil.getString(
501                                            serviceReference.getProperty("servlet-filter-name")));
502                    }
503    
504            }
505    
506    }