/*
 * Decompiled with CFR 0.152.
 */
package com.google.genai;

import com.google.genai.LocalTokenizerProcessor;
import com.google.genai.errors.GenAiIOException;
import com.google.genai.proto.SentencepieceModel;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.nio.file.attribute.FileAttribute;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Logger;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.ResponseBody;

final class LocalTokenizerLoader {
    private static final Map<String, String> GEMINI_MODELS_TO_TOKENIZER_NAMES;
    private static final Logger logger;
    private static Map<String, TokenizerConfig> TOKENIZERS;
    private static final Map<String, SentencepieceModel.ModelProto> modelProtoCache;
    private static final Map<String, LocalTokenizerProcessor> localTokenizerProcessorCache;
    private static OkHttpClient httpClient;

    private LocalTokenizerLoader() {
    }

    public static String getTokenizerName(String modelName) {
        if (GEMINI_MODELS_TO_TOKENIZER_NAMES.containsKey(modelName)) {
            return GEMINI_MODELS_TO_TOKENIZER_NAMES.get(modelName);
        }
        throw new IllegalArgumentException("Model " + modelName + " is not supported. Supported models: " + String.join((CharSequence)", ", GEMINI_MODELS_TO_TOKENIZER_NAMES.keySet()));
    }

    public static SentencepieceModel.ModelProto loadModelProto(String tokenizerName) {
        return modelProtoCache.computeIfAbsent(tokenizerName, key -> {
            try {
                byte[] protoBytes = LocalTokenizerLoader.loadModelProtoBytes(key);
                return SentencepieceModel.ModelProto.parseFrom(protoBytes);
            }
            catch (InvalidProtocolBufferException e) {
                throw new IllegalStateException("Failed to parse model proto", e);
            }
            catch (IOException e) {
                throw new GenAiIOException("Failed to load tokenizer model", e);
            }
        });
    }

    public static LocalTokenizerProcessor getSentencePiece(String tokenizerName) {
        return localTokenizerProcessorCache.computeIfAbsent(tokenizerName, key -> {
            LocalTokenizerProcessor processor = new LocalTokenizerProcessor(LocalTokenizerLoader.loadModelProto(key));
            return processor;
        });
    }

    private static byte[] loadModelProtoBytes(String tokenizerName) throws IOException {
        if (!TOKENIZERS.containsKey(tokenizerName)) {
            throw new IllegalArgumentException("Tokenizer " + tokenizerName + " is not supported. Supported tokenizers: " + String.join((CharSequence)", ", TOKENIZERS.keySet()));
        }
        TokenizerConfig config = TOKENIZERS.get(tokenizerName);
        return LocalTokenizerLoader.load(config.modelUrl(), config.modelHash());
    }

    private static byte[] load(String fileUrl, String expectedHash) throws IOException {
        String filename;
        Path modelDir = Paths.get(System.getProperty("java.io.tmpdir"), "vertexai_tokenizer_model");
        Path modelPath = modelDir.resolve(filename = LocalTokenizerLoader.sha1(fileUrl));
        Optional<byte[]> modelData = LocalTokenizerLoader.maybeLoadFromCache(modelPath, expectedHash);
        if (modelData.isPresent()) {
            return modelData.get();
        }
        byte[] downloadedData = LocalTokenizerLoader.loadFromUrl(fileUrl, expectedHash);
        LocalTokenizerLoader.maybeSaveToCache(modelDir, modelPath, downloadedData);
        logger.info("Downloaded model from " + fileUrl + " to " + modelPath + " with hash " + LocalTokenizerLoader.sha256(downloadedData));
        return downloadedData;
    }

    private static Optional<byte[]> maybeLoadFromCache(Path filePath, String expectedHash) throws IOException {
        if (!Files.exists(filePath, new LinkOption[0])) {
            return Optional.empty();
        }
        byte[] content = Files.readAllBytes(filePath);
        if (LocalTokenizerLoader.isValidModel(content, expectedHash)) {
            return Optional.of(content);
        }
        try {
            Files.deleteIfExists(filePath);
        }
        catch (IOException iOException) {
            // empty catch block
        }
        return Optional.empty();
    }

    private static void maybeSaveToCache(Path cacheDir, Path cachePath, byte[] content) {
        try {
            Files.createDirectories(cacheDir, new FileAttribute[0]);
            Path tmpPath = cacheDir.resolve("." + UUID.randomUUID() + ".tmp");
            Files.write(tmpPath, content, new OpenOption[0]);
            Files.move(tmpPath, cachePath, StandardCopyOption.REPLACE_EXISTING);
        }
        catch (IOException iOException) {
            // empty catch block
        }
    }

    private static byte[] loadFromUrl(String fileUrl, String expectedHash) throws IOException {
        Request request = new Request.Builder().url(fileUrl).build();
        try (Response response = httpClient.newCall(request).execute();){
            if (response == null) {
                throw new GenAiIOException("HTTP request failed: response is null");
            }
            if (!response.isSuccessful()) {
                throw new GenAiIOException("Failed to download tokenizer model: HTTP " + response.code());
            }
            ResponseBody body = response.body();
            if (body == null) {
                throw new GenAiIOException("Failed to download tokenizer model: Response body is null");
            }
            byte[] content = body.bytes();
            if (!LocalTokenizerLoader.isValidModel(content, expectedHash)) {
                String actualHash = LocalTokenizerLoader.sha256(content);
                throw new GenAiIOException("Downloaded model file is corrupted. Expected hash " + expectedHash + ". Got file hash " + actualHash + ".");
            }
            byte[] byArray = content;
            return byArray;
        }
    }

