001/*
002 * Copyright (c) 2006-2011 Nuxeo SA (http://nuxeo.com/) and others.
003 *
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *     bstefanescu
011 */
012package org.nuxeo.ecm.webengine.jaxrs.servlet;
013
014import java.io.IOException;
015import java.util.Collections;
016import java.util.Enumeration;
017
018import javax.servlet.Filter;
019import javax.servlet.FilterChain;
020import javax.servlet.ServletConfig;
021import javax.servlet.ServletContext;
022import javax.servlet.ServletException;
023import javax.servlet.ServletRequest;
024import javax.servlet.ServletResponse;
025import javax.servlet.http.HttpServlet;
026import javax.servlet.http.HttpServletRequest;
027
028import org.nuxeo.ecm.webengine.jaxrs.servlet.config.ServletDescriptor;
029import org.nuxeo.ecm.webengine.jaxrs.servlet.mapping.Path;
030
031/**
032 * @author <a href="mailto:bs@nuxeo.com">Bogdan Stefanescu</a>
033 */
034public class RequestChain {
035
036    protected HttpServlet servlet;
037
038    protected FilterSet[] filters;
039
040    /**
041     * Create a new request chain given the target servlet and an optional list of filter sets.
042     *
043     * @param servlet the target
044     * @param filters the filter sets
045     */
046    public RequestChain(HttpServlet servlet, FilterSet[] filters) {
047        if (servlet == null) {
048            throw new IllegalArgumentException("No target servlet defined");
049        }
050        this.filters = filters == null ? new FilterSet[0] : filters;
051        this.servlet = servlet;
052    }
053
054    public FilterSet[] getFilters() {
055        return filters;
056    }
057
058    public HttpServlet getServlet() {
059        return servlet;
060    }
061
062    public void init(ServletDescriptor sd, ServletConfig config) throws ServletException {
063        for (FilterSet filterSet : filters) {
064            filterSet.init(config);
065        }
066        if (servlet instanceof ManagedServlet) {
067            ((ManagedServlet) servlet).setDescriptor(sd);
068        }
069        servlet.init(new ServletConfigAdapter(sd, config));
070    }
071
072    public void execute(ServletRequest request, ServletResponse response) throws IOException, ServletException {
073        if (filters.length == 0 || (request instanceof HttpServletRequest == false)) {
074            servlet.service(request, response);
075            return;
076        }
077        String pathInfo = ((HttpServletRequest) request).getPathInfo();
078        Path path = pathInfo == null || pathInfo.length() == 0 ? Path.ROOT : Path.parse(pathInfo);
079        for (FilterSet filterSet : filters) {
080            if (filterSet.matches(path)) {
081                new ServletFilterChain(servlet, filterSet.getFilters()).doFilter(request, response);
082                return; // avoid running the servlet twice
083            }
084        }
085        // if not filters matched just run the target servlet.
086        servlet.service(request, response);
087    }
088
089    public void destroy() {
090        if (servlet != null) {
091            servlet.destroy();
092            servlet = null;
093        }
094        for (FilterSet filterSet : filters) {
095            filterSet.destroy();
096        }
097        filters = null;
098    }
099
100    public static class ServletFilterChain implements FilterChain {
101
102        protected final HttpServlet servlet;
103
104        protected final Filter[] filters;
105
106        protected int filterIndex;
107
108        public ServletFilterChain(HttpServlet servlet, Filter[] filters) {
109            this.servlet = servlet;
110            this.filters = filters;
111            filterIndex = 0;
112
113        }
114
115        @Override
116        public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException {
117            if (filterIndex < filters.length) {
118                Filter filter = filters[filterIndex++];
119                filter.doFilter(request, response, this);
120            } else {
121                servlet.service(request, response);
122            }
123        }
124    }
125
126    static class ServletConfigAdapter implements ServletConfig {
127        protected final ServletConfig config;
128
129        protected final ServletDescriptor sd;
130
131        public ServletConfigAdapter(ServletDescriptor sd, ServletConfig config) {
132            this.config = config;
133            this.sd = sd;
134        }
135
136        @Override
137        public String getInitParameter(String key) {
138            return sd.getInitParams().get(key);
139        }
140
141        @Override
142        public Enumeration<String> getInitParameterNames() {
143            return Collections.enumeration(sd.getInitParams().keySet());
144        }
145
146        @Override
147        public ServletContext getServletContext() {
148            return config.getServletContext();
149        }
150
151        @Override
152        public String getServletName() {
153            return sd.getName();
154        }
155    }
156
157}