/*
 * Decompiled with CFR 0.152.
 */
package com.android.build.gradle.internal.tasks.mlkit.codegen;

import com.android.build.gradle.internal.tasks.mlkit.codegen.ClassNames;
import com.android.build.gradle.internal.tasks.mlkit.codegen.ModelGenerator;
import com.android.build.gradle.internal.tasks.mlkit.codegen.ModelUtils;
import com.android.build.gradle.internal.tasks.mlkit.codegen.codeinjector.InjectorUtils;
import com.android.build.gradle.internal.tasks.mlkit.codegen.codeinjector.codeblock.AssociatedFileInjector;
import com.android.build.gradle.internal.tasks.mlkit.codegen.codeinjector.codeblock.CodeBlockInjector;
import com.android.tools.mlkit.MetadataExtractor;
import com.android.tools.mlkit.ModelInfo;
import com.android.tools.mlkit.ModelParsingException;
import com.android.tools.mlkit.TensorInfo;
import com.google.common.base.CaseFormat;
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.FieldSpec;
import com.squareup.javapoet.JavaFile;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import javax.lang.model.element.Modifier;
import org.apache.commons.io.FilenameUtils;
import org.gradle.api.file.DirectoryProperty;
import org.gradle.api.logging.Logger;
import org.gradle.api.logging.Logging;

