/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.cf.taste.impl.recommender.svd;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.recommender.svd.AbstractFactorizer;
import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.cf.taste.model.PreferenceArray;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelSGDFactorizer
extends AbstractFactorizer {
    private final DataModel dataModel;
    private final double lambda;
    private final int rank;
    private final int numEpochs;
    private int numThreads;
    private double mu0 = 0.01;
    private double decayFactor = 1.0;
    private int stepOffset = 0;
    private double forgettingExponent = 0.0;
    private double biasMuRatio = 0.5;
    private double biasLambdaRatio = 0.1;
    protected volatile double[][] userVectors;
    protected volatile double[][] itemVectors;
    private final PreferenceShuffler shuffler;
    private int epoch = 1;
    private static final int USER_BIAS_INDEX = 1;
    private static final int ITEM_BIAS_INDEX = 2;
    private static final int FEATURE_OFFSET = 3;
    private static final double NOISE = 0.02;
    private static final Logger logger = LoggerFactory.getLogger(ParallelSGDFactorizer.class);

    public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numEpochs) throws TasteException {
        super(dataModel);
        this.dataModel = dataModel;
        this.rank = numFeatures + 3;
        this.lambda = lambda;
        this.numEpochs = numEpochs;
        this.shuffler = new PreferenceShuffler(dataModel);
        this.numThreads = Math.min(Runtime.getRuntime().availableProcessors(), (int)Math.pow(this.shuffler.size(), 0.25));
    }

    public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, double mu0, double decayFactor, int stepOffset, double forgettingExponent) throws TasteException {
        this(dataModel, numFeatures, lambda, numIterations);
        this.mu0 = mu0;
        this.decayFactor = decayFactor;
        this.stepOffset = stepOffset;
        this.forgettingExponent = forgettingExponent;
    }

    public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, double mu0, double decayFactor, int stepOffset, double forgettingExponent, int numThreads) throws TasteException {
        this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent);
        this.numThreads = numThreads;
    }

    public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, double mu0, double decayFactor, int stepOffset, double forgettingExponent, double biasMuRatio, double biasLambdaRatio) throws TasteException {
        this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent);
        this.biasMuRatio = biasMuRatio;
        this.biasLambdaRatio = biasLambdaRatio;
    }

    public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, double mu0, double decayFactor, int stepOffset, double forgettingExponent, double biasMuRatio, double biasLambdaRatio, int numThreads) throws TasteException {
        this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent, biasMuRatio, biasLambdaRatio);
        this.numThreads = numThreads;
    }

    protected void initialize() throws TasteException {
        int feature;
        RandomWrapper random = RandomUtils.getRandom();
        this.userVectors = new double[this.dataModel.getNumUsers()][this.rank];
        this.itemVectors = new double[this.dataModel.getNumItems()][this.rank];
        double globalAverage = this.getAveragePreference();
        for (int userIndex = 0; userIndex < this.userVectors.length; ++userIndex) {
            this.userVectors[userIndex][0] = globalAverage;
            this.userVectors[userIndex][1] = 0.0;
            this.userVectors[userIndex][2] = 1.0;
            for (feature = 3; feature < this.rank; ++feature) {
                this.userVectors[userIndex][feature] = random.nextGaussian() * 0.02;
            }
        }
        for (int itemIndex = 0; itemIndex < this.itemVectors.length; ++itemIndex) {
            this.itemVectors[itemIndex][0] = 1.0;
            this.itemVectors[itemIndex][1] = 1.0;
            this.itemVectors[itemIndex][2] = 0.0;
            for (feature = 3; feature < this.rank; ++feature) {
                this.itemVectors[itemIndex][feature] = random.nextGaussian() * 0.02;
            }
        }
    }

    private double getMu(int i) {
        return this.mu0 * Math.pow(this.decayFactor, i - 1) * Math.pow(i + this.stepOffset, this.forgettingExponent);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Factorization factorize() throws TasteException {
        this.initialize();
        if (logger.isInfoEnabled()) {
            logger.info("starting to compute the factorization...");
        }
        this.epoch = 1;
        while (this.epoch <= this.numEpochs) {
            this.shuffler.stage();
            final double mu = this.getMu(this.epoch);
            int subSize = this.shuffler.size() / this.numThreads + 1;
            ExecutorService executor = Executors.newFixedThreadPool(this.numThreads);
            try {
                for (int t = 0; t < this.numThreads; ++t) {
                    final int iStart = t * subSize;
                    final int iEnd = Math.min((t + 1) * subSize, this.shuffler.size());
                    executor.execute(new Runnable(){

                        @Override
                        public void run() {
                            for (int i = iStart; i < iEnd; ++i) {
                                ParallelSGDFactorizer.this.update(ParallelSGDFactorizer.this.shuffler.get(i), mu);
                            }
                        }
                    });
                }
            }
            finally {
                executor.shutdown();
                this.shuffler.shuffle();
                try {
                    boolean terminated = executor.awaitTermination(this.numEpochs * this.shuffler.size(), TimeUnit.MICROSECONDS);
                    if (!terminated) {
                        logger.error("subtasks takes forever, return anyway");
                    }
                }
                catch (InterruptedException e) {
                    throw new TasteException("waiting fof termination interrupted", e);
                }
            }
            ++this.epoch;
        }
        return this.createFactorization(this.userVectors, this.itemVectors);
    }

    double getAveragePreference() throws TasteException {
        FullRunningAverage average = new FullRunningAverage();
        LongPrimitiveIterator it = this.dataModel.getUserIDs();
        while (it.hasNext()) {
            for (Preference pref : this.dataModel.getPreferencesFromUser(it.nextLong())) {
                average.addDatum(pref.getValue());
            }
        }
        return average.getAverage();
    }

    protected void update(Preference preference, double mu) {
        int userIndex = this.userIndex(preference.getUserID());
        int itemIndex = this.itemIndex(preference.getItemID());
        double[] userVector = this.userVectors[userIndex];
        double[] itemVector = this.itemVectors[itemIndex];
        double prediction = this.dot(userVector, itemVector);
        double err = (double)preference.getValue() - prediction;
        int k = 3;
        while (k < this.rank) {
            double userFeature = userVector[k];
            double itemFeature = itemVector[k];
            int n = k;
            userVector[n] = userVector[n] + mu * (err * itemFeature - this.lambda * userFeature);
            int n2 = k++;
            itemVector[n2] = itemVector[n2] + mu * (err * userFeature - this.lambda * itemFeature);
        }
        userVector[1] = userVector[1] + this.biasMuRatio * mu * (err - this.biasLambdaRatio * this.lambda * userVector[1]);
        itemVector[2] = itemVector[2] + this.biasMuRatio * mu * (err - this.biasLambdaRatio * this.lambda * itemVector[2]);
    }

    private double dot(double[] userVector, double[] itemVector) {
        double sum = 0.0;
        for (int k = 0; k < this.rank; ++k) {
            sum += userVector[k] * itemVector[k];
        }
        return sum;
    }

    protected static class PreferenceShuffler {
        private Preference[] preferences;
        private Preference[] unstagedPreferences;
        protected final RandomWrapper random = RandomUtils.getRandom();

        public PreferenceShuffler(DataModel dataModel) throws TasteException {
            this.cachePreferences(dataModel);
            this.shuffle();
            this.stage();
        }

        private int countPreferences(DataModel dataModel) throws TasteException {
            int numPreferences = 0;
            LongPrimitiveIterator userIDs = dataModel.getUserIDs();
            while (userIDs.hasNext()) {
                PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userIDs.nextLong());
                numPreferences += preferencesFromUser.length();
            }
            return numPreferences;
        }

        private void cachePreferences(DataModel dataModel) throws TasteException {
            int numPreferences = this.countPreferences(dataModel);
            this.preferences = new Preference[numPreferences];
            LongPrimitiveIterator userIDs = dataModel.getUserIDs();
            int index = 0;
            while (userIDs.hasNext()) {
                long userID = userIDs.nextLong();
                PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userID);
                for (Preference preference : preferencesFromUser) {
                    this.preferences[index++] = preference;
                }
            }
        }

        public void shuffle() {
            this.unstagedPreferences = (Preference[])this.preferences.clone();
            for (int i = this.unstagedPreferences.length - 1; i > 0; --i) {
                int rand = this.random.nextInt(i + 1);
                this.swapCachedPreferences(i, rand);
            }
        }

        private void swapCachedPreferences(int x, int y) {
            Preference p = this.unstagedPreferences[x];
            this.unstagedPreferences[x] = this.unstagedPreferences[y];
            this.unstagedPreferences[y] = p;
        }

        public void stage() {
            this.preferences = this.unstagedPreferences;
        }

        public Preference get(int i) {
            return this.preferences[i];
        }

        public int size() {
            return this.preferences.length;
        }
    }
}

