/*
 * Decompiled with CFR 0.152.
 */
package io.asyncer.r2dbc.mysql;

import io.asyncer.r2dbc.mysql.ConnectionContext;
import io.asyncer.r2dbc.mysql.Extensions;
import io.asyncer.r2dbc.mysql.HandshakeExchangeable;
import io.asyncer.r2dbc.mysql.QueryFlow;
import io.asyncer.r2dbc.mysql.ServerVersion;
import io.asyncer.r2dbc.mysql.TextSimpleStatement;
import io.asyncer.r2dbc.mysql.api.MySqlResult;
import io.asyncer.r2dbc.mysql.cache.Caches;
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.codec.Codecs;
import io.asyncer.r2dbc.mysql.codec.CodecsBuilder;
import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm;
import io.asyncer.r2dbc.mysql.constant.SslMode;
import io.asyncer.r2dbc.mysql.extension.CodecRegistrar;
import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
import io.asyncer.r2dbc.mysql.message.client.InitDbMessage;
import io.asyncer.r2dbc.mysql.message.server.CompleteMessage;
import io.asyncer.r2dbc.mysql.message.server.ErrorMessage;
import io.asyncer.r2dbc.mysql.message.server.ServerMessage;
import io.netty.buffer.ByteBufAllocator;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.r2dbc.spi.IsolationLevel;
import io.r2dbc.spi.Readable;
import java.time.DateTimeException;
import java.time.Duration;
import java.time.ZoneId;
import java.time.ZoneOffset;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Function;
import org.jetbrains.annotations.Nullable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SynchronousSink;

final class InitFlow {
    private static final InternalLogger logger = InternalLoggerFactory.getInstance(InitFlow.class);
    private static final ServerVersion MARIA_11_1_1 = ServerVersion.create(11, 1, 1, true);
    private static final ServerVersion MYSQL_8_0_3 = ServerVersion.create(8, 0, 3);
    private static final ServerVersion MYSQL_5_7_20 = ServerVersion.create(5, 7, 20);
    private static final ServerVersion MYSQL_8 = ServerVersion.create(8, 0, 0);
    private static final BiConsumer<ServerMessage, SynchronousSink<Boolean>> INIT_DB = (message, sink) -> {
        if (message instanceof ErrorMessage) {
            ErrorMessage msg = (ErrorMessage)message;
            logger.debug("Use database failed: [{}] [{}] {}", new Object[]{msg.getCode(), msg.getSqlState(), msg.getMessage()});
            sink.next((Object)false);
            sink.complete();
        } else if (message instanceof CompleteMessage && ((CompleteMessage)message).isDone()) {
            sink.next((Object)true);
            sink.complete();
        } else {
            ReferenceCountUtil.safeRelease((Object)message);
        }
    };
    private static final BiConsumer<ServerMessage, SynchronousSink<Void>> INIT_DB_AFTER = (message, sink) -> {
        if (message instanceof ErrorMessage) {
            sink.error((Throwable)((ErrorMessage)message).toException());
        } else if (message instanceof CompleteMessage && ((CompleteMessage)message).isDone()) {
            sink.complete();
        } else {
            ReferenceCountUtil.safeRelease((Object)message);
        }
    };

    static Mono<Void> initHandshake(Client client, SslMode sslMode, String database, String user, @Nullable CharSequence password, Set<CompressionAlgorithm> compressionAlgorithms, int zstdCompressionLevel) {
        return client.exchange(new HandshakeExchangeable(client, sslMode, database, user, password, compressionAlgorithms, zstdCompressionLevel)).then();
    }

    static Mono<Codecs> initSession(Client client, String database, int prepareCacheSize, List<String> sessionVariables, boolean forceTimeZone, @Nullable Duration lockWaitTimeout, @Nullable Duration statementTimeout, Extensions extensions) {
        return Mono.defer(() -> {
            ByteBufAllocator allocator = client.getByteBufAllocator();
            CodecsBuilder builder = Codecs.builder();
            extensions.forEach(CodecRegistrar.class, registrar -> registrar.register(allocator, builder));
            Codecs codecs = builder.build();
            List<String> variables = InitFlow.mergeSessionVariables(client, sessionVariables, forceTimeZone, statementTimeout);
            logger.debug("Initializing client session: {}", variables);
            return QueryFlow.setSessionVariables(client, variables).then(InitFlow.loadSessionVariables(client, codecs)).flatMap(data -> InitFlow.loadAndInitInnoDbEngineStatus(data, client, codecs, lockWaitTimeout)).flatMap(data -> {
                ConnectionContext context = client.getContext();
                logger.debug("Initializing connection {} context: {}", (Object)context.getConnectionId(), data);
                context.initSession(Caches.createPrepareCache(prepareCacheSize), ((SessionState)data).level, ((SessionState)data).lockWaitTimeoutSupported, ((SessionState)data).lockWaitTimeout, ((SessionState)data).product, ((SessionState)data).timeZone);
                if (!((SessionState)data).lockWaitTimeoutSupported) {
                    logger.info("Lock wait timeout is not supported by server, all related operations will be ignored");
                }
                return database.isEmpty() ? Mono.just((Object)codecs) : InitFlow.initDatabase(client, database).then(Mono.just((Object)codecs));
            });
        });
    }

