001/*
002 * (C) Copyright 2006-2011 Nuxeo SA (http://nuxeo.com/) and others.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 *
016 * Contributors:
017 *     bstefanescu
018 */
019package org.nuxeo.ecm.webengine.jaxrs.servlet;
020
021import java.io.IOException;
022import java.util.Collections;
023import java.util.Enumeration;
024
025import javax.servlet.Filter;
026import javax.servlet.FilterChain;
027import javax.servlet.ServletConfig;
028import javax.servlet.ServletContext;
029import javax.servlet.ServletException;
030import javax.servlet.ServletRequest;
031import javax.servlet.ServletResponse;
032import javax.servlet.http.HttpServlet;
033import javax.servlet.http.HttpServletRequest;
034
035import org.nuxeo.ecm.webengine.jaxrs.servlet.config.ServletDescriptor;
036import org.nuxeo.ecm.webengine.jaxrs.servlet.mapping.Path;
037
038/**
039 * @author <a href="mailto:bs@nuxeo.com">Bogdan Stefanescu</a>
040 */
041public class RequestChain {
042
043    protected HttpServlet servlet;
044
045    protected FilterSet[] filters;
046
047    /**
048     * Create a new request chain given the target servlet and an optional list of filter sets.
049     *
050     * @param servlet the target
051     * @param filters the filter sets
052     */
053    public RequestChain(HttpServlet servlet, FilterSet[] filters) {
054        if (servlet == null) {
055            throw new IllegalArgumentException("No target servlet defined");
056        }
057        this.filters = filters == null ? new FilterSet[0] : filters;
058        this.servlet = servlet;
059    }
060
061    public FilterSet[] getFilters() {
062        return filters;
063    }
064
065    public HttpServlet getServlet() {
066        return servlet;
067    }
068
069    public void init(ServletDescriptor sd, ServletConfig config) throws ServletException {
070        for (FilterSet filterSet : filters) {
071            filterSet.init(config);
072        }
073        if (servlet instanceof ManagedServlet) {
074            ((ManagedServlet) servlet).setDescriptor(sd);
075        }
076        servlet.init(new ServletConfigAdapter(sd, config));
077    }
078
079    public void execute(ServletRequest request, ServletResponse response) throws IOException, ServletException {
080        if (filters.length == 0 || (request instanceof HttpServletRequest == false)) {
081            servlet.service(request, response);
082            return;
083        }
084        String pathInfo = ((HttpServletRequest) request).getPathInfo();
085        Path path = pathInfo == null || pathInfo.length() == 0 ? Path.ROOT : Path.parse(pathInfo);
086        for (FilterSet filterSet : filters) {
087            if (filterSet.matches(path)) {
088                new ServletFilterChain(servlet, filterSet.getFilters()).doFilter(request, response);
089                return; // avoid running the servlet twice
090            }
091        }
092        // if not filters matched just run the target servlet.
093        servlet.service(request, response);
094    }
095
096    public void destroy() {
097        if (servlet != null) {
098            servlet.destroy();
099            servlet = null;
100        }
101        for (FilterSet filterSet : filters) {
102            filterSet.destroy();
103        }
104        filters = null;
105    }
106
107    public static class ServletFilterChain implements FilterChain {
108
109        protected final HttpServlet servlet;
110
111        protected final Filter[] filters;
112
113        protected int filterIndex;
114
115        public ServletFilterChain(HttpServlet servlet, Filter[] filters) {
116            this.servlet = servlet;
117            this.filters = filters;
118            filterIndex = 0;
119
120        }
121
122        @Override
123        public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException {
124            if (filterIndex < filters.length) {
125                Filter filter = filters[filterIndex++];
126                filter.doFilter(request, response, this);
127            } else {
128                servlet.service(request, response);
129            }
130        }
131    }
132
133    static class ServletConfigAdapter implements ServletConfig {
134        protected final ServletConfig config;
135
136        protected final ServletDescriptor sd;
137
138        public ServletConfigAdapter(ServletDescriptor sd, ServletConfig config) {
139            this.config = config;
140            this.sd = sd;
141        }
142
143        @Override
144        public String getInitParameter(String key) {
145            return sd.getInitParams().get(key);
146        }
147
148        @Override
149        public Enumeration<String> getInitParameterNames() {
150            return Collections.enumeration(sd.getInitParams().keySet());
151        }
152
153        @Override
154        public ServletContext getServletContext() {
155            return config.getServletContext();
156        }
157
158        @Override
159        public String getServletName() {
160            return sd.getName();
161        }
162    }
163
164}