001/* 002 * Copyright (c) 2006-2011 Nuxeo SA (http://nuxeo.com/) and others. 003 * 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 * 009 * Contributors: 010 * bstefanescu 011 */ 012package org.nuxeo.ecm.webengine.jaxrs.servlet; 013 014import java.io.IOException; 015import java.util.Collections; 016import java.util.Enumeration; 017 018import javax.servlet.Filter; 019import javax.servlet.FilterChain; 020import javax.servlet.ServletConfig; 021import javax.servlet.ServletContext; 022import javax.servlet.ServletException; 023import javax.servlet.ServletRequest; 024import javax.servlet.ServletResponse; 025import javax.servlet.http.HttpServlet; 026import javax.servlet.http.HttpServletRequest; 027 028import org.nuxeo.ecm.webengine.jaxrs.servlet.config.ServletDescriptor; 029import org.nuxeo.ecm.webengine.jaxrs.servlet.mapping.Path; 030 031/** 032 * @author <a href="mailto:bs@nuxeo.com">Bogdan Stefanescu</a> 033 */ 034public class RequestChain { 035 036 protected HttpServlet servlet; 037 038 protected FilterSet[] filters; 039 040 /** 041 * Create a new request chain given the target servlet and an optional list of filter sets. 042 * 043 * @param servlet the target 044 * @param filters the filter sets 045 */ 046 public RequestChain(HttpServlet servlet, FilterSet[] filters) { 047 if (servlet == null) { 048 throw new IllegalArgumentException("No target servlet defined"); 049 } 050 this.filters = filters == null ? new FilterSet[0] : filters; 051 this.servlet = servlet; 052 } 053 054 public FilterSet[] getFilters() { 055 return filters; 056 } 057 058 public HttpServlet getServlet() { 059 return servlet; 060 } 061 062 public void init(ServletDescriptor sd, ServletConfig config) throws ServletException { 063 for (FilterSet filterSet : filters) { 064 filterSet.init(config); 065 } 066 if (servlet instanceof ManagedServlet) { 067 ((ManagedServlet) servlet).setDescriptor(sd); 068 } 069 servlet.init(new ServletConfigAdapter(sd, config)); 070 } 071 072 public void execute(ServletRequest request, ServletResponse response) throws IOException, ServletException { 073 if (filters.length == 0 || (request instanceof HttpServletRequest == false)) { 074 servlet.service(request, response); 075 return; 076 } 077 String pathInfo = ((HttpServletRequest) request).getPathInfo(); 078 Path path = pathInfo == null || pathInfo.length() == 0 ? Path.ROOT : Path.parse(pathInfo); 079 for (FilterSet filterSet : filters) { 080 if (filterSet.matches(path)) { 081 new ServletFilterChain(servlet, filterSet.getFilters()).doFilter(request, response); 082 return; // avoid running the servlet twice 083 } 084 } 085 // if not filters matched just run the target servlet. 086 servlet.service(request, response); 087 } 088 089 public void destroy() { 090 if (servlet != null) { 091 servlet.destroy(); 092 servlet = null; 093 } 094 for (FilterSet filterSet : filters) { 095 filterSet.destroy(); 096 } 097 filters = null; 098 } 099 100 public static class ServletFilterChain implements FilterChain { 101 102 protected final HttpServlet servlet; 103 104 protected final Filter[] filters; 105 106 protected int filterIndex; 107 108 public ServletFilterChain(HttpServlet servlet, Filter[] filters) { 109 this.servlet = servlet; 110 this.filters = filters; 111 filterIndex = 0; 112 113 } 114 115 @Override 116 public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { 117 if (filterIndex < filters.length) { 118 Filter filter = filters[filterIndex++]; 119 filter.doFilter(request, response, this); 120 } else { 121 servlet.service(request, response); 122 } 123 } 124 } 125 126 static class ServletConfigAdapter implements ServletConfig { 127 protected final ServletConfig config; 128 129 protected final ServletDescriptor sd; 130 131 public ServletConfigAdapter(ServletDescriptor sd, ServletConfig config) { 132 this.config = config; 133 this.sd = sd; 134 } 135 136 @Override 137 public String getInitParameter(String key) { 138 return sd.getInitParams().get(key); 139 } 140 141 @Override 142 public Enumeration<String> getInitParameterNames() { 143 return Collections.enumeration(sd.getInitParams().keySet()); 144 } 145 146 @Override 147 public ServletContext getServletContext() { 148 return config.getServletContext(); 149 } 150 151 @Override 152 public String getServletName() { 153 return sd.getName(); 154 } 155 } 156 157}