/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import com.google.common.collect.Lists;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.ExecutionException;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.classifier.sgd.CrossFoldLearner;
import org.apache.mahout.classifier.sgd.PriorFunction;
import org.apache.mahout.ep.EvolutionaryProcess;
import org.apache.mahout.ep.Mapping;
import org.apache.mahout.ep.Payload;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.stats.OnlineAuc;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AdaptiveLogisticRegression
implements OnlineLearner,
Writable {
    public static final int DEFAULT_THREAD_COUNT = 20;
    public static final int DEFAULT_POOL_SIZE = 20;
    private static final int SURVIVORS = 2;
    private int record;
    private int cutoff = 1000;
    private int minInterval = 1000;
    private int maxInterval = 1000;
    private int currentStep = 1000;
    private int bufferSize = 1000;
    private List<TrainingExample> buffer = Lists.newArrayList();
    private EvolutionaryProcess<Wrapper, CrossFoldLearner> ep;
    private State<Wrapper, CrossFoldLearner> best;
    private int threadCount = 20;
    private int poolSize = 20;
    private State<Wrapper, CrossFoldLearner> seed;
    private int numFeatures;
    private boolean freezeSurvivors = true;
    private static final Logger log = LoggerFactory.getLogger(AdaptiveLogisticRegression.class);

    public AdaptiveLogisticRegression() {
    }

    public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) {
        this(numCategories, numFeatures, prior, 20, 20);
    }

    public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior, int threadCount, int poolSize) {
        this.numFeatures = numFeatures;
        this.threadCount = threadCount;
        this.poolSize = poolSize;
        this.seed = new State(new double[2], 10.0);
        Wrapper w = new Wrapper(numCategories, numFeatures, prior);
        this.seed.setPayload(w);
        Wrapper.setMappings(this.seed);
        this.seed.setPayload(w);
        this.setPoolSize(this.poolSize);
    }

    @Override
    public void train(int actual, Vector instance) {
        this.train(this.record, null, actual, instance);
    }

    @Override
    public void train(long trackingKey, int actual, Vector instance) {
        this.train(trackingKey, null, actual, instance);
    }

    @Override
    public void train(long trackingKey, String groupKey, int actual, Vector instance) {
        ++this.record;
        this.buffer.add(new TrainingExample(trackingKey, groupKey, actual, instance));
        if (this.buffer.size() > this.bufferSize) {
            this.trainWithBufferedExamples();
        }
    }

    private void trainWithBufferedExamples() {
        try {
            this.best = this.ep.parallelDo(new EvolutionaryProcess.Function<Payload<CrossFoldLearner>>(){

                @Override
                public double apply(Payload<CrossFoldLearner> z, double[] params) {
                    Wrapper x = (Wrapper)z;
                    for (TrainingExample example : AdaptiveLogisticRegression.this.buffer) {
                        x.train(example);
                    }
                    if (x.getLearner().validModel()) {
                        if (x.getLearner().numCategories() == 2) {
                            return x.wrapped.auc();
                        }
                        return x.wrapped.logLikelihood();
                    }
                    return Double.NaN;
                }
            });
        }
        catch (InterruptedException e) {
            log.warn("Ignoring exception", (Throwable)e);
        }
        catch (ExecutionException e) {
            throw new IllegalStateException(e.getCause());
        }
        this.buffer.clear();
        if (this.record > this.cutoff) {
            this.cutoff = this.nextStep(this.record);
            this.ep.mutatePopulation(2);
            if (this.freezeSurvivors) {
                for (State<Wrapper, CrossFoldLearner> state : this.ep.getPopulation().subList(0, 2)) {
                    Wrapper.freeze(state);
                }
            }
        }
    }

    public int nextStep(int recordNumber) {
        int newCutoff;
        int stepSize = AdaptiveLogisticRegression.stepSize(recordNumber, 2.6);
        if (stepSize < this.minInterval) {
            stepSize = this.minInterval;
        }
        if (stepSize > this.maxInterval) {
            stepSize = this.maxInterval;
        }
        if ((newCutoff = stepSize * (recordNumber / stepSize + 1)) < this.cutoff + this.currentStep) {
            newCutoff = this.cutoff + this.currentStep;
        } else {
            this.currentStep = stepSize;
        }
        return newCutoff;
    }

    public static int stepSize(int recordNumber, double multiplier) {
        int[] bumps = new int[]{1, 2, 5};
        double log = Math.floor(multiplier * Math.log10(recordNumber));
        int bump = bumps[(int)log % bumps.length];
        int scale = (int)Math.pow(10.0, Math.floor(log / (double)bumps.length));
        return bump * scale;
    }

    @Override
    public void close() {
        this.trainWithBufferedExamples();
        try {
            this.ep.parallelDo(new EvolutionaryProcess.Function<Payload<CrossFoldLearner>>(){

                @Override
                public double apply(Payload<CrossFoldLearner> payload, double[] params) {
                    CrossFoldLearner learner = ((Wrapper)payload).getLearner();
                    learner.close();
                    return learner.logLikelihood();
                }
            });
        }
        catch (InterruptedException e) {
            log.warn("Ignoring exception", (Throwable)e);
        }
        catch (ExecutionException e) {
            throw new IllegalStateException(e);
        }
        finally {
            this.ep.close();
        }
    }

    public void setInterval(int interval) {
        this.setInterval(interval, interval);
    }

    public void setInterval(int minInterval, int maxInterval) {
        this.minInterval = Math.max(200, minInterval);
        this.maxInterval = Math.max(200, maxInterval);
        this.cutoff = minInterval * (this.record / minInterval + 1);
        this.currentStep = minInterval;
        this.bufferSize = Math.min(minInterval, this.bufferSize);
    }

    public final void setPoolSize(int poolSize) {
        this.poolSize = poolSize;
        this.setupOptimizer(poolSize);
    }

    public void setThreadCount(int threadCount) {
        this.threadCount = threadCount;
        this.setupOptimizer(this.poolSize);
    }

    public void setAucEvaluator(OnlineAuc auc) {
        this.seed.getPayload().setAucEvaluator(auc);
        this.setupOptimizer(this.poolSize);
    }

    private void setupOptimizer(int poolSize) {
        this.ep = new EvolutionaryProcess<Wrapper, CrossFoldLearner>(this.threadCount, poolSize, this.seed);
    }

    public int numFeatures() {
        return this.numFeatures;
    }

    public double auc() {
        if (this.best == null) {
            return Double.NaN;
        }
        Wrapper payload = this.best.getPayload();
        return payload.getLearner().auc();
    }

    public State<Wrapper, CrossFoldLearner> getBest() {
        return this.best;
    }

    public void setBest(State<Wrapper, CrossFoldLearner> best) {
        this.best = best;
    }

    public int getRecord() {
        return this.record;
    }

    public void setRecord(int record) {
        this.record = record;
    }

    public int getMinInterval() {
        return this.minInterval;
    }

    public int getMaxInterval() {
        return this.maxInterval;
    }

    public int getNumCategories() {
        return this.seed.getPayload().getLearner().numCategories();
    }

    public PriorFunction getPrior() {
        return this.seed.getPayload().getLearner().getPrior();
    }

    public void setBuffer(List<TrainingExample> buffer) {
        this.buffer = buffer;
    }

    public List<TrainingExample> getBuffer() {
        return this.buffer;
    }

    public EvolutionaryProcess<Wrapper, CrossFoldLearner> getEp() {
        return this.ep;
    }

    public void setEp(EvolutionaryProcess<Wrapper, CrossFoldLearner> ep) {
        this.ep = ep;
    }

    public State<Wrapper, CrossFoldLearner> getSeed() {
        return this.seed;
    }

    public void setSeed(State<Wrapper, CrossFoldLearner> seed) {
        this.seed = seed;
    }

    public int getNumFeatures() {
        return this.numFeatures;
    }

    public void setAveragingWindow(int averagingWindow) {
        this.seed.getPayload().getLearner().setWindowSize(averagingWindow);
        this.setupOptimizer(this.poolSize);
    }

    public void setFreezeSurvivors(boolean freezeSurvivors) {
        this.freezeSurvivors = freezeSurvivors;
    }

    public void write(DataOutput out) throws IOException {
        out.writeInt(this.record);
        out.writeInt(this.cutoff);
        out.writeInt(this.minInterval);
        out.writeInt(this.maxInterval);
        out.writeInt(this.currentStep);
        out.writeInt(this.bufferSize);
        out.writeInt(this.buffer.size());
        for (TrainingExample example : this.buffer) {
            example.write(out);
        }
        this.ep.write(out);
        this.best.write(out);
        out.writeInt(this.threadCount);
        out.writeInt(this.poolSize);
        this.seed.write(out);
        out.writeInt(this.numFeatures);
        out.writeBoolean(this.freezeSurvivors);
    }

    public void readFields(DataInput in) throws IOException {
        this.record = in.readInt();
        this.cutoff = in.readInt();
        this.minInterval = in.readInt();
        this.maxInterval = in.readInt();
        this.currentStep = in.readInt();
        this.bufferSize = in.readInt();
        int n = in.readInt();
        this.buffer = Lists.newArrayList();
        for (int i = 0; i < n; ++i) {
            TrainingExample example = new TrainingExample();
            example.readFields(in);
            this.buffer.add(example);
        }
        this.ep = new EvolutionaryProcess();
        this.ep.readFields(in);
        this.best = new State();
        this.best.readFields(in);
        this.threadCount = in.readInt();
        this.poolSize = in.readInt();
        this.seed = new State();
        this.seed.readFields(in);
        this.numFeatures = in.readInt();
        this.freezeSurvivors = in.readBoolean();
    }

    public static class TrainingExample
    implements Writable {
        private long key;
        private String groupKey;
        private int actual;
        private Vector instance;

        private TrainingExample() {
        }

        public TrainingExample(long key, String groupKey, int actual, Vector instance) {
            this.key = key;
            this.groupKey = groupKey;
            this.actual = actual;
            this.instance = instance;
        }

        public long getKey() {
            return this.key;
        }

        public int getActual() {
            return this.actual;
        }

        public Vector getInstance() {
            return this.instance;
        }

        public String getGroupKey() {
            return this.groupKey;
        }

        public void write(DataOutput out) throws IOException {
            out.writeLong(this.key);
            if (this.groupKey != null) {
                out.writeBoolean(true);
                out.writeUTF(this.groupKey);
            } else {
                out.writeBoolean(false);
            }
            out.writeInt(this.actual);
            VectorWritable.writeVector(out, this.instance, true);
        }

        public void readFields(DataInput in) throws IOException {
            this.key = in.readLong();
            if (in.readBoolean()) {
                this.groupKey = in.readUTF();
            }
            this.actual = in.readInt();
            this.instance = VectorWritable.readVector(in);
        }
    }

    public static class Wrapper
    implements Payload<CrossFoldLearner> {
        private CrossFoldLearner wrapped;

        public Wrapper() {
        }

        public Wrapper(int numCategories, int numFeatures, PriorFunction prior) {
            this.wrapped = new CrossFoldLearner(5, numCategories, numFeatures, prior);
        }

        public Wrapper copy() {
            Wrapper r = new Wrapper();
            r.wrapped = this.wrapped.copy();
            return r;
        }

        @Override
        public void update(double[] params) {
            int i = 0;
            this.wrapped.lambda(params[i++]);
            this.wrapped.learningRate(params[i]);
            this.wrapped.stepOffset(1);
            this.wrapped.alpha(1.0);
            this.wrapped.decayExponent(0.0);
        }

        public static void freeze(State<Wrapper, CrossFoldLearner> s) {
            double[] params = s.getParams();
            params[1] = params[1] - 10.0;
            s.setOmni(s.getOmni() / 20.0);
            double[] step = s.getStep();
            int i = 0;
            while (i < step.length) {
                int n = i++;
                step[n] = step[n] / 20.0;
            }
        }

        public static void setMappings(State<Wrapper, CrossFoldLearner> x) {
            int i = 0;
            x.setMap(i++, Mapping.logLimit(1.0E-8, 0.1));
            x.setMap(i, Mapping.logLimit(1.0E-8, 1.0));
        }

        public void train(TrainingExample example) {
            this.wrapped.train(example.getKey(), example.getGroupKey(), example.getActual(), example.getInstance());
        }

        public CrossFoldLearner getLearner() {
            return this.wrapped;
        }

        public String toString() {
            return String.format(Locale.ENGLISH, "auc=%.2f", this.wrapped.auc());
        }

        public void setAucEvaluator(OnlineAuc auc) {
            this.wrapped.setAucEvaluator(auc);
        }

        public void write(DataOutput out) throws IOException {
            this.wrapped.write(out);
        }

        public void readFields(DataInput input) throws IOException {
            this.wrapped = new CrossFoldLearner();
            this.wrapped.readFields(input);
        }
    }
}

