/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import com.google.common.collect.Sets;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Collection;
import java.util.HashSet;
import java.util.Random;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.Functions;

public class GradientMachine
extends AbstractVectorClassifier
implements OnlineLearner,
Writable {
    public static final int WRITABLE_VERSION = 1;
    private double learningRate = 0.1;
    private double regularization = 0.1;
    private double sparsity = 0.1;
    private double sparsityLearningRate = 0.1;
    private int numFeatures = 10;
    private int numHidden = 100;
    private int numOutput = 2;
    private Vector[] hiddenWeights;
    private Vector[] outputWeights;
    private Vector hiddenBias;
    private Vector outputBias;
    private final Random rnd;

    public GradientMachine(int numFeatures, int numHidden, int numOutput) {
        int i;
        this.numFeatures = numFeatures;
        this.numHidden = numHidden;
        this.numOutput = numOutput;
        this.hiddenWeights = new DenseVector[numHidden];
        for (i = 0; i < numHidden; ++i) {
            this.hiddenWeights[i] = new DenseVector(numFeatures);
            this.hiddenWeights[i].assign(0.0);
        }
        this.hiddenBias = new DenseVector(numHidden);
        this.hiddenBias.assign(0.0);
        this.outputWeights = new DenseVector[numOutput];
        for (i = 0; i < numOutput; ++i) {
            this.outputWeights[i] = new DenseVector(numHidden);
            this.outputWeights[i].assign(0.0);
        }
        this.outputBias = new DenseVector(numOutput);
        this.outputBias.assign(0.0);
        this.rnd = RandomUtils.getRandom();
    }

    public void initWeights(Random gen) {
        double hiddenFanIn = 1.0 / Math.sqrt(this.numFeatures);
        for (int i = 0; i < this.numHidden; ++i) {
            for (int j = 0; j < this.numFeatures; ++j) {
                double val = (2.0 * gen.nextDouble() - 1.0) * hiddenFanIn;
                this.hiddenWeights[i].setQuick(j, val);
            }
        }
        double outputFanIn = 1.0 / Math.sqrt(this.numHidden);
        for (int i = 0; i < this.numOutput; ++i) {
            for (int j = 0; j < this.numHidden; ++j) {
                double val = (2.0 * gen.nextDouble() - 1.0) * outputFanIn;
                this.outputWeights[i].setQuick(j, val);
            }
        }
    }

    public GradientMachine learningRate(double learningRate) {
        this.learningRate = learningRate;
        return this;
    }

    public GradientMachine regularization(double regularization) {
        this.regularization = regularization;
        return this;
    }

    public GradientMachine sparsity(double sparsity) {
        this.sparsity = sparsity;
        return this;
    }

    public GradientMachine sparsityLearningRate(double sparsityLearningRate) {
        this.sparsityLearningRate = sparsityLearningRate;
        return this;
    }

    public void copyFrom(GradientMachine other) {
        int i;
        this.numFeatures = other.numFeatures;
        this.numHidden = other.numHidden;
        this.numOutput = other.numOutput;
        this.learningRate = other.learningRate;
        this.regularization = other.regularization;
        this.sparsity = other.sparsity;
        this.sparsityLearningRate = other.sparsityLearningRate;
        this.hiddenWeights = new DenseVector[this.numHidden];
        for (i = 0; i < this.numHidden; ++i) {
            this.hiddenWeights[i] = other.hiddenWeights[i].clone();
        }
        this.hiddenBias = other.hiddenBias.clone();
        this.outputWeights = new DenseVector[this.numOutput];
        for (i = 0; i < this.numOutput; ++i) {
            this.outputWeights[i] = other.outputWeights[i].clone();
        }
        this.outputBias = other.outputBias.clone();
    }

    @Override
    public int numCategories() {
        return this.numOutput;
    }

    public int numFeatures() {
        return this.numFeatures;
    }

    public int numHidden() {
        return this.numHidden;
    }

    public DenseVector inputToHidden(Vector input) {
        DenseVector activations = new DenseVector(this.numHidden);
        for (int i = 0; i < this.numHidden; ++i) {
            activations.setQuick(i, this.hiddenWeights[i].dot(input));
        }
        activations.assign(this.hiddenBias, Functions.PLUS);
        activations.assign(Functions.min((double)40.0)).assign(Functions.max((double)-40.0));
        activations.assign(Functions.SIGMOID);
        return activations;
    }

    public DenseVector hiddenToOutput(Vector hiddenActivation) {
        DenseVector activations = new DenseVector(this.numOutput);
        for (int i = 0; i < this.numOutput; ++i) {
            activations.setQuick(i, this.outputWeights[i].dot(hiddenActivation));
        }
        activations.assign(this.outputBias, Functions.PLUS);
        return activations;
    }

    public void updateRanking(Vector hiddenActivation, Collection<Integer> goodLabels, int numTrials, Random gen) {
        if (goodLabels.size() >= this.numOutput) {
            return;
        }
        for (Integer good : goodLabels) {
            int i;
            double goodScore = this.outputWeights[good].dot(hiddenActivation);
            int highestBad = -1;
            double highestBadScore = Double.NEGATIVE_INFINITY;
            for (int i2 = 0; i2 < numTrials; ++i2) {
                int bad = gen.nextInt(this.numOutput);
                while (goodLabels.contains(bad)) {
                    bad = gen.nextInt(this.numOutput);
                }
                double badScore = this.outputWeights[bad].dot(hiddenActivation);
                if (!(badScore > highestBadScore)) continue;
                highestBadScore = badScore;
                highestBad = bad;
            }
            int bad = highestBad;
            double loss = 1.0 - goodScore + highestBadScore;
            if (loss < 0.0) continue;
            Vector gradGood = this.outputWeights[good].clone();
            gradGood.assign(Functions.NEGATE);
            Vector propHidden = gradGood.clone();
            Vector gradBad = this.outputWeights[bad].clone();
            propHidden.assign(gradBad, Functions.PLUS);
            gradGood.assign(Functions.mult((double)(-this.learningRate * (1.0 - this.regularization))));
            this.outputWeights[good].assign(gradGood, Functions.PLUS);
            gradBad.assign(Functions.mult((double)(-this.learningRate * (1.0 + this.regularization))));
            this.outputWeights[bad].assign(gradBad, Functions.PLUS);
            this.outputBias.setQuick(good.intValue(), this.outputBias.get(good.intValue()) + this.learningRate);
            this.outputBias.setQuick(bad, this.outputBias.get(bad) - this.learningRate);
            Vector gradSig = hiddenActivation.clone();
            gradSig.assign(Functions.SIGMOIDGRADIENT);
            for (i = 0; i < this.numHidden; ++i) {
                gradSig.setQuick(i, gradSig.get(i) * propHidden.get(i));
            }
            for (i = 0; i < this.numHidden; ++i) {
                for (int j = 0; j < this.numFeatures; ++j) {
                    double v = this.hiddenWeights[i].get(j);
                    v -= this.learningRate * (gradSig.get(i) + this.regularization * v);
                    this.hiddenWeights[i].setQuick(j, v);
                }
            }
        }
    }

    @Override
    public Vector classify(Vector instance) {
        Vector result = this.classifyNoLink(instance);
        int max = result.maxValueIndex();
        result.assign(0.0);
        result.setQuick(max, 1.0);
        return result.viewPart(1, result.size() - 1);
    }

    @Override
    public Vector classifyNoLink(Vector instance) {
        DenseVector hidden = this.inputToHidden(instance);
        return this.hiddenToOutput((Vector)hidden);
    }

    @Override
    public double classifyScalar(Vector instance) {
        Vector output = this.classifyNoLink(instance);
        if (output.get(0) > output.get(1)) {
            return 0.0;
        }
        return 1.0;
    }

    public GradientMachine copy() {
        this.close();
        GradientMachine r = new GradientMachine(this.numFeatures(), this.numHidden(), this.numCategories());
        r.copyFrom(this);
        return r;
    }

    public void write(DataOutput out) throws IOException {
        int i;
        out.writeInt(1);
        out.writeDouble(this.learningRate);
        out.writeDouble(this.regularization);
        out.writeDouble(this.sparsity);
        out.writeDouble(this.sparsityLearningRate);
        out.writeInt(this.numFeatures);
        out.writeInt(this.numHidden);
        out.writeInt(this.numOutput);
        VectorWritable.writeVector(out, this.hiddenBias);
        for (i = 0; i < this.numHidden; ++i) {
            VectorWritable.writeVector(out, this.hiddenWeights[i]);
        }
        VectorWritable.writeVector(out, this.outputBias);
        for (i = 0; i < this.numOutput; ++i) {
            VectorWritable.writeVector(out, this.outputWeights[i]);
        }
    }

    public void readFields(DataInput in) throws IOException {
        int version = in.readInt();
        if (version == 1) {
            int i;
            this.learningRate = in.readDouble();
            this.regularization = in.readDouble();
            this.sparsity = in.readDouble();
            this.sparsityLearningRate = in.readDouble();
            this.numFeatures = in.readInt();
            this.numHidden = in.readInt();
            this.numOutput = in.readInt();
            this.hiddenWeights = new DenseVector[this.numHidden];
            this.hiddenBias = VectorWritable.readVector(in);
            for (i = 0; i < this.numHidden; ++i) {
                this.hiddenWeights[i] = VectorWritable.readVector(in);
            }
            this.outputWeights = new DenseVector[this.numOutput];
            this.outputBias = VectorWritable.readVector(in);
            for (i = 0; i < this.numOutput; ++i) {
                this.outputWeights[i] = VectorWritable.readVector(in);
            }
        } else {
            throw new IOException("Incorrect object version, wanted 1 got " + version);
        }
    }

    @Override
    public void close() {
    }

    @Override
    public void train(long trackingKey, String groupKey, int actual, Vector instance) {
        DenseVector hiddenActivation = this.inputToHidden(instance);
        this.hiddenToOutput((Vector)hiddenActivation);
        HashSet goodLabels = Sets.newHashSet();
        goodLabels.add(actual);
        this.updateRanking((Vector)hiddenActivation, goodLabels, 2, this.rnd);
    }

    @Override
    public void train(long trackingKey, int actual, Vector instance) {
        this.train(trackingKey, null, actual, instance);
    }

    @Override
    public void train(int actual, Vector instance) {
        this.train(0L, null, actual, instance);
    }
}

