/*
 * Decompiled with CFR 0.152.
 */
package org.openqa.selenium.grid.node;

import com.google.common.collect.ImmutableSet;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Stream;
import org.openqa.selenium.Capabilities;
import org.openqa.selenium.devtools.CdpEndpointFinder;
import org.openqa.selenium.grid.data.Session;
import org.openqa.selenium.grid.node.Node;
import org.openqa.selenium.internal.Debug;
import org.openqa.selenium.internal.Require;
import org.openqa.selenium.remote.SessionId;
import org.openqa.selenium.remote.http.BinaryMessage;
import org.openqa.selenium.remote.http.ClientConfig;
import org.openqa.selenium.remote.http.CloseMessage;
import org.openqa.selenium.remote.http.HttpClient;
import org.openqa.selenium.remote.http.HttpMethod;
import org.openqa.selenium.remote.http.HttpRequest;
import org.openqa.selenium.remote.http.Message;
import org.openqa.selenium.remote.http.TextMessage;
import org.openqa.selenium.remote.http.UrlTemplate;
import org.openqa.selenium.remote.http.WebSocket;

public class ProxyNodeWebsockets
implements BiFunction<String, Consumer<Message>, Optional<Consumer<Message>>> {
    private static final UrlTemplate CDP_TEMPLATE = new UrlTemplate("/session/{sessionId}/se/cdp");
    private static final UrlTemplate BIDI_TEMPLATE = new UrlTemplate("/session/{sessionId}/se/bidi");
    private static final UrlTemplate FWD_TEMPLATE = new UrlTemplate("/session/{sessionId}/se/fwd");
    private static final UrlTemplate VNC_TEMPLATE = new UrlTemplate("/session/{sessionId}/se/vnc");
    private static final Logger LOG = Logger.getLogger(ProxyNodeWebsockets.class.getName());
    private static final ImmutableSet<String> CDP_ENDPOINT_CAPS = ImmutableSet.of((Object)"goog:chromeOptions", (Object)"moz:debuggerAddress", (Object)"ms:edgeOptions");
    private final HttpClient.Factory clientFactory;
    private final Node node;
    private final String gridSubPath;

    public ProxyNodeWebsockets(HttpClient.Factory clientFactory, Node node, String gridSubPath) {
        this.clientFactory = Objects.requireNonNull(clientFactory);
        this.node = Objects.requireNonNull(node);
        this.gridSubPath = gridSubPath;
    }

    @Override
    public Optional<Consumer<Message>> apply(String uri, Consumer<Message> downstream) {
        UrlTemplate.Match fwdMatch = FWD_TEMPLATE.match(uri, this.gridSubPath);
        UrlTemplate.Match cdpMatch = CDP_TEMPLATE.match(uri, this.gridSubPath);
        UrlTemplate.Match bidiMatch = BIDI_TEMPLATE.match(uri, this.gridSubPath);
        UrlTemplate.Match vncMatch = VNC_TEMPLATE.match(uri, this.gridSubPath);
        if (bidiMatch == null && cdpMatch == null && vncMatch == null && fwdMatch == null) {
            return Optional.empty();
        }
        Optional<UrlTemplate.Match> firstMatch = Stream.of(fwdMatch, cdpMatch, bidiMatch, vncMatch).filter(Objects::nonNull).findFirst();
        if (firstMatch.isEmpty()) {
            LOG.warning("No session id found in uri " + uri);
            return Optional.empty();
        }
        String sessionId = (String)firstMatch.get().getParameters().get("sessionId");
        LOG.fine("Matching websockets for session id: " + sessionId);
        SessionId id = new SessionId(sessionId);
        if (!this.node.isSessionOwner(id)) {
            LOG.warning("Not owner of " + id);
            return Optional.empty();
        }
        Session session = this.node.getSession(id);
        Capabilities caps = session.getCapabilities();
        LOG.fine("Scanning for endpoint: " + caps);
        Consumer<SessionId> sessionConsumer = this.node::isSessionOwner;
        if (bidiMatch != null) {
            return this.findBiDiEndpoint(downstream, caps, sessionConsumer, id);
        }
        if (vncMatch != null) {
            sessionConsumer = fakeConsumer -> {};
            return this.findVncEndpoint(downstream, caps, sessionConsumer, id);
        }
        if (fwdMatch != null) {
            LOG.info("Matched endpoint where CDP connection is being forwarded");
            return this.findCdpEndpoint(downstream, caps, sessionConsumer, id);
        }
        if (caps.getCapabilityNames().contains("se:forwardCdp")) {
            LOG.info("Found endpoint where CDP connection needs to be forwarded");
            return this.findForwardCdpEndpoint(downstream, caps, sessionConsumer, id);
        }
        return this.findCdpEndpoint(downstream, caps, sessionConsumer, id);
    }

    private Optional<Consumer<Message>> findCdpEndpoint(Consumer<Message> downstream, Capabilities caps, Consumer<SessionId> sessionConsumer, SessionId sessionId) {
        for (String cdpEndpointCap : CDP_ENDPOINT_CAPS) {
            Optional<Consumer> cdpUri;
            Optional reportedUri = CdpEndpointFinder.getReportedUri((String)cdpEndpointCap, (Capabilities)caps);
            Optional<HttpClient> client = reportedUri.map(uri -> CdpEndpointFinder.getHttpClient((HttpClient.Factory)this.clientFactory, (URI)uri));
            try {
                cdpUri = client.flatMap(CdpEndpointFinder::getCdpEndPoint);
            }
            catch (Exception e) {
                try {
                    client.ifPresent(HttpClient::close);
                }
                catch (Exception ex) {
                    e.addSuppressed(ex);
                }
                throw e;
            }
            if (cdpUri.isPresent()) {
                LOG.log(Debug.getDebugLogLevel(), String.format("Endpoint found in %s", cdpEndpointCap));
                return cdpUri.map(cdp -> this.createWsEndPoint((URI)cdp, downstream, sessionConsumer, sessionId));
            }
            try {
                client.ifPresent(HttpClient::close);
            }
            catch (Exception e) {
                LOG.log(Level.FINE, "failed to close the http client used to check the reported CDP endpoint: " + reportedUri.get(), e);
            }
        }
        return Optional.empty();
    }

    private Optional<Consumer<Message>> findBiDiEndpoint(Consumer<Message> downstream, Capabilities caps, Consumer<SessionId> sessionConsumer, SessionId sessionId) {
        try {
            URI uri = new URI(String.valueOf(caps.getCapability("se:gridWebSocketUrl")));
            return Optional.of(uri).map(bidi -> this.createWsEndPoint((URI)bidi, downstream, sessionConsumer, sessionId));
        }
        catch (URISyntaxException e) {
            LOG.warning("Unable to create URI from: " + caps.getCapability("webSocketUrl"));
            return Optional.empty();
        }
    }

    private Optional<Consumer<Message>> findForwardCdpEndpoint(Consumer<Message> downstream, Capabilities caps, Consumer<SessionId> sessionConsumer, SessionId sessionId) {
        try {
            URI uri = new URI(String.valueOf(caps.getCapability("se:forwardCdp")));
            return Optional.of(uri).map(cdp -> this.createWsEndPoint((URI)cdp, downstream, sessionConsumer, sessionId));
        }
        catch (URISyntaxException e) {
            LOG.warning("Unable to create URI from: " + caps.getCapability("se:forwardCdp"));
            return Optional.empty();
        }
    }

    private Optional<Consumer<Message>> findVncEndpoint(Consumer<Message> downstream, Capabilities caps, Consumer<SessionId> sessionConsumer, SessionId sessionId) {
        Optional<URI> vncUri;
        String vncLocalAddress = (String)caps.getCapability("se:vncLocalAddress");
        try {
            vncUri = Optional.of(new URI(vncLocalAddress));
        }
        catch (URISyntaxException e) {
            LOG.warning("Invalid URI for endpoint " + vncLocalAddress);
            return Optional.empty();
        }
        LOG.log(Debug.getDebugLogLevel(), String.format("Endpoint found in %s", "se:vncLocalAddress"));
        return vncUri.map(vnc -> this.createWsEndPoint((URI)vnc, downstream, sessionConsumer, sessionId));
    }

    private Consumer<Message> createWsEndPoint(URI uri, Consumer<Message> downstream, Consumer<SessionId> sessionConsumer, SessionId sessionId) {
        Require.nonNull((String)"downstream", downstream);
        Require.nonNull((String)"uri", (Object)uri);
        Require.nonNull((String)"sessionConsumer", sessionConsumer);
        Require.nonNull((String)"sessionId", (Object)sessionId);
        LOG.info("Establishing connection to " + uri);
        HttpClient client = this.clientFactory.createClient(ClientConfig.defaultConfig().baseUri(uri));
        WebSocket upstream = client.openSocket(new HttpRequest(HttpMethod.GET, uri.toString()), (WebSocket.Listener)new ForwardingListener(downstream, sessionConsumer, sessionId));
        return arg_0 -> ((WebSocket)upstream).send(arg_0);
    }

    private static class ForwardingListener
    implements WebSocket.Listener {
        private final Consumer<Message> downstream;
        private final Consumer<SessionId> sessionConsumer;
        private final SessionId sessionId;

        public ForwardingListener(Consumer<Message> downstream, Consumer<SessionId> sessionConsumer, SessionId sessionId) {
            this.downstream = Objects.requireNonNull(downstream);
            this.sessionConsumer = Objects.requireNonNull(sessionConsumer);
            this.sessionId = Objects.requireNonNull(sessionId);
        }

        public void onBinary(byte[] data) {
            this.downstream.accept((Message)new BinaryMessage(data));
            this.sessionConsumer.accept(this.sessionId);
        }

        public void onClose(int code, String reason) {
            this.downstream.accept((Message)new CloseMessage(code, reason));
            this.sessionConsumer.accept(this.sessionId);
        }

        public void onText(CharSequence data) {
            this.downstream.accept((Message)new TextMessage(data));
            this.sessionConsumer.accept(this.sessionId);
        }

        public void onError(Throwable cause) {
            LOG.log(Level.WARNING, "Error proxying websocket command", cause);
        }
    }
}

