/*
 * Decompiled with CFR 0.152.
 */
package org.jboss.as.controller.remote;

import java.io.Closeable;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.net.InetAddress;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.security.PrivilegedActionException;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.function.Supplier;
import org.jboss.as.controller.AccessAuditContext;
import org.jboss.as.controller.ModelController;
import org.jboss.as.controller.access.InVmAccess;
import org.jboss.as.controller.client.Operation;
import org.jboss.as.controller.client.OperationMessageHandler;
import org.jboss.as.controller.client.OperationResponse;
import org.jboss.as.controller.logging.ControllerLogger;
import org.jboss.as.controller.remote.IdentityAddressProtocolUtil;
import org.jboss.as.controller.remote.OperationAttachmentsProxy;
import org.jboss.as.controller.remote.ResponseAttachmentInputStreamSupport;
import org.jboss.as.controller.remote.TransactionalProtocolClient;
import org.jboss.as.protocol.StreamUtils;
import org.jboss.as.protocol.mgmt.ActiveOperation;
import org.jboss.as.protocol.mgmt.FlushableDataOutput;
import org.jboss.as.protocol.mgmt.ManagementChannelAssociation;
import org.jboss.as.protocol.mgmt.ManagementProtocolHeader;
import org.jboss.as.protocol.mgmt.ManagementRequestContext;
import org.jboss.as.protocol.mgmt.ManagementRequestHandler;
import org.jboss.as.protocol.mgmt.ManagementRequestHandlerFactory;
import org.jboss.as.protocol.mgmt.ManagementRequestHeader;
import org.jboss.as.protocol.mgmt.ManagementResponseHeader;
import org.jboss.as.protocol.mgmt.ProtocolUtils;
import org.jboss.dmr.ModelNode;
import org.wildfly.security.auth.server.SecurityIdentity;
import org.wildfly.security.manager.WildFlySecurityManager;

