/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.security;

import com.google.protobuf.ByteString;
import com.google.protobuf.Message;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Map;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.TextInputCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.RealmChoiceCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.ipc.ProtobufRpcEngine;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.ipc.RemoteException;
import org.apache.hadoop.ipc.Server;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos;
import org.apache.hadoop.security.SaslInputStream;
import org.apache.hadoop.security.SaslOutputStream;
import org.apache.hadoop.security.SaslRpcServer;
import org.apache.hadoop.security.authentication.util.KerberosName;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.util.ProtoUtil;

@InterfaceAudience.LimitedPrivate(value={"HDFS", "MapReduce"})
@InterfaceStability.Evolving
public class SaslRpcClient {
    public static final Log LOG = LogFactory.getLog(SaslRpcClient.class);
    private final SaslRpcServer.AuthMethod authMethod;
    private final SaslClient saslClient;
    private final boolean fallbackAllowed;
    private static final RpcHeaderProtos.RpcRequestHeaderProto saslHeader = ProtoUtil.makeRpcRequestHeader(RPC.RpcKind.RPC_PROTOCOL_BUFFER, RpcHeaderProtos.RpcRequestHeaderProto.OperationProto.RPC_FINAL_PACKET, Server.AuthProtocol.SASL.callId);
    private static final RpcHeaderProtos.RpcSaslProto negotiateRequest = RpcHeaderProtos.RpcSaslProto.newBuilder().setState(RpcHeaderProtos.RpcSaslProto.SaslState.NEGOTIATE).build();

