/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.ParallelBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.pooling.Pool;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

public final class GoogLeNet {
    private GoogLeNet() {
    }

    public static Block googLeNet(Builder builder) {
        GoogLeNet googLeNet = new GoogLeNet();
        SequentialBlock block1 = new SequentialBlock();
        block1.add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{7L, 7L}))).optPadding(new Shape(new long[]{3L, 3L}))).optStride(new Shape(new long[]{2L, 2L}))).setFilters(64)).build()).add(Activation::relu).add(Pool.maxPool2dBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{2L, 2L}), (Shape)new Shape(new long[]{1L, 1L})));
        SequentialBlock block2 = new SequentialBlock();
        block2.add((Block)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(64)).setKernelShape(new Shape(new long[]{1L, 1L}))).build()).add(Activation::relu).add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(192)).setKernelShape(new Shape(new long[]{3L, 3L}))).optPadding(new Shape(new long[]{1L, 1L}))).build()).add(Activation::relu).add(Pool.maxPool2dBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{2L, 2L}), (Shape)new Shape(new long[]{1L, 1L})));
        SequentialBlock block3 = new SequentialBlock();
        block3.add((Block)googLeNet.inceptionBlock(64, new int[]{96, 128}, new int[]{16, 32}, 32)).add((Block)googLeNet.inceptionBlock(128, new int[]{128, 192}, new int[]{32, 96}, 64)).add(Pool.maxPool2dBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{2L, 2L}), (Shape)new Shape(new long[]{1L, 1L})));
        SequentialBlock block4 = new SequentialBlock();
        block4.add((Block)googLeNet.inceptionBlock(192, new int[]{96, 208}, new int[]{16, 48}, 64)).add((Block)googLeNet.inceptionBlock(160, new int[]{112, 224}, new int[]{24, 64}, 64)).add((Block)googLeNet.inceptionBlock(128, new int[]{128, 256}, new int[]{24, 64}, 64)).add((Block)googLeNet.inceptionBlock(112, new int[]{144, 288}, new int[]{32, 64}, 64)).add((Block)googLeNet.inceptionBlock(256, new int[]{160, 320}, new int[]{32, 128}, 128)).add(Pool.maxPool2dBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{2L, 2L}), (Shape)new Shape(new long[]{1L, 1L})));
        SequentialBlock block5 = new SequentialBlock();
        block5.add((Block)googLeNet.inceptionBlock(256, new int[]{160, 320}, new int[]{32, 128}, 128)).add((Block)googLeNet.inceptionBlock(384, new int[]{192, 384}, new int[]{48, 128}, 128)).add(Pool.globalAvgPool2dBlock());
        return new SequentialBlock().addAll(new Block[]{block1, block2, block3, block4, block5, Linear.builder().setUnits(builder.outSize).build()});
    }

    public ParallelBlock inceptionBlock(int c1, int[] c2, int[] c3, int c4) {
        SequentialBlock p1 = new SequentialBlock().add((Block)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(c1)).setKernelShape(new Shape(new long[]{1L, 1L}))).build()).add(Activation::relu);
        SequentialBlock p2 = new SequentialBlock().add((Block)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(c2[0])).setKernelShape(new Shape(new long[]{1L, 1L}))).build()).add(Activation::relu).add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(c2[1])).setKernelShape(new Shape(new long[]{3L, 3L}))).optPadding(new Shape(new long[]{1L, 1L}))).build()).add(Activation::relu);
        SequentialBlock p3 = new SequentialBlock().add((Block)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(c3[0])).setKernelShape(new Shape(new long[]{1L, 1L}))).build()).add(Activation::relu).add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(c3[1])).setKernelShape(new Shape(new long[]{5L, 5L}))).optPadding(new Shape(new long[]{2L, 2L}))).build()).add(Activation::relu);
        SequentialBlock p4 = new SequentialBlock().add(Pool.maxPool2dBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{1L, 1L}), (Shape)new Shape(new long[]{1L, 1L}))).add((Block)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(c4)).setKernelShape(new Shape(new long[]{1L, 1L}))).build()).add(Activation::relu);
        return new ParallelBlock(list -> {
            List concatenatedList = list.stream().map(NDList::head).collect(Collectors.toList());
            return new NDList(new NDArray[]{NDArrays.concat((NDList)new NDList(concatenatedList), (int)1)});
        }, Arrays.asList(p1, p2, p3, p4));
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        long outSize = 10L;

        Builder() {
        }

        public Builder setOutSize(long outSize) {
            this.outSize = outSize;
            return this;
        }

        public Block build() {
            return GoogLeNet.googLeNet(this);
        }
    }
}

