/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.sgd.AbstractOnlineLogisticRegression;
import org.apache.mahout.classifier.sgd.PolymorphicWritable;
import org.apache.mahout.classifier.sgd.PriorFunction;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MatrixWritable;
import org.apache.mahout.math.VectorWritable;

public class OnlineLogisticRegression
extends AbstractOnlineLogisticRegression
implements Writable {
    public static final int WRITABLE_VERSION = 1;
    private double mu0 = 1.0;
    private double decayFactor = 0.999;
    private int stepOffset = 10;
    private double forgettingExponent = -0.5;
    private int perTermAnnealingOffset = 20;

    public OnlineLogisticRegression() {
    }

    public OnlineLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) {
        this.numCategories = numCategories;
        this.prior = prior;
        this.updateSteps = new DenseVector(numFeatures);
        this.updateCounts = new DenseVector(numFeatures).assign((double)this.perTermAnnealingOffset);
        this.beta = new DenseMatrix(numCategories - 1, numFeatures);
    }

    public OnlineLogisticRegression alpha(double alpha) {
        this.decayFactor = alpha;
        return this;
    }

    @Override
    public OnlineLogisticRegression lambda(double lambda) {
        super.lambda(lambda);
        return this;
    }

    public OnlineLogisticRegression learningRate(double learningRate) {
        this.mu0 = learningRate;
        return this;
    }

    public OnlineLogisticRegression stepOffset(int stepOffset) {
        this.stepOffset = stepOffset;
        return this;
    }

    public OnlineLogisticRegression decayExponent(double decayExponent) {
        if (decayExponent > 0.0) {
            decayExponent = -decayExponent;
        }
        this.forgettingExponent = decayExponent;
        return this;
    }

    @Override
    public double perTermLearningRate(int j) {
        return Math.sqrt((double)this.perTermAnnealingOffset / this.updateCounts.get(j));
    }

    @Override
    public double currentLearningRate() {
        return this.mu0 * Math.pow(this.decayFactor, this.getStep()) * Math.pow(this.getStep() + this.stepOffset, this.forgettingExponent);
    }

    public void copyFrom(OnlineLogisticRegression other) {
        super.copyFrom(other);
        this.mu0 = other.mu0;
        this.decayFactor = other.decayFactor;
        this.stepOffset = other.stepOffset;
        this.forgettingExponent = other.forgettingExponent;
        this.perTermAnnealingOffset = other.perTermAnnealingOffset;
    }

    public OnlineLogisticRegression copy() {
        this.close();
        OnlineLogisticRegression r = new OnlineLogisticRegression(this.numCategories(), this.numFeatures(), this.prior);
        r.copyFrom(this);
        return r;
    }

    public void write(DataOutput out) throws IOException {
        out.writeInt(1);
        out.writeDouble(this.mu0);
        out.writeDouble(this.getLambda());
        out.writeDouble(this.decayFactor);
        out.writeInt(this.stepOffset);
        out.writeInt(this.step);
        out.writeDouble(this.forgettingExponent);
        out.writeInt(this.perTermAnnealingOffset);
        out.writeInt(this.numCategories);
        MatrixWritable.writeMatrix(out, this.beta);
        PolymorphicWritable.write(out, this.prior);
        VectorWritable.writeVector(out, this.updateCounts);
        VectorWritable.writeVector(out, this.updateSteps);
    }

    public void readFields(DataInput in) throws IOException {
        int version = in.readInt();
        if (version != 1) {
            throw new IOException("Incorrect object version, wanted 1 got " + version);
        }
        this.mu0 = in.readDouble();
        this.lambda(in.readDouble());
        this.decayFactor = in.readDouble();
        this.stepOffset = in.readInt();
        this.step = in.readInt();
        this.forgettingExponent = in.readDouble();
        this.perTermAnnealingOffset = in.readInt();
        this.numCategories = in.readInt();
        this.beta = MatrixWritable.readMatrix(in);
        this.prior = PolymorphicWritable.read(in, PriorFunction.class);
        this.updateCounts = VectorWritable.readVector(in);
        this.updateSteps = VectorWritable.readVector(in);
    }
}

