/*
 * Decompiled with CFR 0.152.
 */
package com.google.genai;

import com.google.genai.MergeCandidate;
import com.google.genai.Symbol;
import com.google.genai.Token;
import com.google.genai.Trie;
import com.google.genai.proto.SentencepieceModel;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.logging.Logger;

final class LocalTokenizerProcessor {
    private static final Logger log = Logger.getLogger(LocalTokenizerProcessor.class.getName());
    private Set<String> userDefined = new HashSet<String>();
    private Trie uTrie = new Trie();
    private Map<String, Integer> pieces = new HashMap<String, Integer>();
    private Map<String, Integer> reserved = new HashMap<String, Integer>();
    private Map<Byte, Token> byte2Token = new HashMap<Byte, Token>();
    private Map<Integer, Byte> idToByte = new HashMap<Integer, Byte>();
    private SentencepieceModel.ModelProto model;
    private int unkID;

    public LocalTokenizerProcessor(SentencepieceModel.ModelProto model) {
        this.model = model;
        SentencepieceModel.TrainerSpec tSpec = model.getTrainerSpec();
        if (tSpec.getModelType() != SentencepieceModel.TrainerSpec.ModelType.BPE) {
            throw new IllegalArgumentException(String.format("Unsupported model type %s. Only BPE is supported.", new Object[]{tSpec.getModelType()}));
        }
        SentencepieceModel.NormalizerSpec nSpec = model.getNormalizerSpec();
        if (nSpec.getAddDummyPrefix() || nSpec.getRemoveExtraWhitespaces()) {
            throw new IllegalArgumentException(String.format("Unsupported model normalizer option: %s", nSpec));
        }
        if (!tSpec.hasUnkId()) {
            throw new IllegalArgumentException("Unknown ID is not set.");
        }
        this.unkID = tSpec.getUnkId();
        for (int i = 0; i < model.getPiecesCount(); ++i) {
            SentencepieceModel.ModelProto.SentencePiece p = model.getPieces(i);
            if (p.getType() == SentencepieceModel.ModelProto.SentencePiece.Type.NORMAL || p.getType() == SentencepieceModel.ModelProto.SentencePiece.Type.USER_DEFINED || p.getType() == SentencepieceModel.ModelProto.SentencePiece.Type.UNUSED) {
                this.pieces.put(p.getPiece(), i);
            } else if (p.getType() == SentencepieceModel.ModelProto.SentencePiece.Type.BYTE) {
                if (!tSpec.getByteFallback()) {
                    throw new IllegalArgumentException(String.format("byte piece %s is found although byte fallback is not enabled.", p.getPiece()));
                }
                int bValue = this.convertHexValue(p.getPiece());
                if (bValue >= 0 && bValue < 256) {
                    this.byte2Token.put((byte)bValue, new Token(p.getPiece(), i));
                    this.idToByte.put(i, (byte)bValue);
                }
            } else {
                this.reserved.put(p.getPiece(), i);
            }
            if (p.getType() != SentencepieceModel.ModelProto.SentencePiece.Type.USER_DEFINED) continue;
            this.userDefined.add(p.getPiece());
            this.uTrie.insert(p.getPiece());
        }
    }

    public List<Token> encode(String text) throws IllegalStateException {
        text = this.normalize(text);
        ArrayList<Symbol> symbols = new ArrayList<Symbol>(text.length());
        int i = 0;
        while (i < text.length()) {
            int len = this.uTrie.prefixLen(text.substring(i));
            if (len > 0) {
                symbols.add(new Symbol(text.substring(i, i + len), true, symbols.size() - 1, symbols.size() + 1));
                i += len;
                continue;
            }
            symbols.add(new Symbol(text.substring(i, i + 1), false, symbols.size() - 1, symbols.size() + 1));
            ++i;
        }
        ((Symbol)symbols.get((int)(symbols.size() - 1))).next = -1;
        PriorityQueue<MergeCandidate> pqSymbols = new PriorityQueue<MergeCandidate>();
        for (int i2 = 1; i2 < symbols.size(); ++i2) {
            this.addNewCandidate(symbols, pqSymbols, i2 - 1, i2);
        }
        while (!pqSymbols.isEmpty()) {
            MergeCandidate mc = (MergeCandidate)pqSymbols.poll();
            if (!this.isMergeCandidateValid(symbols, mc)) continue;
            Symbol leftSymbol = (Symbol)symbols.get(mc.left);
            Symbol rightSymbol = (Symbol)symbols.get(mc.right);
            String merged = this.maybeMerge(leftSymbol.text, rightSymbol.text);
            if (merged.length() == 0) {
                throw new IllegalStateException(String.format("error merge symbols, left %s, right %s", leftSymbol, rightSymbol));
            }
            leftSymbol.text = merged;
            leftSymbol.next = rightSymbol.next;
            rightSymbol.text = "";
            if (rightSymbol.next > 0) {
                ((Symbol)symbols.get((int)rightSymbol.next)).prev = mc.left;
            }
            this.addNewCandidate(symbols, pqSymbols, leftSymbol.prev, mc.left);
            this.addNewCandidate(symbols, pqSymbols, mc.left, rightSymbol.next);
        }
        ArrayList<Token> tokens = new ArrayList<Token>();
        int i3 = 0;
        while (i3 >= 0) {
            Symbol s = (Symbol)symbols.get(i3);
            int id = this.symbolToID(s);
            if (id == this.unkID && this.model.getTrainerSpec().getByteFallback()) {
                byte[] bytes = s.text.getBytes();
                for (int j = 0; j < bytes.length; ++j) {
                    tokens.add(this.byte2Token.get(bytes[j]));
                }
            } else {
                tokens.add(new Token(s.text, id));
            }
            i3 = ((Symbol)symbols.get((int)i3)).next;
        }
        return tokens;
    }