    private static Mono<SessionState> loadAndInitInnoDbEngineStatus(SessionState data, Client client, Codecs codecs, @Nullable Duration lockWaitTimeout) {
        return new TextSimpleStatement(client, codecs, "SHOW VARIABLES LIKE 'innodb_lock_wait_timeout'").execute().flatMap(r -> r.map(readable -> {
            String value = (String)readable.get(1, String.class);
            if (value == null || value.isEmpty()) {
                return data;
            }
            return data.lockWaitTimeout(Duration.ofSeconds(Long.parseLong(value)));
        })).single((Object)data).flatMap(d -> {
            if (lockWaitTimeout != null) {
                if (((SessionState)d).lockWaitTimeoutSupported) {
                    return QueryFlow.executeVoid(client, StringUtils.lockWaitTimeoutStatement(lockWaitTimeout)).then(Mono.fromSupplier(() -> d.lockWaitTimeout(lockWaitTimeout)));
                }
                logger.warn("Lock wait timeout is not supported by server, ignore initial setting");
                return Mono.just((Object)d);
            }
            return Mono.just((Object)d);
        });
    }

    private static Mono<SessionState> loadSessionVariables(Client client, Codecs codecs) {
        Function<MySqlResult, Flux> handler;
        ConnectionContext context = client.getContext();
        StringBuilder query = new StringBuilder(128).append("SELECT ").append(InitFlow.transactionIsolationColumn(context)).append(",@@version_comment AS v");
        if (context.isTimeZoneInitialized()) {
            handler = r -> InitFlow.convertSessionData(r, false);
        } else {
            query.append(",@@system_time_zone AS s,@@time_zone AS t");
            handler = r -> InitFlow.convertSessionData(r, true);
        }
        return new TextSimpleStatement(client, codecs, query.toString()).execute().flatMap(handler).last();
    }

    private static Mono<Void> initDatabase(Client client, String database) {
        return client.exchange(new InitDbMessage(database), INIT_DB).last().flatMap(success -> {
            if (success.booleanValue()) {
                return Mono.empty();
            }
            String sql = "CREATE DATABASE IF NOT EXISTS " + StringUtils.quoteIdentifier(database);
            return QueryFlow.executeVoid(client, sql).then(client.exchange(new InitDbMessage(database), INIT_DB_AFTER).then());
        });
    }

    private static List<String> mergeSessionVariables(Client client, List<String> sessionVariables, boolean forceTimeZone, @Nullable Duration statementTimeout) {
        ConnectionContext context = client.getContext();
        if (!(forceTimeZone && context.isTimeZoneInitialized() || statementTimeout != null)) {
            return sessionVariables;
        }
        ArrayList<String> variables = new ArrayList<String>(sessionVariables.size() + 2);
        variables.addAll(sessionVariables);
        if (forceTimeZone && context.isTimeZoneInitialized()) {
            variables.add(InitFlow.timeZoneVariable(context.getTimeZone()));
        }
        if (statementTimeout != null) {
            if (context.isStatementTimeoutSupported()) {
                variables.add(StringUtils.statementTimeoutVariable(statementTimeout, context.isMariaDb()));
            } else {
                logger.warn("Statement timeout is not supported in {}, ignore initial setting", (Object)context.getServerVersion());
            }
        }
        return variables;
    }

    private static String timeZoneVariable(ZoneId timeZone) {
        String offerStr = timeZone instanceof ZoneOffset && "Z".equalsIgnoreCase(timeZone.getId()) ? "+00:00" : timeZone.getId();
        return "time_zone='" + offerStr + "'";
    }

    private static Flux<SessionState> convertSessionData(MySqlResult r, boolean timeZone) {
        return r.map(readable -> {
            IsolationLevel level = InitFlow.convertIsolationLevel((String)readable.get(0, String.class));
            String product = (String)readable.get(1, String.class);
            return new SessionState(level, product, timeZone ? InitFlow.readZoneId(readable) : null);
        });
    }

