/*
 * Decompiled with CFR 0.152.
 */
package org.overlord.commons.auth.filters;

import java.io.IOException;
import java.io.StringReader;
import java.io.UnsupportedEncodingException;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.security.KeyPair;
import java.security.KeyStore;
import java.security.Principal;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.xml.stream.XMLEventReader;
import javax.xml.stream.XMLInputFactory;
import org.apache.commons.codec.binary.Base64;
import org.overlord.commons.auth.Messages;
import org.overlord.commons.auth.filters.SimplePrincipal;
import org.overlord.commons.auth.util.SAMLBearerTokenUtil;
import org.picketlink.identity.federation.core.parsers.saml.SAMLAssertionParser;
import org.picketlink.identity.federation.core.saml.v2.util.DocumentUtil;
import org.picketlink.identity.federation.saml.v2.assertion.AssertionType;
import org.picketlink.identity.federation.saml.v2.assertion.AttributeStatementType;
import org.picketlink.identity.federation.saml.v2.assertion.NameIDType;
import org.picketlink.identity.federation.saml.v2.assertion.StatementAbstractType;
import org.picketlink.identity.federation.saml.v2.assertion.SubjectType;
import org.w3c.dom.Document;

public class SamlBearerTokenAuthFilter
implements Filter {
    public static final ThreadLocal<SimplePrincipal> TL_principal = new ThreadLocal();
    private static final SimplePrincipal NO_PROXY = new SimplePrincipal(null);
    private String realm;
    private Set<String> allowedIssuers;
    private boolean signatureRequired;
    private String keystorePath;
    private String keystorePassword;
    private String keyAlias;
    private String keyPassword;
    private boolean wrapRequest;

    public void init(FilterConfig config) throws ServletException {
        String parameter = config.getInitParameter("realm");
        this.realm = parameter != null && parameter.trim().length() > 0 ? parameter : this.defaultRealm();
        parameter = config.getInitParameter("allowedIssuers");
        if (parameter != null && parameter.trim().length() > 0) {
            String[] split;
            this.allowedIssuers = new HashSet<String>();
            for (String issuer : split = parameter.split(",")) {
                this.allowedIssuers.add(issuer);
            }
        } else {
            this.allowedIssuers = this.defaultAllowedIssuers();
        }
        this.signatureRequired = (parameter = config.getInitParameter("signatureRequired")) != null && parameter.trim().length() > 0 ? Boolean.parseBoolean(parameter) : this.defaultSignatureRequired();
        parameter = config.getInitParameter("keystorePath");
        this.keystorePath = parameter != null && parameter.trim().length() > 0 ? parameter : this.defaultKeystorePath();
        parameter = config.getInitParameter("keystorePassword");
        this.keystorePassword = parameter != null && parameter.trim().length() > 0 ? parameter : this.defaultKeystorePassword();
        parameter = config.getInitParameter("keyAlias");
        this.keyAlias = parameter != null && parameter.trim().length() > 0 ? parameter : this.defaultKeyAlias();
        parameter = config.getInitParameter("keyPassword");
        this.keyPassword = parameter != null && parameter.trim().length() > 0 ? parameter : this.defaultKeyPassword();
        parameter = config.getInitParameter("wrapRequest");
        this.wrapRequest = parameter != null && parameter.trim().length() > 0 ? Boolean.parseBoolean(parameter) : this.defaultWrapRequest();
    }

    protected String defaultKeystorePassword() {
        return null;
    }

    protected String defaultKeyAlias() {
        return null;
    }

    protected String defaultKeyPassword() {
        return null;
    }

    protected String defaultKeystorePath() {
        return null;
    }

    protected boolean defaultSignatureRequired() {
        return false;
    }

    protected boolean defaultWrapRequest() {
        return false;
    }

    protected Set<String> defaultAllowedIssuers() {
        return Collections.emptySet();
    }

    protected String defaultRealm() {
        return "Overlord";
    }

    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        TL_principal.remove();
        HttpServletRequest req = (HttpServletRequest)request;
        String authHeader = req.getHeader("Authorization");
        Creds credentials = this.parseAuthorizationHeader(authHeader);
        if (credentials == null) {
            this.sendAuthResponse((HttpServletResponse)response);
            return;
        }
        SimplePrincipal principal = this.login(credentials, req, (HttpServletResponse)response);
        if (principal != null) {
            this.doFilterChain(request, response, chain, principal);
        } else {
            this.sendAuthResponse((HttpServletResponse)response);
        }
    }

    protected void doFilterChain(ServletRequest request, ServletResponse response, FilterChain chain, SimplePrincipal principal) throws IOException, ServletException {
        if (principal == NO_PROXY) {
            chain.doFilter(request, response);
        } else {
            HttpServletRequest hsr = null;
            hsr = this.wrapRequest ? this.wrapTheRequest(request, principal) : this.proxyRequest(request, principal);
            chain.doFilter((ServletRequest)hsr, response);
        }
    }

    private HttpServletRequest wrapTheRequest(ServletRequest request, final SimplePrincipal principal) {
        HttpServletRequestWrapper wrapper = new HttpServletRequestWrapper((HttpServletRequest)request){

            public Principal getUserPrincipal() {
                return principal;
            }

            public boolean isUserInRole(String role) {
                return principal.getRoles().contains(role);
            }

            public String getRemoteUser() {
                return principal.getName();
            }
        };
        return wrapper;
    }

    private HttpServletRequest proxyRequest(final ServletRequest request, final SimplePrincipal principal) {
        InvocationHandler handler = new InvocationHandler(){

            @Override
            public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                if (method.getName().equals("getUserPrincipal")) {
                    return principal;
                }
                if (method.getName().equals("getRemoteUser")) {
                    return principal.getName();
                }
                if (method.getName().equals("isUserInRole")) {
                    String role = (String)args[0];
                    return principal.getRoles().contains(role);
                }
                return method.invoke((Object)request, args);
            }
        };
        return (HttpServletRequest)Proxy.newProxyInstance(Thread.currentThread().getContextClassLoader(), new Class[]{HttpServletRequest.class}, handler);
    }

    private Creds parseAuthorizationHeader(String authHeader) {
        if (authHeader == null) {
            return null;
        }
        if (!authHeader.toUpperCase().startsWith("BASIC ")) {
            return null;
        }
        try {
            String userpassEncoded = authHeader.substring(6);
            byte[] decoded = Base64.decodeBase64((String)userpassEncoded);
            String data = new String(decoded, "UTF-8");
            int sepIdx = data.indexOf(58);
            if (sepIdx > 0) {
                String username = data.substring(0, sepIdx);
                String password = data.substring(sepIdx + 1);
                return new Creds(username, password);
            }
            return new Creds(data, null);
        }
        catch (UnsupportedEncodingException e) {
            throw new RuntimeException(e);
        }
    }

    protected SimplePrincipal login(Creds credentials, HttpServletRequest request, HttpServletResponse response) throws IOException {
        if ("SAML-BEARER-TOKEN".equals(credentials.username)) {
            return this.doSamlLogin(credentials.password, request);
        }
        return this.doBasicLogin(credentials.username, credentials.password, request);
    }

    protected SimplePrincipal doSamlLogin(String assertionData, HttpServletRequest request) throws IOException {
        try {
            KeyPair keyPair;
            Document samlAssertion = DocumentUtil.getDocument((String)assertionData);
            SAMLAssertionParser parser = new SAMLAssertionParser();
            XMLEventReader xmlEventReader = XMLInputFactory.newInstance().createXMLEventReader(new StringReader(assertionData));
            Object parsed = parser.parse(xmlEventReader);
            AssertionType assertion = (AssertionType)parsed;
            SAMLBearerTokenUtil.validateAssertion(assertion, request, this.allowedIssuers);
            if (this.signatureRequired && !SAMLBearerTokenUtil.isSAMLAssertionSignatureValid(samlAssertion, keyPair = this.getKeyPair(assertion))) {
                throw new IOException(Messages.getString("SamlBearerTokenAuthFilter.InvalidSig"));
            }
            return this.consumeAssertion(assertion);
        }
        catch (IOException e) {
            throw e;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private KeyPair getKeyPair(AssertionType assertion) throws IOException {
        KeyStore keystore = this.loadKeystore();
        try {
            return SAMLBearerTokenUtil.getKeyPair(keystore, this.keyAlias, this.keyPassword);
        }
        catch (Exception e) {
            e.printStackTrace();
            throw new IOException(Messages.getString("SamlBearerTokenAuthFilter.FailedToGetKeyPair") + this.keyAlias);
        }
    }

    private KeyStore loadKeystore() throws IOException {
        try {
            return SAMLBearerTokenUtil.loadKeystore(this.keystorePath, this.keystorePassword);
        }
        catch (Exception e) {
            e.printStackTrace();
            throw new IOException(Messages.getString("SamlBearerTokenAuthFilter.ErrorLoadingKeystore") + e.getMessage());
        }
    }

    private SimplePrincipal consumeAssertion(AssertionType assertion) throws Exception {
        SubjectType samlSubjectType = assertion.getSubject();
        String samlSubject = ((NameIDType)samlSubjectType.getSubType().getBaseID()).getValue();
        SimplePrincipal identity = new SimplePrincipal(samlSubject);
        Set statements = assertion.getStatements();
        for (StatementAbstractType statement : statements) {
            if (!(statement instanceof AttributeStatementType)) continue;
            AttributeStatementType attrStatement = (AttributeStatementType)statement;
            List attributes = attrStatement.getAttributes();
            for (AttributeStatementType.ASTChoiceType astChoiceType : attributes) {
                if (astChoiceType.getAttribute() == null || !astChoiceType.getAttribute().getName().equals("Role")) continue;
                List values = astChoiceType.getAttribute().getAttributeValue();
                for (Object roleValue : values) {
                    if (roleValue == null) continue;
                    identity.addRole(roleValue.toString());
                }
            }
        }
        TL_principal.set(identity);
        return identity;
    }

    protected SimplePrincipal doBasicLogin(String username, String password, HttpServletRequest request) throws IOException {
        try {
            request.login(username, password);
            return NO_PROXY;
        }
        catch (Exception e) {
            return null;
        }
    }

    private void sendAuthResponse(HttpServletResponse response) throws IOException {
        response.setHeader("WWW-Authenticate", String.format("BASIC realm=\"%1$s\"", this.realm));
        response.sendError(401);
    }

    public void destroy() {
    }

    protected static class Creds {
        public String username;
        public String password;

        public Creds(String username, String password) {
            this.username = username;
            this.password = password;
        }
    }
}

