/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.boot.actuate.trace;

import java.io.IOException;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import javax.servlet.http.HttpSession;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.boot.actuate.trace.TraceProperties;
import org.springframework.boot.actuate.trace.TraceRepository;
import org.springframework.boot.autoconfigure.web.servlet.error.ErrorAttributes;
import org.springframework.core.Ordered;
import org.springframework.http.HttpStatus;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.context.request.WebRequest;
import org.springframework.web.filter.OncePerRequestFilter;

public class WebRequestTraceFilter
extends OncePerRequestFilter
implements Ordered {
    private static final Log logger = LogFactory.getLog(WebRequestTraceFilter.class);
    private boolean dumpRequests = false;
    private int order = 0x7FFFFFF5;
    private final TraceRepository repository;
    private ErrorAttributes errorAttributes;
    private final TraceProperties properties;

    public WebRequestTraceFilter(TraceRepository repository, TraceProperties properties) {
        this.repository = repository;
        this.properties = properties;
    }

    public void setDumpRequests(boolean dumpRequests) {
        this.dumpRequests = dumpRequests;
    }

    public int getOrder() {
        return this.order;
    }

    public void setOrder(int order) {
        this.order = order;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        long startTime = System.nanoTime();
        Map<String, Object> trace = this.getTrace(request);
        this.logTrace(request, trace);
        int status = HttpStatus.INTERNAL_SERVER_ERROR.value();
        try {
            filterChain.doFilter((ServletRequest)request, (ServletResponse)response);
            status = response.getStatus();
            this.addTimeTaken(trace, startTime);
        }
        catch (Throwable throwable) {
            this.addTimeTaken(trace, startTime);
            this.enhanceTrace(trace, (HttpServletResponse)(status == response.getStatus() ? response : new CustomStatusResponseWrapper(response, status)));
            this.repository.add(trace);
            throw throwable;
        }
        this.enhanceTrace(trace, (HttpServletResponse)(status == response.getStatus() ? response : new CustomStatusResponseWrapper(response, status)));
        this.repository.add(trace);
    }

    protected Map<String, Object> getTrace(HttpServletRequest request) {
        HttpSession session = request.getSession(false);
        Throwable exception = (Throwable)request.getAttribute("javax.servlet.error.exception");
        Principal userPrincipal = request.getUserPrincipal();
        LinkedHashMap<String, Object> trace = new LinkedHashMap<String, Object>();
        LinkedHashMap<String, Map<String, Object>> headers = new LinkedHashMap<String, Map<String, Object>>();
        trace.put("method", request.getMethod());
        trace.put("path", request.getRequestURI());
        trace.put("headers", headers);
        if (this.isIncluded(TraceProperties.Include.REQUEST_HEADERS)) {
            headers.put("request", this.getRequestHeaders(request));
        }
        this.add(trace, TraceProperties.Include.PATH_INFO, "pathInfo", request.getPathInfo());
        this.add(trace, TraceProperties.Include.PATH_TRANSLATED, "pathTranslated", request.getPathTranslated());
        this.add(trace, TraceProperties.Include.CONTEXT_PATH, "contextPath", request.getContextPath());
        this.add(trace, TraceProperties.Include.USER_PRINCIPAL, "userPrincipal", userPrincipal == null ? null : userPrincipal.getName());
        if (this.isIncluded(TraceProperties.Include.PARAMETERS)) {
            trace.put("parameters", this.getParameterMapCopy(request));
        }
        this.add(trace, TraceProperties.Include.QUERY_STRING, "query", request.getQueryString());
        this.add(trace, TraceProperties.Include.AUTH_TYPE, "authType", request.getAuthType());
        this.add(trace, TraceProperties.Include.REMOTE_ADDRESS, "remoteAddress", request.getRemoteAddr());
        this.add(trace, TraceProperties.Include.SESSION_ID, "sessionId", session == null ? null : session.getId());
        this.add(trace, TraceProperties.Include.REMOTE_USER, "remoteUser", request.getRemoteUser());
        if (this.isIncluded(TraceProperties.Include.ERRORS) && exception != null && this.errorAttributes != null) {
            trace.put("error", this.errorAttributes.getErrorAttributes((WebRequest)new ServletWebRequest(request), true));
        }
        return trace;
    }

    private Map<String, Object> getRequestHeaders(HttpServletRequest request) {
        LinkedHashMap<String, Object> headers = new LinkedHashMap<String, Object>();
        Set<String> excludedHeaders = this.getExcludeHeaders();
        Enumeration names = request.getHeaderNames();
        while (names.hasMoreElements()) {
            String name = (String)names.nextElement();
            if (excludedHeaders.contains(name.toLowerCase())) continue;
            headers.put(name, this.getHeaderValue(request, name));
        }
        this.postProcessRequestHeaders(headers);
        return headers;
    }

    private Set<String> getExcludeHeaders() {
        HashSet<String> excludedHeaders = new HashSet<String>();
        if (!this.isIncluded(TraceProperties.Include.COOKIES)) {
            excludedHeaders.add("cookie");
        }
        if (!this.isIncluded(TraceProperties.Include.AUTHORIZATION_HEADER)) {
            excludedHeaders.add("authorization");
        }
        return excludedHeaders;
    }

    private Object getHeaderValue(HttpServletRequest request, String name) {
        ArrayList value = Collections.list(request.getHeaders(name));
        if (value.size() == 1) {
            return value.get(0);
        }
        if (value.isEmpty()) {
            return "";
        }
        return value;
    }

    private Map<String, String[]> getParameterMapCopy(HttpServletRequest request) {
        return new LinkedHashMap<String, String[]>(request.getParameterMap());
    }

    protected void postProcessRequestHeaders(Map<String, Object> headers) {
    }

    private void addTimeTaken(Map<String, Object> trace, long startTime) {
        long timeTaken = System.nanoTime() - startTime;
        this.add(trace, TraceProperties.Include.TIME_TAKEN, "timeTaken", "" + TimeUnit.NANOSECONDS.toMillis(timeTaken));
    }

    protected void enhanceTrace(Map<String, Object> trace, HttpServletResponse response) {
        if (this.isIncluded(TraceProperties.Include.RESPONSE_HEADERS)) {
            Map headers = (Map)trace.get("headers");
            headers.put("response", this.getResponseHeaders(response));
        }
    }

    private Map<String, String> getResponseHeaders(HttpServletResponse response) {
        LinkedHashMap<String, String> headers = new LinkedHashMap<String, String>();
        for (String header : response.getHeaderNames()) {
            String value = response.getHeader(header);
            headers.put(header, value);
        }
        if (!this.isIncluded(TraceProperties.Include.COOKIES)) {
            headers.remove("Set-Cookie");
        }
        headers.put("status", String.valueOf(response.getStatus()));
        return headers;
    }

    private void logTrace(HttpServletRequest request, Map<String, Object> trace) {
        if (logger.isTraceEnabled()) {
            logger.trace((Object)("Processing request " + request.getMethod() + " " + request.getRequestURI()));
            if (this.dumpRequests) {
                logger.trace((Object)("Headers: " + trace.get("headers")));
            }
        }
    }

    private void add(Map<String, Object> trace, TraceProperties.Include include, String name, Object value) {
        if (this.isIncluded(include) && value != null) {
            trace.put(name, value);
        }
    }

    private boolean isIncluded(TraceProperties.Include include) {
        return this.properties.getInclude().contains((Object)include);
    }

    public void setErrorAttributes(ErrorAttributes errorAttributes) {
        this.errorAttributes = errorAttributes;
    }

    private static final class CustomStatusResponseWrapper
    extends HttpServletResponseWrapper {
        private final int status;

        private CustomStatusResponseWrapper(HttpServletResponse response, int status) {
            super(response);
            this.status = status;
        }

        public int getStatus() {
            return this.status;
        }
    }
}

