001/*
002 * (C) Copyright 2018 Nuxeo (http://nuxeo.com/) and others.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 *
016 * Contributors:
017 *     Florent Guillaume
018 */
019package org.nuxeo.ecm.platform.web.common.requestcontroller.filter;
020
021import static com.google.common.net.HttpHeaders.ORIGIN;
022import static com.google.common.net.HttpHeaders.REFERER;
023import static org.apache.commons.lang3.StringUtils.isBlank;
024
025import java.io.IOException;
026import java.net.URI;
027import java.net.URISyntaxException;
028import java.util.Arrays;
029import java.util.List;
030import java.util.Objects;
031
032import javax.servlet.Filter;
033import javax.servlet.FilterChain;
034import javax.servlet.FilterConfig;
035import javax.servlet.ServletException;
036import javax.servlet.ServletRequest;
037import javax.servlet.ServletResponse;
038import javax.servlet.http.HttpServletRequest;
039import javax.servlet.http.HttpServletRequestWrapper;
040import javax.servlet.http.HttpServletResponse;
041
042import org.apache.commons.logging.Log;
043import org.apache.commons.logging.LogFactory;
044import org.apache.http.client.methods.HttpGet;
045import org.apache.http.client.methods.HttpHead;
046import org.apache.http.client.methods.HttpOptions;
047import org.apache.http.client.methods.HttpTrace;
048import org.nuxeo.ecm.platform.web.common.requestcontroller.service.RequestControllerManager;
049import org.nuxeo.ecm.platform.web.common.vh.VirtualHostHelper;
050import org.nuxeo.runtime.api.Framework;
051
052import com.thetransactioncompany.cors.CORSConfiguration;
053import com.thetransactioncompany.cors.CORSFilter;
054import com.thetransactioncompany.cors.Origin;
055
056/**
057 * Nuxeo CORS and CSRF filter, returning CORS configuration and preventing CSRF attacks by rejecting dubious requests.
058 *
059 * @since 5.7.2 for CORS
060 * @since 10.1 for CSRF
061 */
062public class NuxeoCorsCsrfFilter implements Filter {
063
064    private static final Log log = LogFactory.getLog(NuxeoCorsCsrfFilter.class);
065
066    public static final String GET = HttpGet.METHOD_NAME;
067
068    public static final String HEAD = HttpHead.METHOD_NAME;
069
070    public static final String OPTIONS = HttpOptions.METHOD_NAME;
071
072    public static final String TRACE = HttpTrace.METHOD_NAME;
073
074    public static final List<String> SCHEMES_ALLOWED = Arrays.asList("moz-extension", "chrome-extension");
075
076    @Override
077    public void init(FilterConfig filterConfig) {
078        // nothing to do
079    }
080
081    @Override
082    public void destroy() {
083        // nothing to do
084    }
085
086    @Override
087    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain)
088            throws IOException, ServletException {
089        HttpServletRequest request = (HttpServletRequest) servletRequest;
090        HttpServletResponse response = (HttpServletResponse) servletResponse;
091
092        RequestControllerManager service = Framework.getService(RequestControllerManager.class);
093        CORSFilter corsFilter = service.getCorsFilterForRequest(request);
094        CORSConfiguration corsConfig = corsFilter == null ? null : corsFilter.getConfiguration();
095        String method = request.getMethod();
096        URI sourceURI = getSourceURI(request);
097        URI targetURI = getTargetURI(request);
098        if (log.isDebugEnabled()) {
099            log.debug("Method: " + method + ", source: " + sourceURI + ", target: " + targetURI);
100        }
101
102        boolean allow;
103        if (GET.equals(method) || HEAD.equals(method) || OPTIONS.equals(method) || TRACE.equals(method)) {
104            // safe method according to RFC 7231 4.2.1
105            log.debug("Safe method: allow");
106            allow = true;
107        } else if (sourceAndTargetMatch(sourceURI, targetURI)) {
108            // source and target match, or not provided
109            log.debug("Source and target match: allow");
110            if (targetURI == null) {
111                // misconfigured server or proxy headers
112                log.error("Cannot determine target URL for CSRF check");
113            }
114            allow = true;
115        } else if (corsConfig == null) {
116            // source not known by CORS config: be safe and disallow
117            log.debug("URL not covered by CORS config: disallow cross-site request");
118            allow = false;
119        } else if (!corsConfig.isAllowedOrigin(new Origin(sourceURI.toString()))) {
120            // not in allowed CORS origins
121            log.debug("Origin not allowed by CORS config: disallow cross-site request");
122            allow = false;
123        } else if (!corsConfig.isSupportedMethod(method)) {
124            // not in allowed CORS methods
125            log.debug("Method not allowed by CORS config: disallow cross-site request");
126            allow = false;
127        } else {
128            log.debug("Origin and method allowed by CORS config: allow cross-site request");
129            allow = true;
130        }
131
132        if (allow) {
133            if (corsFilter == null) {
134                chain.doFilter(request, response);
135            } else {
136                request = maybeIgnoreWhitelistedOrigin(request);
137                corsFilter.doFilter(request, response, chain);
138            }
139            return;
140        }
141
142        // disallowed cross-site request
143        String message = "CSRF check failure";
144        log.warn(message + ": source: " + sourceURI + " does not match target: " + targetURI
145                + " and not allowed by CORS config");
146        response.sendError(HttpServletResponse.SC_FORBIDDEN, message);
147    }
148
149    /** Gets the source URI: the URI of the page from which the request is actually coming. */
150    public URI getSourceURI(HttpServletRequest request) {
151        String source = request.getHeader(ORIGIN);
152        if (isBlank(source)) {
153            source = request.getHeader(REFERER);
154        }
155        if (isBlank(source)) {
156            return null;
157        }
158        source = source.trim();
159        if ("null".equals(source)) {
160            // RFC 6454 7.1 origin-list-or-null
161            return null;
162        }
163        if (source.contains(" ")) {
164            // RFC 6454 7.1 origin-list
165            // keep only the first origin to simplify the logic (nobody sends two origins anyway)
166            source = source.substring(0, source.indexOf(' '));
167        }
168        try {
169            return new URI(source);
170        } catch (URISyntaxException e) {
171            return null;
172        }
173    }
174
175    /** Gets the target URI: the URI to which the browser is connecting. */
176    public URI getTargetURI(HttpServletRequest request) {
177        String baseURL = VirtualHostHelper.getServerURL(request, false);
178        if (baseURL == null) {
179            return null;
180        }
181        try {
182            return new URI(baseURL);
183        } catch (URISyntaxException e) {
184            return null;
185        }
186    }
187
188    public boolean sourceAndTargetMatch(URI sourceURI, URI targetURI) {
189        if (sourceURI == null || targetURI == null) {
190            return true;
191        }
192        if (isWhitelistedScheme(sourceURI)) {
193            return true;
194        }
195        return Objects.equals(sourceURI.getScheme(), targetURI.getScheme()) //
196                && Objects.equals(sourceURI.getHost(), targetURI.getHost()) //
197                && sourceURI.getPort() == targetURI.getPort();
198    }
199
200    protected HttpServletRequest maybeIgnoreWhitelistedOrigin(HttpServletRequest request) {
201        String origin = request.getHeader(ORIGIN);
202        if (origin == null) {
203            return request;
204        }
205        URI uri;
206        try {
207            uri = new URI(origin);
208        } catch (URISyntaxException e) {
209            return request;
210        }
211        if (!isWhitelistedScheme(uri)) {
212            return request;
213        }
214        // wrap request to pretend that the Origin is absent
215        return new IgnoredOriginRequestWrapper(request);
216    }
217
218    protected boolean isWhitelistedScheme(URI uri) {
219        return SCHEMES_ALLOWED.contains(uri.getScheme());
220    }
221
222    /**
223     * Wrapper for the request to hide the Origin header.
224     *
225     * @since 10.2
226     */
227    public static class IgnoredOriginRequestWrapper extends HttpServletRequestWrapper {
228
229        public IgnoredOriginRequestWrapper(HttpServletRequest request) {
230            super(request);
231        }
232
233        @Override
234        public String getHeader(String name) {
235            if (ORIGIN.equalsIgnoreCase(name)) {
236                return null;
237            }
238            return super.getHeader(name);
239        }
240    }
241
242}