    private static String transactionIsolationColumn(ConnectionContext context) {
        ServerVersion version = context.getServerVersion();
        if (context.isMariaDb()) {
            return version.isGreaterThanOrEqualTo(MARIA_11_1_1) ? "@@transaction_isolation AS i" : "@@tx_isolation AS i";
        }
        return version.isGreaterThanOrEqualTo(MYSQL_8_0_3) || version.isGreaterThanOrEqualTo(MYSQL_5_7_20) && version.isLessThan(MYSQL_8) ? "@@transaction_isolation AS i" : "@@tx_isolation AS i";
    }

    private static ZoneId readZoneId(Readable readable) {
        String systemTimeZone = (String)readable.get(2, String.class);
        String timeZone = (String)readable.get(3, String.class);
        if (timeZone == null || timeZone.isEmpty() || "SYSTEM".equalsIgnoreCase(timeZone)) {
            if (systemTimeZone == null || systemTimeZone.isEmpty()) {
                logger.warn("MySQL does not return any timezone, trying to use system default timezone");
                return ZoneId.systemDefault().normalized();
            }
            return InitFlow.convertZoneId(systemTimeZone);
        }
        return InitFlow.convertZoneId(timeZone);
    }

    private static ZoneId convertZoneId(String id) {
        try {
            return StringUtils.parseZoneId(id);
        }
        catch (DateTimeException e) {
            logger.warn("The server timezone is unknown <{}>, trying to use system default timezone", (Object)id, (Object)e);
            return ZoneId.systemDefault().normalized();
        }
    }

    private static IsolationLevel convertIsolationLevel(@Nullable String name) {
        if (name == null) {
            logger.warn("Isolation level is null in current session, fallback to repeatable read");
            return IsolationLevel.REPEATABLE_READ;
        }
        switch (name) {
            case "READ-UNCOMMITTED": {
                return IsolationLevel.READ_UNCOMMITTED;
            }
            case "READ-COMMITTED": {
                return IsolationLevel.READ_COMMITTED;
            }
            case "REPEATABLE-READ": {
                return IsolationLevel.REPEATABLE_READ;
            }
            case "SERIALIZABLE": {
                return IsolationLevel.SERIALIZABLE;
            }
        }
        logger.warn("Unknown isolation level {} in current session, fallback to repeatable read", (Object)name);
        return IsolationLevel.REPEATABLE_READ;
    }

    private InitFlow() {
    }

    private static final class SessionState {
        private final IsolationLevel level;
        @Nullable
        private final String product;
        @Nullable
        private final ZoneId timeZone;
        private final Duration lockWaitTimeout;
        private final boolean lockWaitTimeoutSupported;

        SessionState(IsolationLevel level, @Nullable String product, @Nullable ZoneId timeZone) {
            this(level, product, timeZone, Duration.ZERO, false);
        }

        private SessionState(IsolationLevel level, @Nullable String product, @Nullable ZoneId timeZone, Duration lockWaitTimeout, boolean lockWaitTimeoutSupported) {
            this.level = level;
            this.product = product;
            this.timeZone = timeZone;
            this.lockWaitTimeout = lockWaitTimeout;
            this.lockWaitTimeoutSupported = lockWaitTimeoutSupported;
        }

        SessionState lockWaitTimeout(Duration timeout) {
            return new SessionState(this.level, this.product, this.timeZone, timeout, true);
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof SessionState)) {
                return false;
            }
            SessionState that = (SessionState)o;
            return this.lockWaitTimeoutSupported == that.lockWaitTimeoutSupported && this.level.equals(that.level) && Objects.equals(this.product, that.product) && Objects.equals(this.timeZone, that.timeZone) && this.lockWaitTimeout.equals(that.lockWaitTimeout);
        }

        public int hashCode() {
            int result = this.level.hashCode();
            result = 31 * result + (this.product != null ? this.product.hashCode() : 0);
            result = 31 * result + (this.timeZone != null ? this.timeZone.hashCode() : 0);
            result = 31 * result + this.lockWaitTimeout.hashCode();
            return 31 * result + (this.lockWaitTimeoutSupported ? 1 : 0);
        }

        public String toString() {
            return "SessionState{level=" + this.level + ", product='" + this.product + "', timeZone=" + this.timeZone + ", lockWaitTimeout=" + this.lockWaitTimeout + ", lockWaitTimeoutSupported=" + this.lockWaitTimeoutSupported + '}';
        }
    }
}