    public String decodeIds(List<Integer> ids) {
        StringBuilder sb = new StringBuilder();
        int i = 0;
        while (i < ids.size()) {
            int nextNonByte;
            for (nextNonByte = i; nextNonByte < ids.size() && this.isByteId(ids.get(nextNonByte)); ++nextNonByte) {
            }
            int numBytes = nextNonByte - i;
            if (numBytes > 0) {
                byte[] buf = new byte[numBytes];
                for (int j = 0; j < numBytes; ++j) {
                    buf[j] = this.idToByte.get(ids.get(i + j));
                }
                sb.append(new String(buf, StandardCharsets.UTF_8));
            }
            if (nextNonByte >= ids.size()) break;
            int currentId = ids.get(nextNonByte);
            SentencepieceModel.ModelProto.SentencePiece pieceProto = this.model.getPieces(currentId);
            if (pieceProto.getType() != SentencepieceModel.ModelProto.SentencePiece.Type.CONTROL) {
                if (currentId == this.unkID) {
                    sb.append(this.model.getTrainerSpec().getUnkSurface());
                } else {
                    String pieceText = pieceProto.getPiece();
                    sb.append(this.replaceSentencePieceSeparator(pieceText));
                }
            }
            i = nextNonByte + 1;
        }
        return sb.toString();
    }

    private String normalize(String text) {
        return text.replaceAll(" ", "\u2581");
    }

    private void addNewCandidate(List<Symbol> symbols, PriorityQueue<MergeCandidate> pq, int left, int right) {
        if (left == -1 || right == -1 || symbols.get((int)left).noMerge || symbols.get((int)right).noMerge) {
            return;
        }
        String merged = this.maybeMerge(symbols.get((int)left).text, symbols.get((int)right).text);
        if (merged.length() == 0) {
            return;
        }
        pq.add(new MergeCandidate(left, right, merged.length(), this.model.getPieces(this.pieces.get(merged)).getScore()));
    }

    private String maybeMerge(String a, String b) {
        String merged = a + b;
        if (this.pieces.containsKey(merged)) {
            return merged;
        }
        return "";
    }

    private boolean isMergeCandidateValid(List<Symbol> symbols, MergeCandidate symbol) {
        String left = symbols.get((int)symbol.left).text;
        String right = symbols.get((int)symbol.right).text;
        return left != "" && right != "" && left.length() + right.length() == symbol.length;
    }

    private int symbolToID(Symbol symbol) {
        if (this.pieces.containsKey(symbol.text)) {
            return this.pieces.get(symbol.text);
        }
        if (this.reserved.containsKey(symbol.text)) {
            return this.reserved.get(symbol.text);
        }
        return this.unkID;
    }

    private boolean isByteId(int id) {
        return this.model.getPieces(id).getType() == SentencepieceModel.ModelProto.SentencePiece.Type.BYTE;
    }

    private String replaceSentencePieceSeparator(String pieceText) {
        if (pieceText == null) {
            return "";
        }
        return pieceText.replace('\u2581', ' ');
    }

    private int convertHexValue(String bv) {
        if (bv == null || !bv.startsWith("<0x") || !bv.endsWith(">")) {
            return -1;
        }
        String hexPart = bv.substring(3, bv.length() - 1);
        if (hexPart.isEmpty()) {
            return -1;
        }
        try {
            return Integer.parseInt(hexPart, 16);
        }
        catch (NumberFormatException e) {
            return -1;
        }
    }
}

