001/*
002 * (C) Copyright 2018 Nuxeo (http://nuxeo.com/) and others.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 *
016 * Contributors:
017 *     Florent Guillaume
018 */
019package org.nuxeo.ecm.platform.web.common.requestcontroller.filter;
020
021import static com.google.common.net.HttpHeaders.ORIGIN;
022import static com.google.common.net.HttpHeaders.REFERER;
023import static javax.servlet.http.HttpServletResponse.SC_FORBIDDEN;
024import static javax.servlet.http.HttpServletResponse.SC_OK;
025import static org.apache.commons.lang3.StringUtils.isBlank;
026
027import java.io.IOException;
028import java.io.Serializable;
029import java.net.URI;
030import java.net.URISyntaxException;
031import java.security.SecureRandom;
032import java.util.ArrayList;
033import java.util.Arrays;
034import java.util.HashSet;
035import java.util.List;
036import java.util.Map;
037import java.util.Objects;
038import java.util.Random;
039import java.util.Set;
040
041import javax.servlet.Filter;
042import javax.servlet.FilterChain;
043import javax.servlet.FilterConfig;
044import javax.servlet.ServletException;
045import javax.servlet.ServletRequest;
046import javax.servlet.ServletResponse;
047import javax.servlet.http.HttpServletRequest;
048import javax.servlet.http.HttpServletRequestWrapper;
049import javax.servlet.http.HttpServletResponse;
050import javax.servlet.http.HttpSession;
051
052import org.apache.commons.lang3.RandomStringUtils;
053import org.apache.commons.lang3.StringUtils;
054import org.apache.commons.logging.Log;
055import org.apache.commons.logging.LogFactory;
056import org.apache.http.client.methods.HttpGet;
057import org.apache.http.client.methods.HttpHead;
058import org.apache.http.client.methods.HttpOptions;
059import org.apache.http.client.methods.HttpTrace;
060import org.nuxeo.ecm.platform.web.common.requestcontroller.service.RequestControllerManager;
061import org.nuxeo.ecm.platform.web.common.vh.VirtualHostHelper;
062import org.nuxeo.runtime.api.Framework;
063import org.nuxeo.runtime.services.config.ConfigurationService;
064
065import com.thetransactioncompany.cors.CORSConfiguration;
066import com.thetransactioncompany.cors.CORSFilter;
067import com.thetransactioncompany.cors.Origin;
068
069/**
070 * Nuxeo CORS and CSRF filter, returning CORS configuration and preventing CSRF attacks by rejecting dubious requests.
071 *
072 * @since 5.7.2 for CORS
073 * @since 10.1 for CSRF
074 */
075public class NuxeoCorsCsrfFilter implements Filter {
076
077    private static final Log log = LogFactory.getLog(NuxeoCorsCsrfFilter.class);
078
079    public static final String GET = HttpGet.METHOD_NAME;
080
081    public static final String HEAD = HttpHead.METHOD_NAME;
082
083    public static final String OPTIONS = HttpOptions.METHOD_NAME;
084
085    public static final String TRACE = HttpTrace.METHOD_NAME;
086
087    // safe methods according to RFC 7231 4.2.1
088    protected static final Set<String> SAFE_METHODS = new HashSet<>(Arrays.asList(GET, HEAD, OPTIONS, TRACE));
089
090    // RFC 6454
091    // 6.2 If the origin is not a scheme/host/port triple, then return the string null
092    // 7.3 Whenever a user agent issues an HTTP request from a "privacy-sensitive" context,
093    // the user agent MUST send the value "null" in the Origin header field.
094    public static final String ORIGIN_NULL = "null";
095
096    // marker for privacy-sensitive origins
097    public static final URI PRIVACY_SENSITIVE = URI.create("privacy-sensitive:///");
098
099    public static final List<String> SCHEMES_ALLOWED = Arrays.asList("moz-extension", "chrome-extension");
100
101    /**
102     * Allows to disable strict CORS checks when a request has Origin: null.
103     * <p>
104     * This may happen for local files, or for a JavaScript-triggered redirect. Setting this to false may expose the
105     * application to CSRF problems from files locally hosted on the user's disk.
106     *
107     * @since 10.3
108     */
109    public static final String ALLOW_NULL_ORIGIN_PROP = "nuxeo.cors.allowNullOrigin";
110
111    /**
112     * Configuration property (namespace) for CSRF tokens.
113     *
114     * @since 10.3
115     */
116    public static final String CSRF_TOKEN_NS_PROP = "nuxeo.csrf.token";
117
118    /**
119     * Allows enforcing the use of a CSRF token. Configuration property (under the {@value #CSRF_TOKEN_NS_PROP}
120     * namespace).
121     *
122     * @since 10.3
123     */
124    public static final String CSRF_TOKEN_ENABLED_SUBPROP = "enabled";
125
126    /** @since 10.3 */
127    public static final String CSRF_TOKEN_ENABLED_DEFAULT = "false";
128
129    /**
130     * Allows definition of endpoints for which no CSRF token check is done. Configuration <em>list</em> property (under
131     * the {@value #CSRF_TOKEN_NS_PROP} namespace).
132     *
133     * @since 10.3
134     */
135    public static final String CSRF_TOKEN_SKIP_SUBPROP = "skip";
136
137    /**
138     * Session attribute in which token is stored.
139     *
140     * @since 10.3
141     */
142    public static final String CSRF_TOKEN_ATTRIBUTE = "NuxeoCSRFToken";
143
144    /**
145     * Request header to pass a token, or fetch one.
146     *
147     * @since 10.3
148     */
149    public static final String CSRF_TOKEN_HEADER = "CSRF-Token";
150
151    /**
152     * Pseudo-value to fetch a token.
153     *
154     * @since 10.3
155     */
156    public static final String CSRF_TOKEN_FETCH = "fetch";
157
158    /**
159     * Pseudo-value to denote an invalid token.
160     *
161     * @since 10.3
162     */
163    public static final String CSRF_TOKEN_INVALID = "invalid";
164
165    /**
166     * Request parameter to pass a token.
167     *
168     * @since 10.3
169     */
170    public static final String CSRF_TOKEN_PARAM = "csrf-token";
171
172    protected static final Random RANDOM = new SecureRandom();
173
174    protected boolean allowNullOrigin;
175
176    protected boolean csrfTokenEnabled;
177
178    protected List<String> csrfTokenSkipPaths;
179
180    @Override
181    public void init(FilterConfig filterConfig) {
182        ConfigurationService configurationService = Framework.getService(ConfigurationService.class);
183        allowNullOrigin = configurationService.isBooleanTrue(ALLOW_NULL_ORIGIN_PROP);
184        Map<String, Serializable> csrfTokenConfig = configurationService.getProperties(CSRF_TOKEN_NS_PROP);
185        csrfTokenEnabled = Boolean.parseBoolean(StringUtils.defaultString(
186                (String) csrfTokenConfig.get(CSRF_TOKEN_ENABLED_SUBPROP), CSRF_TOKEN_ENABLED_DEFAULT));
187        csrfTokenSkipPaths = new ArrayList<>();
188        Serializable skipPaths = csrfTokenConfig.get(CSRF_TOKEN_SKIP_SUBPROP);
189        if (skipPaths instanceof String[]) {
190            csrfTokenSkipPaths.addAll(Arrays.asList((String[]) skipPaths));
191        }
192    }
193
194    @Override
195    public void destroy() {
196        // nothing to do
197    }
198
199    @Override
200    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain)
201            throws IOException, ServletException {
202        HttpServletRequest request = (HttpServletRequest) servletRequest;
203        HttpServletResponse response = (HttpServletResponse) servletResponse;
204
205        if (manageCSRFToken(request, response)) {
206            return;
207        }
208
209        RequestControllerManager service = Framework.getService(RequestControllerManager.class);
210        CORSFilter corsFilter = service.getCorsFilterForRequest(request);
211        CORSConfiguration corsConfig = corsFilter == null ? null : corsFilter.getConfiguration();
212        String method = request.getMethod();
213        URI sourceURI = getSourceURI(request);
214        URI targetURI = getTargetURI(request);
215        if (log.isDebugEnabled()) {
216            log.debug("Method: " + method + ", source: " + sourceURI + ", target: " + targetURI);
217        }
218
219        boolean allow;
220        if (isSafeMethod(method)) {
221            // safe method according to RFC 7231 4.2.1
222            log.debug("Safe method: allow");
223            allow = true;
224        } else if (sourceAndTargetMatch(sourceURI, targetURI)) {
225            // source and target match, or not provided
226            log.debug("Source and target match: allow");
227            if (targetURI == null) {
228                // misconfigured server or proxy headers
229                log.error("Cannot determine target URL for CSRF check");
230            }
231            allow = true;
232        } else if (corsConfig == null) {
233            // source not known by CORS config: be safe and disallow
234            log.debug("URL not covered by CORS config: disallow cross-site request");
235            allow = false;
236        } else if (!corsConfig.isAllowedOrigin(originFromURI(sourceURI))) {
237            // not in allowed CORS origins
238            log.debug("Origin not allowed by CORS config: disallow cross-site request");
239            allow = false;
240        } else if (!corsConfig.isSupportedMethod(method)) {
241            // not in allowed CORS methods
242            log.debug("Method not allowed by CORS config: disallow cross-site request");
243            allow = false;
244        } else {
245            log.debug("Origin and method allowed by CORS config: allow cross-site request");
246            allow = true;
247        }
248
249        if (allow) {
250            if (corsFilter == null) {
251                chain.doFilter(request, response);
252            } else {
253                request = maybeIgnoreWhitelistedOrigin(request);
254                corsFilter.doFilter(request, response, chain);
255            }
256            return;
257        }
258
259        // disallowed cross-site request
260        String message = "CSRF check failure";
261        log.warn(message + ": source: " + sourceURI + " does not match target: " + targetURI
262                + " and not allowed by CORS config");
263        response.sendError(HttpServletResponse.SC_FORBIDDEN, message);
264    }
265
266    /**
267     * Check safe method according to RFC 7231 4.2.1.
268     */
269    protected boolean isSafeMethod(String method) {
270        return SAFE_METHODS.contains(method);
271    }
272
273    /**
274     * Manages the CSRF token.
275     * <p>
276     * This method may return a response with token fetch information or with an error if needed, in which case it will
277     * return {@code true}.
278     *
279     * @return {@code true} if the caller doesn't need to do more work (a response has been sent)
280     * @since 10.3
281     */
282    protected boolean manageCSRFToken(HttpServletRequest request, HttpServletResponse response) throws IOException {
283        if (!csrfTokenEnabled) {
284            log.debug("No CSRF token check configured");
285            return false; // no check to do
286        }
287
288        String method = request.getMethod();
289        String path = request.getServletPath();
290        if (path == null) {
291            path = "";
292        }
293        String pathInfo = request.getPathInfo();
294        if (pathInfo != null) {
295            path += pathInfo;
296        }
297        String requestToken = request.getHeader(CSRF_TOKEN_HEADER);
298
299        // token fetch request
300        if (GET.equals(method) && path.isEmpty() && CSRF_TOKEN_FETCH.equals(requestToken)) {
301            HttpSession session = request.getSession(); // create if needed
302            String token = (String) session.getAttribute(CSRF_TOKEN_ATTRIBUTE);
303            if (token == null) {
304                token = generateNewToken();
305                session.setAttribute(CSRF_TOKEN_ATTRIBUTE, token);
306            }
307            log.debug("Returning CSRF token fetch");
308            response.setHeader(CSRF_TOKEN_HEADER, token);
309            response.setStatus(SC_OK);
310            return true;
311
312        }
313
314        // do we need to check the token?
315        if (isSafeMethod(method)) {
316            log.debug("No CSRF token check on safe method");
317            return false;
318        }
319
320        // is the endpoint specially configured to skip the token check?
321        if (csrfTokenSkipPaths.contains(path)) {
322            log.debug("No CSRF token check on configured endpoint");
323            return false;
324        }
325
326        // check the token
327        HttpSession session = request.getSession(false);
328        String token;
329        if (session == null || (token = (String) session.getAttribute(CSRF_TOKEN_ATTRIBUTE)) == null) {
330            log.debug("Error, no session or no CSRF token in session");
331            String message = "CSRF check failure";
332            log.warn(message + ": invalid token");
333            response.setHeader(CSRF_TOKEN_HEADER, CSRF_TOKEN_INVALID);
334            response.sendError(SC_FORBIDDEN, message);
335            return true;
336        }
337        if (StringUtils.isEmpty(requestToken)) {
338            // allow request parameter to contain the token too
339            requestToken = request.getParameter(CSRF_TOKEN_PARAM);
340        }
341        if (!token.equals(requestToken)) {
342            log.debug("Error, CSRF token does not match");
343            String message = "CSRF check failure";
344            log.warn(message + ": invalid token");
345            response.setHeader(CSRF_TOKEN_HEADER, CSRF_TOKEN_INVALID);
346            response.sendError(SC_FORBIDDEN, message);
347            return true;
348        }
349
350        // token is ok, proceed
351        log.debug("CSRF token matches");
352        return false;
353    }
354
355    protected String generateNewToken() {
356        return RandomStringUtils.random(40, 0, 0, true, true, null, RANDOM);
357    }
358
359    /**
360     * Gets the source URI: the URI of the page from which the request is actually coming.
361     * <p>
362     * {@code null} is returned is there is no header.
363     * <p>
364     * {@link #PRIVACY_SENSITIVE} is returned is there is a null origin (RFC 6454 7.3, "privacy-sensitive" context)
365     * unless configured to be ignored.
366     */
367    public URI getSourceURI(HttpServletRequest request) {
368        String source = request.getHeader(ORIGIN);
369        if (isBlank(source)) {
370            source = request.getHeader(REFERER);
371        }
372        if (isBlank(source)) {
373            return null;
374        }
375        source = source.trim();
376        if (ORIGIN_NULL.equals(source)) {
377            return allowNullOrigin ? null : PRIVACY_SENSITIVE;
378        }
379        if (source.contains(" ")) {
380            // RFC 6454 7.1 origin-list
381            // keep only the first origin to simplify the logic (nobody sends two origins anyway)
382            source = source.substring(0, source.indexOf(' '));
383        }
384        try {
385            return new URI(source); // NOSONAR (URI is not opened as a stream)
386        } catch (URISyntaxException e) {
387            return null;
388        }
389    }
390
391    /** Gets the target URI: the URI to which the browser is connecting. */
392    public URI getTargetURI(HttpServletRequest request) {
393        String baseURL = VirtualHostHelper.getServerURL(request, false);
394        if (baseURL == null) {
395            return null;
396        }
397        try {
398            return new URI(baseURL); // NOSONAR (URI is not opened as a stream)
399        } catch (URISyntaxException e) {
400            return null;
401        }
402    }
403
404    public boolean sourceAndTargetMatch(URI sourceURI, URI targetURI) {
405        if (sourceURI == null || targetURI == null) {
406            return true;
407        }
408        if (isWhitelistedScheme(sourceURI)) {
409            return true;
410        }
411        return Objects.equals(sourceURI.getScheme(), targetURI.getScheme()) //
412                && Objects.equals(sourceURI.getHost(), targetURI.getHost()) //
413                && sourceURI.getPort() == targetURI.getPort();
414    }
415
416    /**
417     * Gets an Origin from a URI. Strips the path and query (which may be present in Referer headers).
418     */
419    protected Origin originFromURI(URI uri) {
420        // remove path, query and fragment
421        try {
422            uri = new URI(uri.getScheme(), null, uri.getHost(), uri.getPort(), null, null, null);
423        } catch (URISyntaxException e) {
424            // keep passed-in URI
425        }
426        return new Origin(uri.toString());
427    }
428
429    protected HttpServletRequest maybeIgnoreWhitelistedOrigin(HttpServletRequest request) {
430        String origin = request.getHeader(ORIGIN);
431        if (origin == null) {
432            return request;
433        }
434        URI uri;
435        try {
436            uri = new URI(origin); // NOSONAR (URI is not opened as a stream)
437        } catch (URISyntaxException e) {
438            return request;
439        }
440        if (!isWhitelistedScheme(uri)) {
441            return request;
442        }
443        // wrap request to pretend that the Origin is absent
444        return new IgnoredOriginRequestWrapper(request);
445    }
446
447    protected boolean isWhitelistedScheme(URI uri) {
448        return SCHEMES_ALLOWED.contains(uri.getScheme());
449    }
450
451    /**
452     * Wrapper for the request to hide the Origin header.
453     *
454     * @since 10.2
455     */
456    public static class IgnoredOriginRequestWrapper extends HttpServletRequestWrapper {
457
458        public IgnoredOriginRequestWrapper(HttpServletRequest request) {
459            super(request);
460        }
461
462        @Override
463        public String getHeader(String name) {
464            if (ORIGIN.equalsIgnoreCase(name)) {
465                return null;
466            }
467            return super.getHeader(name);
468        }
469    }
470
471}