    private static boolean isValidModel(byte[] modelData, String expectedHash) {
        if (expectedHash == null || expectedHash.isEmpty()) {
            throw new IllegalArgumentException("expected_hash is required");
        }
        return LocalTokenizerLoader.sha256(modelData).equals(expectedHash);
    }

    private static String sha256(byte[] data) {
        return LocalTokenizerLoader.hash(data, "SHA-256");
    }

    private static String sha1(String input) {
        return LocalTokenizerLoader.hash(input.getBytes(), "SHA-1");
    }

    private static String hash(byte[] data, String algorithm) {
        try {
            MessageDigest digest = MessageDigest.getInstance(algorithm);
            byte[] hash = digest.digest(data);
            return LocalTokenizerLoader.bytesToHex(hash);
        }
        catch (NoSuchAlgorithmException e) {
            throw new IllegalArgumentException(e);
        }
    }

    private static String bytesToHex(byte[] hash) {
        StringBuilder hexString = new StringBuilder(2 * hash.length);
        for (byte b : hash) {
            String hex = Integer.toHexString(0xFF & b);
            if (hex.length() == 1) {
                hexString.append('0');
            }
            hexString.append(hex);
        }
        return hexString.toString();
    }

    static {
        logger = Logger.getLogger(LocalTokenizerLoader.class.getName());
        HashMap<String, String> modelMap = new HashMap<String, String>();
        modelMap.put("gemini-1.0-pro-001", "gemma2");
        modelMap.put("gemini-1.0-pro-002", "gemma2");
        modelMap.put("gemini-1.0-pro", "gemma2");
        modelMap.put("gemini-1.5-flash-001", "gemma2");
        modelMap.put("gemini-1.5-flash-002", "gemma2");
        modelMap.put("gemini-1.5-flash", "gemma2");
        modelMap.put("gemini-1.5-pro-001", "gemma2");
        modelMap.put("gemini-1.5-pro-002", "gemma2");
        modelMap.put("gemini-1.5-pro", "gemma2");
        modelMap.put("gemini-2.0-flash-001", "gemma3");
        modelMap.put("gemini-2.0-flash-lite-001", "gemma3");
        modelMap.put("gemini-2.0-flash-lite", "gemma3");
        modelMap.put("gemini-2.0-flash", "gemma3");
        modelMap.put("gemini-2.5-flash-lite-preview-06-17", "gemma3");
        modelMap.put("gemini-2.5-flash-lite", "gemma3");
        modelMap.put("gemini-2.5-flash-preview-04-17", "gemma3");
        modelMap.put("gemini-2.5-flash-preview-05-20", "gemma3");
        modelMap.put("gemini-2.5-flash", "gemma3");
        modelMap.put("gemini-2.5-pro-exp-03-25", "gemma3");
        modelMap.put("gemini-2.5-pro-preview-05-06", "gemma3");
        modelMap.put("gemini-2.5-pro-preview-06-05", "gemma3");
        modelMap.put("gemini-2.5-pro", "gemma3");
        modelMap.put("gemini-live-2.5-flash", "gemma3");
        GEMINI_MODELS_TO_TOKENIZER_NAMES = Collections.unmodifiableMap(modelMap);
        HashMap<String, TokenizerConfig> map = new HashMap<String, TokenizerConfig>();
        map.put("gemma2", new TokenizerConfig("https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model", "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2"));
        map.put("gemma3", new TokenizerConfig("https://raw.githubusercontent.com/google/gemma_pytorch/014acb7ac4563a5f77c76d7ff98f31b568c16508/tokenizer/gemma3_cleaned_262144_v2.spiece.model", "1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c"));
        TOKENIZERS = Collections.unmodifiableMap(map);
        modelProtoCache = new ConcurrentHashMap<String, SentencepieceModel.ModelProto>();
        localTokenizerProcessorCache = new ConcurrentHashMap<String, LocalTokenizerProcessor>();
        httpClient = new OkHttpClient();
    }

    static final class TokenizerConfig {
        private final String modelUrl;
        private final String modelHash;

        TokenizerConfig(String modelUrl, String modelHash) {
            this.modelUrl = modelUrl;
            this.modelHash = modelHash;
        }

        String modelUrl() {
            return this.modelUrl;
        }

        String modelHash() {
            return this.modelHash;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            TokenizerConfig that = (TokenizerConfig)o;
            return Objects.equals(this.modelUrl, that.modelUrl) && Objects.equals(this.modelHash, that.modelHash);
        }

        public int hashCode() {
            return Objects.hash(this.modelUrl, this.modelHash);
        }

        public String toString() {
            return "TokenizerConfig[modelUrl=" + this.modelUrl + ", modelHash=" + this.modelHash + "]";
        }
    }
}