    public SaslRpcClient(SaslRpcServer.AuthMethod method, Token<? extends TokenIdentifier> token, String serverPrincipal, boolean fallbackAllowed) throws IOException {
        this.authMethod = method;
        this.fallbackAllowed = fallbackAllowed;
        String saslUser = null;
        String saslProtocol = null;
        String saslServerName = null;
        Map<String, String> saslProperties = SaslRpcServer.SASL_PROPS;
        SaslClientCallbackHandler saslCallback = null;
        switch (method) {
            case TOKEN: {
                saslProtocol = "";
                saslServerName = "default";
                saslCallback = new SaslClientCallbackHandler(token);
                break;
            }
            case KERBEROS: {
                if (serverPrincipal == null || serverPrincipal.isEmpty()) {
                    throw new IOException("Failed to specify server's Kerberos principal name");
                }
                KerberosName name = new KerberosName(serverPrincipal);
                saslProtocol = name.getServiceName();
                saslServerName = name.getHostName();
                if (saslServerName != null) break;
                throw new IOException("Kerberos principal name does NOT have the expected hostname part: " + serverPrincipal);
            }
            default: {
                throw new IOException("Unknown authentication method " + (Object)((Object)method));
            }
        }
        String mechanism = method.getMechanismName();
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Creating SASL " + mechanism + "(" + (Object)((Object)this.authMethod) + ") " + " client to authenticate to service at " + saslServerName));
        }
        this.saslClient = Sasl.createSaslClient(new String[]{mechanism}, saslUser, saslProtocol, saslServerName, saslProperties, saslCallback);
        if (this.saslClient == null) {
            throw new IOException("Unable to find SASL client implementation");
        }
    }

    public boolean saslConnect(InputStream inS, OutputStream outS) throws IOException {
        DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS));
        DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream(outS));
        boolean inSasl = false;
        this.sendSaslMessage(outStream, negotiateRequest);
        boolean done = false;
        do {
            int totalLen = inStream.readInt();
            ProtobufRpcEngine.RpcResponseMessageWrapper responseWrapper = new ProtobufRpcEngine.RpcResponseMessageWrapper();
            responseWrapper.readFields(inStream);
            RpcHeaderProtos.RpcResponseHeaderProto header = (RpcHeaderProtos.RpcResponseHeaderProto)responseWrapper.getMessageHeader();
            switch (header.getStatus()) {
                case ERROR: 
                case FATAL: {
                    throw new RemoteException(header.getExceptionClassName(), header.getErrorMsg());
                }
            }
            if (totalLen != responseWrapper.getLength()) {
                throw new SaslException("Received malformed response length");
            }
            if (header.getCallId() != Server.AuthProtocol.SASL.callId) {
                throw new SaslException("Non-SASL response during negotiation");
            }
            RpcHeaderProtos.RpcSaslProto saslMessage = RpcHeaderProtos.RpcSaslProto.parseFrom(responseWrapper.getMessageBytes());
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("Received SASL message " + saslMessage));
            }
            RpcHeaderProtos.RpcSaslProto.Builder response = null;
            switch (saslMessage.getState()) {
                case NEGOTIATE: {
                    inSasl = true;
                    String clientAuthMethod = this.authMethod.toString();
                    RpcHeaderProtos.RpcSaslProto.SaslAuth saslAuthType = null;
                    for (RpcHeaderProtos.RpcSaslProto.SaslAuth authType : saslMessage.getAuthsList()) {
                        if (!clientAuthMethod.equals(authType.getMethod())) continue;
                        saslAuthType = authType;
                        break;
                    }
                    if (saslAuthType == null) {
                        saslAuthType = RpcHeaderProtos.RpcSaslProto.SaslAuth.newBuilder().setMethod(clientAuthMethod).setMechanism(this.saslClient.getMechanismName()).build();
                    }
                    byte[] challengeToken = null;
                    if (saslAuthType != null && saslAuthType.hasChallenge()) {
                        challengeToken = saslAuthType.getChallenge().toByteArray();
                        saslAuthType = RpcHeaderProtos.RpcSaslProto.SaslAuth.newBuilder(saslAuthType).clearChallenge().build();
                    } else if (this.saslClient.hasInitialResponse()) {
                        challengeToken = new byte[]{};
                    }
                    byte[] responseToken = challengeToken != null ? this.saslClient.evaluateChallenge(challengeToken) : new byte[]{};
                    response = this.createSaslReply(RpcHeaderProtos.RpcSaslProto.SaslState.INITIATE, responseToken);
                    response.addAuths(saslAuthType);
                    break;
                }
                case CHALLENGE: {
                    inSasl = true;
                    byte[] responseToken = this.saslEvaluateToken(saslMessage, false);
                    response = this.createSaslReply(RpcHeaderProtos.RpcSaslProto.SaslState.RESPONSE, responseToken);
                    break;
                }
                case SUCCESS: {
                    if (inSasl && this.saslEvaluateToken(saslMessage, true) != null) {
                        throw new SaslException("SASL client generated spurious token");
                    }
                    done = true;
                    break;
                }
                default: {
                    throw new SaslException("RPC client doesn't support SASL " + (Object)((Object)saslMessage.getState()));
                }
            }
            if (response == null) continue;
            this.sendSaslMessage(outStream, response.build());
        } while (!done);
        if (!inSasl && !this.fallbackAllowed) {
            throw new IOException("Server asks us to fall back to SIMPLE auth, but this client is configured to only allow secure connections.");
        }
        return inSasl;
    }

    private void sendSaslMessage(DataOutputStream out, RpcHeaderProtos.RpcSaslProto message) throws IOException {
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Sending sasl message " + message));
        }
        ProtobufRpcEngine.RpcRequestMessageWrapper request = new ProtobufRpcEngine.RpcRequestMessageWrapper(saslHeader, (Message)message);
        out.writeInt(request.getLength());
        request.write(out);
        out.flush();
    }

    private byte[] saslEvaluateToken(RpcHeaderProtos.RpcSaslProto saslResponse, boolean done) throws SaslException {
        byte[] saslToken = null;
        if (saslResponse.hasToken()) {
            saslToken = saslResponse.getToken().toByteArray();
            saslToken = this.saslClient.evaluateChallenge(saslToken);
        } else if (!done) {
            throw new SaslException("Challenge contains no token");
        }
        if (done && !this.saslClient.isComplete()) {
            throw new SaslException("Client is out of sync with server");
        }
        return saslToken;
    }

    private RpcHeaderProtos.RpcSaslProto.Builder createSaslReply(RpcHeaderProtos.RpcSaslProto.SaslState state, byte[] responseToken) {
        RpcHeaderProtos.RpcSaslProto.Builder response = RpcHeaderProtos.RpcSaslProto.newBuilder();
        response.setState(state);
        if (responseToken != null) {
            response.setToken(ByteString.copyFrom((byte[])responseToken));
        }
        return response;
    }

    public InputStream getInputStream(InputStream in) throws IOException {
        if (!this.saslClient.isComplete()) {
            throw new IOException("Sasl authentication exchange hasn't completed yet");
        }
        return new SaslInputStream(in, this.saslClient);
    }

    public OutputStream getOutputStream(OutputStream out) throws IOException {
        if (!this.saslClient.isComplete()) {
            throw new IOException("Sasl authentication exchange hasn't completed yet");
        }
        return new SaslOutputStream(out, this.saslClient);
    }

    public void dispose() throws SaslException {
        this.saslClient.dispose();
    }

    private static class SaslClientCallbackHandler
    implements CallbackHandler {
        private final String userName;
        private final char[] userPassword;

        public SaslClientCallbackHandler(Token<? extends TokenIdentifier> token) {
            this.userName = SaslRpcServer.encodeIdentifier(token.getIdentifier());
            this.userPassword = SaslRpcServer.encodePassword(token.getPassword());
        }

        @Override
        public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
            NameCallback nc = null;
            PasswordCallback pc = null;
            TextInputCallback rc = null;
            for (Callback callback : callbacks) {
                if (callback instanceof RealmChoiceCallback) continue;
                if (callback instanceof NameCallback) {
                    nc = (NameCallback)callback;
                    continue;
                }
                if (callback instanceof PasswordCallback) {
                    pc = (PasswordCallback)callback;
                    continue;
                }
                if (callback instanceof RealmCallback) {
                    rc = (RealmCallback)callback;
                    continue;
                }
                throw new UnsupportedCallbackException(callback, "Unrecognized SASL client callback");
            }
            if (nc != null) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug((Object)("SASL client callback: setting username: " + this.userName));
                }
                nc.setName(this.userName);
            }
            if (pc != null) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug((Object)"SASL client callback: setting userPassword");
                }
                pc.setPassword(this.userPassword);
            }
            if (rc != null) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug((Object)("SASL client callback: setting realm: " + rc.getDefaultText()));
                }
                rc.setText(rc.getDefaultText());
            }
        }
    }
}

