/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.transformer;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.transformer.MissingOps;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Collections;

public final class IdEmbedding
extends AbstractBlock {
    private static final String EMBEDDING_PARAM_NAME = "embedding";
    private int dictionarySize;
    private int embeddingSize;
    private Parameter embedding;

    private IdEmbedding(Builder builder) {
        this.dictionarySize = builder.dictionarySize;
        this.embeddingSize = builder.embeddingSize;
        this.embedding = this.addParameter(Parameter.builder().setName(EMBEDDING_PARAM_NAME).setType(Parameter.Type.WEIGHT).optShape(new Shape(this.dictionarySize, this.embeddingSize)).build());
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        return new Shape[]{inputShapes[0].addAll(new Shape(this.embeddingSize))};
    }

    @Override
    protected NDList forwardInternal(ParameterStore ps, NDList inputs, boolean training, PairList<String, Object> params) {
        NDArray input = inputs.singletonOrThrow();
        NDArray ids = input.flatten().reshape(1L, input.getShape().size());
        NDArray embeddingTable = ps.getValue(this.embedding, ids.getDevice(), training);
        NDArray result = MissingOps.gatherNd(embeddingTable, ids);
        Shape targetShape = input.getShape().addAll(new Shape(embeddingTable.getShape().get(1)));
        return new NDList(result.reshape(targetShape));
    }

    public NDArray probabilities(ParameterStore parameterStore, NDArray input, boolean training) {
        NDArray asMatrix = input.reshape(-1L, this.embeddingSize);
        NDArray embeddingTableTransposed = parameterStore.getValue(this.embedding, input.getDevice(), training).transpose();
        embeddingTableTransposed.attach(input.getManager());
        NDArray logitsFlat = asMatrix.dot(embeddingTableTransposed);
        NDArray logProbsFlat = logitsFlat.logSoftmax(1);
        Shape targetShape = input.getShape().slice(0, input.getShape().dimension() - 1).addAll(new Shape(this.dictionarySize));
        return logProbsFlat.reshape(targetShape);
    }

    public NDArray getValue(ParameterStore ps, Device device, boolean training) {
        return ps.getValue(this.embedding, device, training);
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
        this.inputNames = Collections.singletonList("tokenIds");
    }

    public static final class Builder {
        private int dictionarySize;
        private int embeddingSize;

        public Builder setDictionarySize(int dictionarySize) {
            this.dictionarySize = dictionarySize;
            return this;
        }

        public Builder setEmbeddingSize(int embeddingSize) {
            this.embeddingSize = embeddingSize;
            return this;
        }

        public IdEmbedding build() {
            if (this.dictionarySize <= 0) {
                throw new IllegalArgumentException("You must specify the dictionary Size for the embedding.");
            }
            if (this.embeddingSize == 0) {
                throw new IllegalArgumentException("You must specify the embedding size");
            }
            return new IdEmbedding(this);
        }
    }
}