public class TransactionalProtocolOperationHandler
implements ManagementRequestHandlerFactory {
    private final ModelController controller;
    private final ManagementChannelAssociation channelAssociation;
    private final ResponseAttachmentInputStreamSupport responseAttachmentSupport;

    public TransactionalProtocolOperationHandler(ModelController controller, ManagementChannelAssociation channelAssociation, ResponseAttachmentInputStreamSupport responseAttachmentSupport) {
        this.controller = controller;
        this.channelAssociation = channelAssociation;
        this.responseAttachmentSupport = responseAttachmentSupport;
    }

    public ManagementRequestHandler<?, ?> resolveHandler(ManagementRequestHandlerFactory.RequestHandlerChain handlers, ManagementRequestHeader request) {
        switch (request.getOperationId()) {
            case 71: {
                ExecuteRequestContext executeRequestContext = new ExecuteRequestContext(this.responseAttachmentSupport);
                try {
                    executeRequestContext.operation = handlers.registerActiveOperation(Integer.valueOf(request.getBatchId()), (Object)executeRequestContext, (ActiveOperation.CompletedCallback)executeRequestContext);
                }
                catch (IllegalStateException ise) {
                    return new AbortOperationHandler(true);
                }
                return new ExecuteRequestHandler();
            }
            case 78: {
                ExecuteRequestContext executeRequestContext = new ExecuteRequestContext(this.responseAttachmentSupport);
                try {
                    executeRequestContext.operation = handlers.registerActiveOperation(Integer.valueOf(request.getBatchId()), (Object)executeRequestContext, (ActiveOperation.CompletedCallback)executeRequestContext);
                    return new AbortOperationHandler(false);
                }
                catch (IllegalStateException illegalStateException) {
                    return new CompleteTxOperationHandler();
                }
            }
            case 79: {
                handlers.registerActiveOperation(Integer.valueOf(request.getBatchId()), null);
                return this.responseAttachmentSupport.getReadHandler();
            }
            case 68: {
                handlers.registerActiveOperation(Integer.valueOf(request.getBatchId()), null);
                return this.responseAttachmentSupport.getCloseHandler();
            }
        }
        return handlers.resolveNext();
    }

    protected OperationResponse internalExecute(Operation operation, ManagementRequestContext<?> context, OperationMessageHandler messageHandler, ModelController.OperationTransactionControl control) {
        return this.controller.execute(operation, messageHandler, control);
    }

    static void sendResponse(ManagementRequestContext<ExecuteRequestContext> context, final byte responseType, final ModelNode response) throws IOException {
        final IOExceptionHolder exceptionHolder = new IOExceptionHolder();
        final CountDownLatch latch = new CountDownLatch(1);
        boolean accepted = context.executeAsync((ManagementRequestContext.AsyncTask)new ManagementRequestContext.AsyncTask<ExecuteRequestContext>(){

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            public void execute(ManagementRequestContext<ExecuteRequestContext> context) throws Exception {
                FlushableDataOutput output = null;
                try {
                    ControllerLogger.MGMT_OP_LOGGER.tracef("Transmitting response for %d", context.getOperationId());
                    ManagementResponseHeader header = ManagementResponseHeader.create((ManagementProtocolHeader)context.getRequestHeader());
                    output = context.writeMessage((ManagementProtocolHeader)header);
                    output.writeByte((int)responseType);
                    response.writeExternal((DataOutput)output);
                    output.writeByte(36);
                    output.close();
                    StreamUtils.safeClose((Closeable)output);
                    latch.countDown();
                }
                catch (IOException toCache) {
                    exceptionHolder.exception = toCache;
                }
                finally {
                    StreamUtils.safeClose(output);
                    latch.countDown();
                }
            }
        }, false);
        if (accepted) {
            try {
                latch.await();
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            if (exceptionHolder.exception != null) {
                throw exceptionHolder.exception;
            }
        }
    }

    private static Execution privilegedExecution() {
        return WildFlySecurityManager.isChecking() ? Execution.PRIVILEGED : Execution.NON_PRIVILEGED;
    }

    private static interface Execution {
        public static final Execution NON_PRIVILEGED = new Execution(){

            @Override
            public <T> T execute(Supplier<T> supplier) {
                return supplier.get();
            }
        };
        public static final Execution PRIVILEGED = new Execution(){

            @Override
            public <T> T execute(Supplier<T> supplier) {
                try {
                    return (T)AccessController.doPrivileged(() -> NON_PRIVILEGED.execute(supplier));
                }
                catch (PrivilegedActionException e) {
                    Throwable cause = e.getCause();
                    if (cause instanceof RuntimeException) {
                        throw (RuntimeException)cause;
                    }
                    if (cause instanceof Error) {
                        throw (Error)cause;
                    }
                    throw new RuntimeException(cause);
                }
            }
        };

        public <T> T execute(Supplier<T> var1);
    }

    private static class IOExceptionHolder {
        private IOException exception;

        private IOExceptionHolder() {
        }
    }

    private static class ExecuteRequestContext
    implements ActiveOperation.CompletedCallback<Void> {
        private ActiveOperation<Void, ExecuteRequestContext> operation;
        private boolean prepared;
        private boolean rollbackOnPrepare;
        private ModelController.OperationTransaction activeTx;
        private ManagementRequestContext<ExecuteRequestContext> responseChannel;
        private final CountDownLatch txCompletedLatch = new CountDownLatch(1);
        private boolean txCompleted;
        private OperationResponse postPrepareRaceResponse;
        final ResponseAttachmentInputStreamSupport streamSupport;

        ExecuteRequestContext(ResponseAttachmentInputStreamSupport streamSupport) {
            this.streamSupport = streamSupport;
        }

        Integer getOperationId() {
            return this.operation.getOperationId();
        }

        ActiveOperation.ResultHandler<Void> getResultHandler() {
            return this.operation.getResultHandler();
        }

        public void completed(Void result) {
        }

        public synchronized void failed(Exception e) {
            if (this.prepared) {
                ModelController.OperationTransaction transaction = this.activeTx;
                this.activeTx = null;
                if (transaction != null) {
                    try {
                        transaction.rollback();
                    }
                    finally {
                        this.txCompletedLatch.countDown();
                    }
                }
            } else if (this.responseChannel != null) {
                this.rollbackOnPrepare = true;
                String message = e.getMessage() != null ? e.getMessage() : "failure before rollback " + e.getClass().getName();
                ModelNode response = new ModelNode();
                response.get("outcome").set("failed");
                response.get("failure-description").set(message);
                ControllerLogger.MGMT_OP_LOGGER.tracef("sending pre-prepare failed response for %d  --- interrupted: %s", this.getOperationId(), Thread.currentThread().isInterrupted());
                try {
                    TransactionalProtocolOperationHandler.sendResponse(this.responseChannel, (byte)73, response);
                    this.responseChannel = null;
                }
                catch (IOException ignored) {
                    ControllerLogger.MGMT_OP_LOGGER.failedSendingFailedResponse(ignored, response, this.getOperationId());
                }
            }
        }

        public void cancelled() {
        }

        synchronized void initialize(ManagementRequestContext<ExecuteRequestContext> context) {
            assert (!this.prepared);
            assert (this.activeTx == null);
            this.responseChannel = context;
            ControllerLogger.MGMT_OP_LOGGER.tracef("Initialized for %d", this.getOperationId());
        }

        synchronized void prepare(ModelController.OperationTransaction tx, ModelNode result) {
            assert (!this.prepared);
            this.prepared = true;
            if (this.rollbackOnPrepare) {
                try {
                    tx.rollback();
                    ControllerLogger.MGMT_OP_LOGGER.tracef("rolled back on prepare for %d  --- interrupted: %s", this.getOperationId(), Thread.currentThread().isInterrupted());
                }
                finally {
                    this.txCompletedLatch.countDown();
                }
            }
            assert (this.activeTx == null);
            assert (this.responseChannel != null);
            this.activeTx = tx;
            ControllerLogger.MGMT_OP_LOGGER.tracef("sending prepared response for %d  --- interrupted: %s", this.getOperationId(), Thread.currentThread().isInterrupted());
            try {
                TransactionalProtocolOperationHandler.sendResponse(this.responseChannel, (byte)75, result);
                this.responseChannel = null;
            }
            catch (IOException e) {
                this.getResultHandler().failed((Throwable)e);
            }
        }

        synchronized void completeTx(ManagementRequestContext<ExecuteRequestContext> context, boolean commit) {
            if (!this.prepared) {
                assert (!commit);
                ControllerLogger.MGMT_OP_LOGGER.tracef("completeTx (cancel unprepared) for %d", this.getOperationId());
                this.rollbackOnPrepare = true;
                this.cancel(context);
            } else if (this.txCompleted) {
                ControllerLogger.MGMT_OP_LOGGER.tracef("completeTx (post-commit cancel) for %d", this.getOperationId());
                this.cancel(context);
            } else if (this.postPrepareRaceResponse == null) {
                this.txCompleted = true;
                if (this.activeTx != null) {
                    try {
                        assert (this.responseChannel == null);
                        this.responseChannel = context;
                        ControllerLogger.MGMT_OP_LOGGER.tracef("completeTx (%s) for %d", commit, this.getOperationId());
                        if (commit) {
                            this.activeTx.commit();
                        }
                        this.activeTx.rollback();
                    }
                    finally {
                        this.txCompletedLatch.countDown();
                    }
                }
            } else {
                assert (this.responseChannel == null);
                this.responseChannel = context;
                ControllerLogger.MGMT_OP_LOGGER.tracef("completeTx (%s) for %d received after a post-prepare response had already been cached; sending the cached response", commit, this.getOperationId());
                this.completed(this.postPrepareRaceResponse);
            }
        }

        synchronized void failed(ModelNode response) {
            if (this.prepared) {
                this.completed(OperationResponse.Factory.createSimple((ModelNode)response));
            } else {
                assert (this.responseChannel != null);
                ControllerLogger.MGMT_OP_LOGGER.tracef("sending pre-prepare failed response for %d  --- interrupted: %s", this.getOperationId(), Thread.currentThread().isInterrupted());
                try {
                    TransactionalProtocolOperationHandler.sendResponse(this.responseChannel, (byte)73, response);
                    this.responseChannel = null;
                }
                catch (IOException e) {
                    ControllerLogger.MGMT_OP_LOGGER.failedSendingFailedResponse(e, response, this.getOperationId());
                }
                finally {
                    this.getResultHandler().done(null);
                }
            }
        }

        synchronized void completed(OperationResponse response) {
            assert (this.prepared);
            if (this.responseChannel != null) {
                ControllerLogger.MGMT_OP_LOGGER.tracef("sending completed response %s for %d  --- interrupted: %s", response.getResponseNode(), this.getOperationId(), Thread.currentThread().isInterrupted());
                this.streamSupport.registerStreams(this.operation.getOperationId(), response.getInputStreams());
                try {
                    TransactionalProtocolOperationHandler.sendResponse(this.responseChannel, (byte)74, response.getResponseNode());
                    this.responseChannel = null;
                }
                catch (IOException e) {
                    ControllerLogger.MGMT_OP_LOGGER.failedSendingCompletedResponse(e, response.getResponseNode(), this.getOperationId());
                }
                finally {
                    this.getResultHandler().done(null);
                }
            } else {
                assert (this.postPrepareRaceResponse == null);
                ControllerLogger.MGMT_OP_LOGGER.tracef("received a post-prepare response for %d but no COMPLETE_TX_REQUEST has been received; caching the response", this.getOperationId());
                this.postPrepareRaceResponse = response;
            }
        }

        private void cancel(ManagementRequestContext<ExecuteRequestContext> context) {
            context.executeAsync((ManagementRequestContext.AsyncTask)new ManagementRequestContext.AsyncTask<ExecuteRequestContext>(){

                public void execute(ManagementRequestContext<ExecuteRequestContext> executeRequestContextManagementRequestContext) throws Exception {
                    operation.getResultHandler().cancel();
                }
            }, false);
        }
    }

    private static class ProxyOperationTransactionControl
    implements ModelController.OperationTransactionControl {
        private final ExecuteRequestContext requestContext;

        ProxyOperationTransactionControl(ExecuteRequestContext requestContext) {
            this.requestContext = requestContext;
        }

        @Override
        public void operationPrepared(ModelController.OperationTransaction transaction, ModelNode result) {
            this.requestContext.prepare(transaction, result);
            try {
                this.requestContext.txCompletedLatch.await();
            }
            catch (InterruptedException e) {
                ControllerLogger.ROOT_LOGGER.tracef("Clearing interrupted status from client request %d", this.requestContext.getOperationId());
                Thread.currentThread().interrupt();
            }
        }
    }

    private class AbortOperationHandler
    implements ManagementRequestHandler<Void, ExecuteRequestContext> {
        private final boolean forExecuteTxRequest;

        private AbortOperationHandler(boolean forExecuteTxRequest) {
            this.forExecuteTxRequest = forExecuteTxRequest;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void handleRequest(DataInput input, ActiveOperation.ResultHandler<Void> resultHandler, ManagementRequestContext<ExecuteRequestContext> context) throws IOException {
            if (this.forExecuteTxRequest) {
                try {
                    ExecutableRequest.parse(input, TransactionalProtocolOperationHandler.this.channelAssociation);
                }
                finally {
                    ControllerLogger.MGMT_OP_LOGGER.tracef("aborting (cancel received before request) for %d", context.getOperationId());
                    ModelNode response = new ModelNode();
                    response.get("outcome").set("cancelled");
                    ((ExecuteRequestContext)context.getAttachment()).initialize(context);
                    ((ExecuteRequestContext)context.getAttachment()).failed(response);
                }
            } else {
                byte commitOrRollback = input.readByte();
                if (commitOrRollback == 112) {
                    throw ControllerLogger.MGMT_OP_LOGGER.responseHandlerNotFound(context.getOperationId());
                }
            }
        }
    }

    private class CompleteTxOperationHandler
    implements ManagementRequestHandler<Void, ExecuteRequestContext> {
        private CompleteTxOperationHandler() {
        }

        public void handleRequest(DataInput input, ActiveOperation.ResultHandler<Void> resultHandler, ManagementRequestContext<ExecuteRequestContext> context) throws IOException {
            ExecuteRequestContext executeRequestContext = (ExecuteRequestContext)context.getAttachment();
            byte commitOrRollback = input.readByte();
            executeRequestContext.completeTx(context, commitOrRollback == 112);
        }
    }

    private static class ExecutableRequest {
        private final ModelNode operation;
        private final int attachmentsLength;
        private final IdentityAddressProtocolUtil.PropagatedIdentity propagatedIdentity;
        private final boolean inVmCall;

        private ExecutableRequest(ModelNode operation, int attachmentsLength, IdentityAddressProtocolUtil.PropagatedIdentity propagatedIdentity, boolean inVmCall) {
            this.operation = operation;
            this.attachmentsLength = attachmentsLength;
            this.propagatedIdentity = propagatedIdentity;
            this.inVmCall = inVmCall;
        }

        static ExecutableRequest parse(DataInput input, ManagementChannelAssociation channelAssociation) throws IOException {
            ModelNode operation = new ModelNode();
            ProtocolUtils.expectHeader((DataInput)input, (int)97);
            operation.readExternal(input);
            ProtocolUtils.expectHeader((DataInput)input, (int)101);
            int attachmentsLength = input.readInt();
            Boolean readIdentity = (Boolean)channelAssociation.getAttachments().getAttachment(TransactionalProtocolClient.SEND_IDENTITY);
            IdentityAddressProtocolUtil.PropagatedIdentity propagatedIdentity = readIdentity != null && readIdentity != false ? IdentityAddressProtocolUtil.read(input) : null;
            Boolean readSendInVm = (Boolean)channelAssociation.getAttachments().getAttachment(TransactionalProtocolClient.SEND_IN_VM);
            boolean inVmCall = false;
            if (readSendInVm != null && readSendInVm.booleanValue()) {
                ProtocolUtils.expectHeader((DataInput)input, (int)81);
                inVmCall = input.readBoolean();
            }
            return new ExecutableRequest(operation, attachmentsLength, propagatedIdentity, inVmCall);
        }
    }

    private class ExecuteRequestHandler
    implements ManagementRequestHandler<Void, ExecuteRequestContext> {
        private ExecuteRequestHandler() {
        }

        public void handleRequest(DataInput input, ActiveOperation.ResultHandler<Void> resultHandler, final ManagementRequestContext<ExecuteRequestContext> context) throws IOException {
            ControllerLogger.MGMT_OP_LOGGER.tracef("Handling transactional ExecuteRequest for %d", context.getOperationId());
            final ExecutableRequest executableRequest = ExecutableRequest.parse(input, TransactionalProtocolOperationHandler.this.channelAssociation);
            IdentityAddressProtocolUtil.PropagatedIdentity propagatedIdentity = executableRequest.propagatedIdentity;
            final SecurityIdentity securityIdentity = propagatedIdentity != null ? propagatedIdentity.securityIdentity : null;
            final InetAddress remoteAddress = propagatedIdentity != null ? propagatedIdentity.inetAddress : null;
            final PrivilegedAction<Void> action = new PrivilegedAction<Void>(){

                @Override
                public Void run() {
                    ExecuteRequestHandler.this.doExecute(executableRequest.operation, executableRequest.attachmentsLength, (ManagementRequestContext<ExecuteRequestContext>)context);
                    return null;
                }
            };
            final ExecuteRequestContext executeRequestContext = (ExecuteRequestContext)context.getAttachment();
            executeRequestContext.initialize(context);
            ManagementRequestContext.MultipleResponseAsyncTask<ExecuteRequestContext> task = new ManagementRequestContext.MultipleResponseAsyncTask<ExecuteRequestContext>(){

                public void execute(ManagementRequestContext<ExecuteRequestContext> context) throws Exception {
                    Supplier<Void> execution = new Supplier<Void>(){

                        @Override
                        public Void get() {
                            if (executableRequest.inVmCall && securityIdentity == null) {
                                return InVmAccess.runInVm(() -> {
                                    AccessAuditContext.doAs(false, null, remoteAddress, action);
                                    return null;
                                });
                            }
                            AccessAuditContext.doAs(securityIdentity != null, securityIdentity, remoteAddress, action);
                            return null;
                        }
                    };
                    TransactionalProtocolOperationHandler.privilegedExecution().execute(execution);
                }

                public ManagementProtocolHeader getCurrentRequestHeader() {
                    ManagementRequestContext current = executeRequestContext.responseChannel;
                    return current == null ? null : current.getRequestHeader();
                }
            };
            context.executeAsync((ManagementRequestContext.AsyncTask)task);
        }

        protected void doExecute(ModelNode operation, int attachmentsLength, ManagementRequestContext<ExecuteRequestContext> context) {
            OperationResponse result;
            ControllerLogger.MGMT_OP_LOGGER.tracef("Executing transactional ExecuteRequest for %d", context.getOperationId());
            ExecuteRequestContext executeRequestContext = (ExecuteRequestContext)context.getAttachment();
            executeRequestContext.initialize(context);
            Integer batchId = executeRequestContext.getOperationId();
            OperationMessageHandler messageHandlerProxy = OperationMessageHandler.DISCARD;
            ProxyOperationTransactionControl control = new ProxyOperationTransactionControl(executeRequestContext);
            OperationAttachmentsProxy attachmentsProxy = OperationAttachmentsProxy.create(operation, TransactionalProtocolOperationHandler.this.channelAssociation, batchId, attachmentsLength);
            try {
                result = TransactionalProtocolOperationHandler.this.internalExecute(attachmentsProxy, context, messageHandlerProxy, control);
            }
            catch (Throwable t) {
                ModelNode failure = new ModelNode();
                failure.get("outcome").set("failed");
                failure.get("failure-description").set(t.getClass().getName() + ":" + t.getMessage());
                executeRequestContext.failed(failure);
                attachmentsProxy.shutdown();
                ControllerLogger.MGMT_OP_LOGGER.unexpectedOperationExecutionException(t, Collections.singletonList(operation));
                return;
            }
            if (!executeRequestContext.prepared) {
                executeRequestContext.failed(result.getResponseNode());
            } else {
                executeRequestContext.completed(result);
            }
        }
    }
}

