/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.mlp;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import java.io.Closeable;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.WritableUtils;
import org.apache.mahout.classifier.mlp.NeuralNetworkFunctions;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.DoubleFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class NeuralNetwork {
    private static final Logger log = LoggerFactory.getLogger(NeuralNetwork.class);
    public static final double DEFAULT_LEARNING_RATE = 0.5;
    public static final double DEFAULT_REGULARIZATION_WEIGHT = 0.0;
    public static final double DEFAULT_MOMENTUM_WEIGHT = 0.1;
    protected String modelType;
    protected String modelPath;
    protected double learningRate;
    protected double regularizationWeight;
    protected double momentumWeight;
    protected String costFunctionName;
    protected List<Integer> layerSizeList;
    protected TrainingMethod trainingMethod;
    protected List<Matrix> weightMatrixList;
    protected List<Matrix> prevWeightUpdatesList;
    protected List<String> squashingFunctionList;
    protected int finalLayerIndex;

    public NeuralNetwork() {
        log.info("Initialize model...");
        this.learningRate = 0.5;
        this.regularizationWeight = 0.0;
        this.momentumWeight = 0.1;
        this.trainingMethod = TrainingMethod.GRADIENT_DESCENT;
        this.costFunctionName = "Minus_Squared";
        this.modelType = this.getClass().getSimpleName();
        this.layerSizeList = Lists.newArrayList();
        this.layerSizeList = Lists.newArrayList();
        this.weightMatrixList = Lists.newArrayList();
        this.prevWeightUpdatesList = Lists.newArrayList();
        this.squashingFunctionList = Lists.newArrayList();
    }

    public NeuralNetwork(double learningRate, double momentumWeight, double regularizationWeight) {
        this();
        this.setLearningRate(learningRate);
        this.setMomentumWeight(momentumWeight);
        this.setRegularizationWeight(regularizationWeight);
    }

    public NeuralNetwork(String modelPath) throws IOException {
        this.modelPath = modelPath;
        this.readFromModel();
    }

    public String getModelType() {
        return this.modelType;
    }

    public final NeuralNetwork setLearningRate(double learningRate) {
        Preconditions.checkArgument((learningRate > 0.0 ? 1 : 0) != 0, (Object)"Learning rate must be larger than 0.");
        this.learningRate = learningRate;
        return this;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public final NeuralNetwork setRegularizationWeight(double regularizationWeight) {
        Preconditions.checkArgument((regularizationWeight >= 0.0 && regularizationWeight < 0.1 ? 1 : 0) != 0, (Object)"Regularization weight must be in range [0, 0.1)");
        this.regularizationWeight = regularizationWeight;
        return this;
    }

    public double getRegularizationWeight() {
        return this.regularizationWeight;
    }

    public final NeuralNetwork setMomentumWeight(double momentumWeight) {
        Preconditions.checkArgument((momentumWeight >= 0.0 && momentumWeight <= 1.0 ? 1 : 0) != 0, (Object)"Momentum weight must be in range [0, 1.0]");
        this.momentumWeight = momentumWeight;
        return this;
    }

    public double getMomentumWeight() {
        return this.momentumWeight;
    }

    public NeuralNetwork setTrainingMethod(TrainingMethod method) {
        this.trainingMethod = method;
        return this;
    }

    public TrainingMethod getTrainingMethod() {
        return this.trainingMethod;
    }

    public NeuralNetwork setCostFunction(String costFunction) {
        this.costFunctionName = costFunction;
        return this;
    }

    public int addLayer(int size, boolean isFinalLayer, String squashingFunctionName) {
        Preconditions.checkArgument((size > 0 ? 1 : 0) != 0, (Object)"Size of layer must be larger than 0.");
        log.info("Add layer with size {} and squashing function {}", (Object)size, (Object)squashingFunctionName);
        int actualSize = size;
        if (!isFinalLayer) {
            ++actualSize;
        }
        this.layerSizeList.add(actualSize);
        int layerIndex = this.layerSizeList.size() - 1;
        if (isFinalLayer) {
            this.finalLayerIndex = layerIndex;
        }
        if (layerIndex > 0) {
            int sizePrevLayer = this.layerSizeList.get(layerIndex - 1);
            int row = isFinalLayer ? actualSize : actualSize - 1;
            DenseMatrix weightMatrix = new DenseMatrix(row, sizePrevLayer);
            final RandomWrapper rnd = RandomUtils.getRandom();
            weightMatrix.assign(new DoubleFunction(){

                public double apply(double value) {
                    return rnd.nextDouble() - 0.5;
                }
            });
            this.weightMatrixList.add((Matrix)weightMatrix);
            this.prevWeightUpdatesList.add((Matrix)new DenseMatrix(row, sizePrevLayer));
            this.squashingFunctionList.add(squashingFunctionName);
        }
        return layerIndex;
    }

    public int getLayerSize(int layer) {
        Preconditions.checkArgument((layer >= 0 && layer < this.layerSizeList.size() ? 1 : 0) != 0, (Object)String.format("Input must be in range [0, %d]\n", this.layerSizeList.size() - 1));
        return this.layerSizeList.get(layer);
    }

    protected List<Integer> getLayerSizeList() {
        return this.layerSizeList;
    }

    public Matrix getWeightsByLayer(int layerIndex) {
        return this.weightMatrixList.get(layerIndex);
    }

    public void updateWeightMatrices(Matrix[] matrices) {
        for (int i = 0; i < matrices.length; ++i) {
            Matrix matrix = this.weightMatrixList.get(i);
            this.weightMatrixList.set(i, matrix.plus(matrices[i]));
        }
    }

    public void setWeightMatrices(Matrix[] matrices) {
        this.weightMatrixList = Lists.newArrayList();
        Collections.addAll(this.weightMatrixList, matrices);
    }

    public void setWeightMatrix(int index, Matrix matrix) {
        Preconditions.checkArgument((0 <= index && index < this.weightMatrixList.size() ? 1 : 0) != 0, (Object)String.format("index [%s] should be in range [%s, %s).", index, 0, this.weightMatrixList.size()));
        this.weightMatrixList.set(index, matrix);
    }

    public Matrix[] getWeightMatrices() {
        Matrix[] matrices = new Matrix[this.weightMatrixList.size()];
        this.weightMatrixList.toArray(matrices);
        return matrices;
    }

    public Vector getOutput(Vector instance) {
        Preconditions.checkArgument((this.layerSizeList.get(0) == instance.size() + 1 ? 1 : 0) != 0, (Object)String.format("The dimension of input instance should be %d, but the input has dimension %d.", this.layerSizeList.get(0) - 1, instance.size()));
        DenseVector instanceWithBias = new DenseVector(instance.size() + 1);
        instanceWithBias.set(0, 0.99999);
        for (int i = 1; i < instanceWithBias.size(); ++i) {
            instanceWithBias.set(i, instance.get(i - 1));
        }
        List<Vector> outputCache = this.getOutputInternal((Vector)instanceWithBias);
        Vector result = outputCache.get(outputCache.size() - 1);
        return result.viewPart(1, result.size() - 1);
    }

    protected List<Vector> getOutputInternal(Vector instance) {
        ArrayList outputCache = Lists.newArrayList();
        Vector intermediateOutput = instance;
        outputCache.add(intermediateOutput);
        for (int i = 0; i < this.layerSizeList.size() - 1; ++i) {
            intermediateOutput = this.forward(i, intermediateOutput);
            outputCache.add(intermediateOutput);
        }
        return outputCache;
    }

    protected Vector forward(int fromLayer, Vector intermediateOutput) {
        Matrix weightMatrix = this.weightMatrixList.get(fromLayer);
        Vector vec = weightMatrix.times(intermediateOutput);
        vec = vec.assign(NeuralNetworkFunctions.getDoubleFunction(this.squashingFunctionList.get(fromLayer)));
        DenseVector vecWithBias = new DenseVector(vec.size() + 1);
        vecWithBias.set(0, 1.0);
        for (int i = 0; i < vec.size(); ++i) {
            vecWithBias.set(i + 1, vec.get(i));
        }
        return vecWithBias;
    }

    public void trainOnline(Vector trainingInstance) {
        Matrix[] matrices = this.trainByInstance(trainingInstance);
        this.updateWeightMatrices(matrices);
    }

    public Matrix[] trainByInstance(Vector trainingInstance) {
        int outputDimension;
        int inputDimension = this.layerSizeList.get(0) - 1;
        Preconditions.checkArgument((inputDimension + (outputDimension = this.layerSizeList.get(this.layerSizeList.size() - 1).intValue()) == trainingInstance.size() ? 1 : 0) != 0, (Object)String.format("The dimension of training instance is %d, but requires %d.", trainingInstance.size(), inputDimension + outputDimension));
        if (this.trainingMethod.equals((Object)TrainingMethod.GRADIENT_DESCENT)) {
            return this.trainByInstanceGradientDescent(trainingInstance);
        }
        throw new IllegalArgumentException("Training method is not supported.");
    }

    private Matrix[] trainByInstanceGradientDescent(Vector trainingInstance) {
        int inputDimension = this.layerSizeList.get(0) - 1;
        DenseVector inputInstance = new DenseVector(this.layerSizeList.get(0).intValue());
        inputInstance.set(0, 1.0);
        for (int i = 0; i < inputDimension; ++i) {
            inputInstance.set(i + 1, trainingInstance.get(i));
        }
        Vector labels = trainingInstance.viewPart(inputInstance.size() - 1, trainingInstance.size() - inputInstance.size() + 1);
        Matrix[] weightUpdateMatrices = new Matrix[this.weightMatrixList.size()];
        for (int m = 0; m < weightUpdateMatrices.length; ++m) {
            weightUpdateMatrices[m] = new DenseMatrix(this.weightMatrixList.get(m).rowSize(), this.weightMatrixList.get(m).columnSize());
        }
        List<Vector> internalResults = this.getOutputInternal((Vector)inputInstance);
        DenseVector deltaVec = new DenseVector(this.layerSizeList.get(this.layerSizeList.size() - 1).intValue());
        Vector output = internalResults.get(internalResults.size() - 1);
        DoubleFunction derivativeSquashingFunction = NeuralNetworkFunctions.getDerivativeDoubleFunction(this.squashingFunctionList.get(this.squashingFunctionList.size() - 1));
        DoubleDoubleFunction costFunction = NeuralNetworkFunctions.getDerivativeDoubleDoubleFunction(this.costFunctionName);
        Matrix lastWeightMatrix = this.weightMatrixList.get(this.weightMatrixList.size() - 1);
        for (int i = 0; i < deltaVec.size(); ++i) {
            double costFuncDerivative = costFunction.apply(labels.get(i), output.get(i + 1));
            deltaVec.set(i, costFuncDerivative += this.regularizationWeight * lastWeightMatrix.viewRow(i).zSum());
            deltaVec.set(i, deltaVec.get(i) * derivativeSquashingFunction.apply(output.get(i + 1)));
        }
        for (int layer = this.layerSizeList.size() - 2; layer >= 0; --layer) {
            deltaVec = this.backPropagate(layer, (Vector)deltaVec, internalResults, weightUpdateMatrices[layer]);
        }
        this.prevWeightUpdatesList = Arrays.asList(weightUpdateMatrices);
        return weightUpdateMatrices;
    }

    private Vector backPropagate(int currentLayerIndex, Vector nextLayerDelta, List<Vector> outputCache, Matrix weightUpdateMatrix) {
        final DoubleFunction derivativeSquashingFunction = NeuralNetworkFunctions.getDerivativeDoubleFunction(this.squashingFunctionList.get(currentLayerIndex));
        Vector curLayerOutput = outputCache.get(currentLayerIndex);
        Matrix weightMatrix = this.weightMatrixList.get(currentLayerIndex);
        Matrix prevWeightMatrix = this.prevWeightUpdatesList.get(currentLayerIndex);
        if (currentLayerIndex != this.layerSizeList.size() - 2) {
            nextLayerDelta = nextLayerDelta.viewPart(1, nextLayerDelta.size() - 1);
        }
        Vector delta = weightMatrix.transpose().times(nextLayerDelta);
        delta = delta.assign(curLayerOutput, new DoubleDoubleFunction(){

            public double apply(double deltaElem, double curLayerOutputElem) {
                return deltaElem * derivativeSquashingFunction.apply(curLayerOutputElem);
            }
        });
        for (int i = 0; i < weightUpdateMatrix.rowSize(); ++i) {
            for (int j = 0; j < weightUpdateMatrix.columnSize(); ++j) {
                weightUpdateMatrix.set(i, j, -this.learningRate * nextLayerDelta.get(i) * curLayerOutput.get(j) + this.momentumWeight * prevWeightMatrix.get(i, j));
            }
        }
        return delta;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void readFromModel() throws IOException {
        log.info("Load model from {}", (Object)this.modelPath);
        Preconditions.checkArgument((this.modelPath != null ? 1 : 0) != 0, (Object)"Model path has not been set.");
        FSDataInputStream is = null;
        try {
            Path path = new Path(this.modelPath);
            FileSystem fs = path.getFileSystem(new Configuration());
            is = new FSDataInputStream((InputStream)fs.open(path));
            this.readFields((DataInput)is);
        }
        catch (Throwable throwable) {
            Closeables.close(is, (boolean)true);
            throw throwable;
        }
        Closeables.close((Closeable)is, (boolean)true);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void writeModelToFile() throws IOException {
        log.info("Write model to {}.", (Object)this.modelPath);
        Preconditions.checkArgument((this.modelPath != null ? 1 : 0) != 0, (Object)"Model path has not been set.");
        FSDataOutputStream stream = null;
        try {
            Path path = new Path(this.modelPath);
            FileSystem fs = path.getFileSystem(new Configuration());
            stream = fs.create(path, true);
            this.write((DataOutput)stream);
        }
        catch (Throwable throwable) {
            Closeables.close(stream, (boolean)false);
            throw throwable;
        }
        Closeables.close((Closeable)stream, (boolean)false);
    }

    public void setModelPath(String modelPath) {
        this.modelPath = modelPath;
    }

    public String getModelPath() {
        return this.modelPath;
    }

    public void write(DataOutput output) throws IOException {
        WritableUtils.writeString((DataOutput)output, (String)this.modelType);
        output.writeDouble(this.learningRate);
        if (this.modelPath != null) {
            WritableUtils.writeString((DataOutput)output, (String)this.modelPath);
        } else {
            WritableUtils.writeString((DataOutput)output, (String)"null");
        }
        output.writeDouble(this.regularizationWeight);
        output.writeDouble(this.momentumWeight);
        WritableUtils.writeString((DataOutput)output, (String)this.costFunctionName);
        output.writeInt(this.layerSizeList.size());
        for (Integer aLayerSizeList : this.layerSizeList) {
            output.writeInt(aLayerSizeList);
        }
        WritableUtils.writeEnum((DataOutput)output, (Enum)this.trainingMethod);
        output.writeInt(this.squashingFunctionList.size());
        for (String aSquashingFunctionList : this.squashingFunctionList) {
            WritableUtils.writeString((DataOutput)output, (String)aSquashingFunctionList);
        }
        output.writeInt(this.weightMatrixList.size());
        for (Matrix aWeightMatrixList : this.weightMatrixList) {
            MatrixWritable.writeMatrix(output, aWeightMatrixList);
        }
    }

    public void readFields(DataInput input) throws IOException {
        this.modelType = WritableUtils.readString((DataInput)input);
        if (!this.modelType.equals(this.getClass().getSimpleName())) {
            throw new IllegalArgumentException("The specified location does not contains the valid NeuralNetwork model.");
        }
        this.learningRate = input.readDouble();
        this.modelPath = WritableUtils.readString((DataInput)input);
        if (this.modelPath.equals("null")) {
            this.modelPath = null;
        }
        this.regularizationWeight = input.readDouble();
        this.momentumWeight = input.readDouble();
        this.costFunctionName = WritableUtils.readString((DataInput)input);
        int numLayers = input.readInt();
        this.layerSizeList = Lists.newArrayList();
        for (int i = 0; i < numLayers; ++i) {
            this.layerSizeList.add(input.readInt());
        }
        this.trainingMethod = (TrainingMethod)WritableUtils.readEnum((DataInput)input, TrainingMethod.class);
        int squashingFunctionSize = input.readInt();
        this.squashingFunctionList = Lists.newArrayList();
        for (int i = 0; i < squashingFunctionSize; ++i) {
            this.squashingFunctionList.add(WritableUtils.readString((DataInput)input));
        }
        int numOfMatrices = input.readInt();
        this.weightMatrixList = Lists.newArrayList();
        this.prevWeightUpdatesList = Lists.newArrayList();
        for (int i = 0; i < numOfMatrices; ++i) {
            Matrix matrix = MatrixWritable.readMatrix(input);
            this.weightMatrixList.add(matrix);
            this.prevWeightUpdatesList.add((Matrix)new DenseMatrix(matrix.rowSize(), matrix.columnSize()));
        }
    }

    public static enum TrainingMethod {
        GRADIENT_DESCENT;

    }
}

