/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.cf.taste.impl.recommender.svd;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.recommender.svd.AbstractFactorizer;
import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.cf.taste.model.PreferenceArray;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
import org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver;
import org.apache.mahout.math.map.OpenIntObjectHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ALSWRFactorizer
extends AbstractFactorizer {
    private final DataModel dataModel;
    private final int numFeatures;
    private final double lambda;
    private final int numIterations;
    private final boolean usesImplicitFeedback;
    private final double alpha;
    private final int numTrainingThreads;
    private static final double DEFAULT_ALPHA = 40.0;
    private static final Logger log = LoggerFactory.getLogger(ALSWRFactorizer.class);

    public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, boolean usesImplicitFeedback, double alpha, int numTrainingThreads) throws TasteException {
        super(dataModel);
        this.dataModel = dataModel;
        this.numFeatures = numFeatures;
        this.lambda = lambda;
        this.numIterations = numIterations;
        this.usesImplicitFeedback = usesImplicitFeedback;
        this.alpha = alpha;
        this.numTrainingThreads = numTrainingThreads;
    }

    public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, boolean usesImplicitFeedback, double alpha) throws TasteException {
        this(dataModel, numFeatures, lambda, numIterations, usesImplicitFeedback, alpha, Runtime.getRuntime().availableProcessors());
    }

    public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations) throws TasteException {
        this(dataModel, numFeatures, lambda, numIterations, false, 40.0);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Factorization factorize() throws TasteException {
        log.info("starting to compute the factorization...");
        final Features features = new Features(this);
        OpenIntObjectHashMap<Vector> userY = null;
        OpenIntObjectHashMap<Vector> itemY = null;
        if (this.usesImplicitFeedback) {
            userY = this.userFeaturesMapping(this.dataModel.getUserIDs(), this.dataModel.getNumUsers(), features.getU());
            itemY = this.itemFeaturesMapping(this.dataModel.getItemIDs(), this.dataModel.getNumItems(), features.getM());
        }
        for (int iteration = 0; iteration < this.numIterations; ++iteration) {
            log.info("iteration {}", (Object)iteration);
            ExecutorService queue = this.createQueue();
            LongPrimitiveIterator userIDsIterator = this.dataModel.getUserIDs();
            try {
                ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackSolver;
                ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackAlternatingLeastSquaresSolver = implicitFeedbackSolver = this.usesImplicitFeedback ? new ImplicitFeedbackAlternatingLeastSquaresSolver(this.numFeatures, this.lambda, this.alpha, itemY, this.numTrainingThreads) : null;
                while (userIDsIterator.hasNext()) {
                    final long userID = userIDsIterator.nextLong();
                    final LongPrimitiveIterator itemIDsFromUser = this.dataModel.getItemIDsFromUser(userID).iterator();
                    final PreferenceArray userPrefs = this.dataModel.getPreferencesFromUser(userID);
                    queue.execute(new Runnable(){

                        @Override
                        public void run() {
                            ArrayList featureVectors = Lists.newArrayList();
                            while (itemIDsFromUser.hasNext()) {
                                long itemID = itemIDsFromUser.nextLong();
                                featureVectors.add(features.getItemFeatureColumn(ALSWRFactorizer.this.itemIndex(itemID)));
                            }
                            Vector userFeatures = ALSWRFactorizer.this.usesImplicitFeedback ? implicitFeedbackSolver.solve(ALSWRFactorizer.this.sparseUserRatingVector(userPrefs)) : AlternatingLeastSquaresSolver.solve((Iterable)featureVectors, (Vector)ALSWRFactorizer.ratingVector(userPrefs), (double)ALSWRFactorizer.this.lambda, (int)ALSWRFactorizer.this.numFeatures);
                            features.setFeatureColumnInU(ALSWRFactorizer.this.userIndex(userID), userFeatures);
                        }
                    });
                }
            }
            finally {
                queue.shutdown();
                try {
                    queue.awaitTermination(this.dataModel.getNumUsers(), TimeUnit.SECONDS);
                }
                catch (InterruptedException e) {
                    log.warn("Error when computing user features", (Throwable)e);
                }
            }
            queue = this.createQueue();
            LongPrimitiveIterator itemIDsIterator = this.dataModel.getItemIDs();
            try {
                ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackSolver;
                ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackAlternatingLeastSquaresSolver = implicitFeedbackSolver = this.usesImplicitFeedback ? new ImplicitFeedbackAlternatingLeastSquaresSolver(this.numFeatures, this.lambda, this.alpha, userY, this.numTrainingThreads) : null;
                while (itemIDsIterator.hasNext()) {
                    final long itemID = itemIDsIterator.nextLong();
                    final PreferenceArray itemPrefs = this.dataModel.getPreferencesForItem(itemID);
                    queue.execute(new Runnable(){

                        @Override
                        public void run() {
                            ArrayList featureVectors = Lists.newArrayList();
                            for (Preference pref : itemPrefs) {
                                long userID = pref.getUserID();
                                featureVectors.add(features.getUserFeatureColumn(ALSWRFactorizer.this.userIndex(userID)));
                            }
                            Vector itemFeatures = ALSWRFactorizer.this.usesImplicitFeedback ? implicitFeedbackSolver.solve(ALSWRFactorizer.this.sparseItemRatingVector(itemPrefs)) : AlternatingLeastSquaresSolver.solve((Iterable)featureVectors, (Vector)ALSWRFactorizer.ratingVector(itemPrefs), (double)ALSWRFactorizer.this.lambda, (int)ALSWRFactorizer.this.numFeatures);
                            features.setFeatureColumnInM(ALSWRFactorizer.this.itemIndex(itemID), itemFeatures);
                        }
                    });
                }
                continue;
            }
            finally {
                queue.shutdown();
                try {
                    queue.awaitTermination(this.dataModel.getNumItems(), TimeUnit.SECONDS);
                }
                catch (InterruptedException e) {
                    log.warn("Error when computing item features", (Throwable)e);
                }
            }
        }
        log.info("finished computation of the factorization...");
        return this.createFactorization(features.getU(), features.getM());
    }

    protected ExecutorService createQueue() {
        return Executors.newFixedThreadPool(this.numTrainingThreads);
    }

    protected static Vector ratingVector(PreferenceArray prefs) {
        double[] ratings = new double[prefs.length()];
        for (int n = 0; n < prefs.length(); ++n) {
            ratings[n] = prefs.get(n).getValue();
        }
        return new DenseVector(ratings, true);
    }

    protected OpenIntObjectHashMap<Vector> itemFeaturesMapping(LongPrimitiveIterator itemIDs, int numItems, double[][] featureMatrix) {
        OpenIntObjectHashMap mapping = new OpenIntObjectHashMap(numItems);
        while (itemIDs.hasNext()) {
            long itemID = (Long)itemIDs.next();
            mapping.put((int)itemID, (Object)new DenseVector(featureMatrix[this.itemIndex(itemID)], true));
        }
        return mapping;
    }

    protected OpenIntObjectHashMap<Vector> userFeaturesMapping(LongPrimitiveIterator userIDs, int numUsers, double[][] featureMatrix) {
        OpenIntObjectHashMap mapping = new OpenIntObjectHashMap(numUsers);
        while (userIDs.hasNext()) {
            long userID = (Long)userIDs.next();
            mapping.put((int)userID, (Object)new DenseVector(featureMatrix[this.userIndex(userID)], true));
        }
        return mapping;
    }

    protected Vector sparseItemRatingVector(PreferenceArray prefs) {
        SequentialAccessSparseVector ratings = new SequentialAccessSparseVector(Integer.MAX_VALUE, prefs.length());
        for (Preference preference : prefs) {
            ratings.set((int)preference.getUserID(), (double)preference.getValue());
        }
        return ratings;
    }

    protected Vector sparseUserRatingVector(PreferenceArray prefs) {
        SequentialAccessSparseVector ratings = new SequentialAccessSparseVector(Integer.MAX_VALUE, prefs.length());
        for (Preference preference : prefs) {
            ratings.set((int)preference.getItemID(), (double)preference.getValue());
        }
        return ratings;
    }

    static class Features {
        private final DataModel dataModel;
        private final int numFeatures;
        private final double[][] M;
        private final double[][] U;

        Features(ALSWRFactorizer factorizer) throws TasteException {
            this.dataModel = factorizer.dataModel;
            this.numFeatures = factorizer.numFeatures;
            RandomWrapper random = RandomUtils.getRandom();
            this.M = new double[this.dataModel.getNumItems()][this.numFeatures];
            LongPrimitiveIterator itemIDsIterator = this.dataModel.getItemIDs();
            while (itemIDsIterator.hasNext()) {
                long itemID = itemIDsIterator.nextLong();
                int itemIDIndex = factorizer.itemIndex(itemID);
                this.M[itemIDIndex][0] = this.averateRating(itemID);
                for (int feature = 1; feature < this.numFeatures; ++feature) {
                    this.M[itemIDIndex][feature] = random.nextDouble() * 0.1;
                }
            }
            this.U = new double[this.dataModel.getNumUsers()][this.numFeatures];
        }

        double[][] getM() {
            return this.M;
        }

        double[][] getU() {
            return this.U;
        }

        Vector getUserFeatureColumn(int index) {
            return new DenseVector(this.U[index]);
        }

        Vector getItemFeatureColumn(int index) {
            return new DenseVector(this.M[index]);
        }

        void setFeatureColumnInU(int idIndex, Vector vector) {
            this.setFeatureColumn(this.U, idIndex, vector);
        }

        void setFeatureColumnInM(int idIndex, Vector vector) {
            this.setFeatureColumn(this.M, idIndex, vector);
        }

        protected void setFeatureColumn(double[][] matrix, int idIndex, Vector vector) {
            for (int feature = 0; feature < this.numFeatures; ++feature) {
                matrix[idIndex][feature] = vector.get(feature);
            }
        }

        protected double averateRating(long itemID) throws TasteException {
            PreferenceArray prefs = this.dataModel.getPreferencesForItem(itemID);
            FullRunningAverage avg = new FullRunningAverage();
            for (Preference pref : prefs) {
                avg.addDatum(pref.getValue());
            }
            return avg.getAverage();
        }
    }
}