public class TfliteModelGenerator
implements ModelGenerator {
    private static final String FIELD_MODEL = "model";
    private static final String FIELD_METADATA_EXTRACTOR = "extractor";
    private final Logger logger;
    private final String localModelPath;
    private final MetadataExtractor extractor;
    private final ModelInfo modelInfo;
    private final String className;
    private final String packageName;

    public TfliteModelGenerator(File modelFile, String packageName, String localModelPath) throws ModelParsingException {
        this.extractor = ModelUtils.createMetadataExtractor(modelFile);
        this.localModelPath = localModelPath;
        this.modelInfo = ModelInfo.buildFrom(this.extractor);
        this.packageName = packageName;
        this.logger = Logging.getLogger(this.getClass());
        this.className = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, FilenameUtils.removeExtension((String)modelFile.getName()));
    }

    @Override
    public void generateBuildClass(DirectoryProperty outputDirProperty) {
        TypeSpec.Builder classBuilder = TypeSpec.classBuilder((String)this.className).addModifiers(new Modifier[]{Modifier.PUBLIC, Modifier.FINAL});
        classBuilder.addJavadoc(this.modelInfo.getModelDescription(), new Object[0]);
        this.buildFields(classBuilder);
        this.buildConstructor(classBuilder);
        this.buildCreateInputsMethod(classBuilder);
        this.buildGetAssociatedFileMethod(classBuilder);
        this.buildRunMethod(classBuilder);
        this.buildInnerClass(classBuilder);
        try {
            JavaFile javaFile = JavaFile.builder((String)this.packageName, (TypeSpec)classBuilder.build()).build();
            javaFile.writeTo((File)outputDirProperty.getAsFile().get());
        }
        catch (IOException e) {
            this.logger.debug("Failed to write mlkit generated java file");
        }
    }

    private void buildFields(TypeSpec.Builder classBuilder) {
        for (TensorInfo tensorInfo : this.modelInfo.getInputs()) {
            InjectorUtils.getFieldInjector().inject(classBuilder, tensorInfo);
        }
        for (TensorInfo tensorInfo : this.modelInfo.getOutputs()) {
            InjectorUtils.getFieldInjector().inject(classBuilder, tensorInfo);
        }
        FieldSpec model2 = FieldSpec.builder((TypeName)ClassNames.MODEL, (String)FIELD_MODEL, (Modifier[])new Modifier[0]).addModifiers(new Modifier[]{Modifier.PRIVATE, Modifier.FINAL}).build();
        classBuilder.addField(model2);
    }

    private void buildGetAssociatedFileMethod(TypeSpec.Builder classBuilder) {
        MethodSpec.Builder methodBuilder = MethodSpec.methodBuilder((String)"getAssociatedFile").addParameter((TypeName)ClassNames.CONTEXT, "context", new Modifier[0]).addParameter(String.class, "fileName", new Modifier[0]).addException(IOException.class).returns(InputStream.class);
        methodBuilder.addStatement("$T inputStream = context.getAssets().open($S)", new Object[]{InputStream.class, this.localModelPath}).addStatement("$T zipFile = new $T(new $T($T.toByteArray(inputStream)))", new Object[]{ClassNames.ZIP_FILE, ClassNames.ZIP_FILE, ClassNames.SEEKABLE_IN_MEMORY_BYTE_CHANNEL, ClassNames.IO_UTILS}).addStatement("return zipFile.getRawInputStream(zipFile.getEntry(fileName))", new Object[0]);
        classBuilder.addMethod(methodBuilder.build());
    }

    private void buildInnerClass(TypeSpec.Builder classBuilder) {
        InjectorUtils.getOutputsClassInjector().inject(classBuilder, this.modelInfo.getOutputs());
        InjectorUtils.getInputsClassInjector().inject(classBuilder, this.modelInfo.getInputs());
    }

    private void buildConstructor(TypeSpec.Builder classBuilder) {
        MethodSpec.Builder constructorBuilder = MethodSpec.constructorBuilder().addModifiers(new Modifier[]{Modifier.PUBLIC}).addParameter((TypeName)ClassNames.CONTEXT, "context", new Modifier[0]).addException((TypeName)ClassNames.IO_EXCEPTION).addStatement("$L = new $T.Builder(context, $S).build()", new Object[]{FIELD_MODEL, ClassNames.MODEL, this.localModelPath});
        for (TensorInfo tensorInfo : this.modelInfo.getInputs()) {
            CodeBlockInjector preprocessorInjector = InjectorUtils.getInputProcessorInjector(tensorInfo);
            preprocessorInjector.inject(constructorBuilder, tensorInfo);
        }
        for (TensorInfo tensorInfo : this.modelInfo.getOutputs()) {
            CodeBlockInjector postprocessorInjector = InjectorUtils.getOutputProcessorInjector(tensorInfo);
            postprocessorInjector.inject(constructorBuilder, tensorInfo);
            AssociatedFileInjector codeBlockInjector = InjectorUtils.getAssociatedFileInjector();
            ((CodeBlockInjector)codeBlockInjector).inject(constructorBuilder, tensorInfo);
        }
        classBuilder.addMethod(constructorBuilder.build());
    }

    private void buildRunMethod(TypeSpec.Builder classBuilder) {
        ClassName outputType = ClassName.get((String)this.packageName, (String)this.className, (String[])new String[0]).nestedClass("Outputs");
        ClassName inputType = ClassName.get((String)this.packageName, (String)this.className, (String[])new String[0]).nestedClass("Inputs");
        String localInputs = "inputs";
        String localOutputs = "outputs";
        MethodSpec.Builder methodBuilder = MethodSpec.methodBuilder((String)"run").addModifiers(new Modifier[]{Modifier.PUBLIC}).addParameter((TypeName)inputType, localInputs, new Modifier[0]).returns((TypeName)outputType);
        methodBuilder.addStatement("$T $L = new $T()", new Object[]{outputType, localOutputs, outputType});
        methodBuilder.addStatement("$L.run($L.getBuffer(), $L.getBuffer())", new Object[]{FIELD_MODEL, localInputs, localOutputs});
        methodBuilder.addStatement("return $L", new Object[]{localOutputs});
        classBuilder.addMethod(methodBuilder.build());
    }

    private void buildCreateInputsMethod(TypeSpec.Builder classBuilder) {
        ClassName inputType = ClassName.get((String)this.packageName, (String)this.className, (String[])new String[0]).nestedClass("Inputs");
        MethodSpec.Builder methodBuilder = MethodSpec.methodBuilder((String)"createInputs").addModifiers(new Modifier[]{Modifier.PUBLIC}).returns((TypeName)inputType).addStatement("return new $L()", new Object[]{"Inputs"});
        classBuilder.addMethod(methodBuilder.build());
    }
}

