/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.cf.taste.impl.recommender.svd;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.mahout.cf.taste.impl.TasteTestCase;
import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
import org.apache.mahout.cf.taste.impl.model.GenericPreference;
import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
import org.apache.mahout.cf.taste.impl.recommender.svd.ParallelSGDFactorizer;
import org.apache.mahout.cf.taste.impl.recommender.svd.SVDRecommender;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.cf.taste.model.PreferenceArray;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.VectorFunction;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelSGDFactorizerTest
extends TasteTestCase {
    protected DataModel dataModel;
    protected int rank;
    protected double lambda;
    protected int numIterations;
    private RandomWrapper random = RandomUtils.getRandom();
    protected Factorizer factorizer;
    protected SVDRecommender svdRecommender;
    private static final Logger logger = LoggerFactory.getLogger(ParallelSGDFactorizerTest.class);

    private Matrix randomMatrix(int numRows, int numColumns, double range) {
        double[][] data = new double[numRows][numColumns];
        for (int i = 0; i < numRows; ++i) {
            for (int j = 0; j < numColumns; ++j) {
                double sqrtUniform = this.random.nextDouble();
                data[i][j] = sqrtUniform * range;
            }
        }
        return new DenseMatrix(data);
    }

    private void normalize(Matrix source, final double range) {
        final double max = source.aggregateColumns(new VectorFunction(){

            public double apply(Vector column) {
                return column.maxValue();
            }
        }).maxValue();
        final double min = source.aggregateColumns(new VectorFunction(){

            public double apply(Vector column) {
                return column.minValue();
            }
        }).minValue();
        source.assign(new DoubleFunction(){

            public double apply(double value) {
                return (value - min) * range / (max - min);
            }
        });
    }

    public void setUpSyntheticData() throws Exception {
        int numUsers = 2000;
        int numItems = 1000;
        double sparsity = 0.5;
        this.rank = 20;
        this.lambda = 1.0E-9;
        this.numIterations = 100;
        Matrix users = this.randomMatrix(numUsers, this.rank, 1.0);
        Matrix items = this.randomMatrix(this.rank, numItems, 1.0);
        Matrix ratings = users.times(items);
        this.normalize(ratings, 5.0);
        FastByIDMap userData = new FastByIDMap();
        for (int userIndex = 0; userIndex < numUsers; ++userIndex) {
            ArrayList row = Lists.newArrayList();
            for (int itemIndex = 0; itemIndex < numItems; ++itemIndex) {
                if (!(this.random.nextDouble() <= sparsity)) continue;
                row.add(new GenericPreference((long)userIndex, (long)itemIndex, (float)ratings.get(userIndex, itemIndex)));
            }
            userData.put((long)userIndex, (Object)new GenericUserPreferenceArray((List)row));
        }
        this.dataModel = new GenericDataModel(userData);
    }

    public void setUpToyData() throws Exception {
        this.rank = 3;
        this.lambda = 0.01;
        this.numIterations = 1000;
        FastByIDMap userData = new FastByIDMap();
        userData.put(1L, (Object)new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(1L, 1L, 5.0f), new GenericPreference(1L, 2L, 5.0f), new GenericPreference(1L, 3L, 2.0f))));
        userData.put(2L, (Object)new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(2L, 1L, 2.0f), new GenericPreference(2L, 3L, 3.0f), new GenericPreference(2L, 4L, 5.0f))));
        userData.put(3L, (Object)new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(3L, 2L, 5.0f), new GenericPreference(3L, 4L, 3.0f))));
        userData.put(4L, (Object)new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(4L, 1L, 3.0f), new GenericPreference(4L, 4L, 5.0f))));
        this.dataModel = new GenericDataModel(userData);
    }

    @Test
    public void testPreferenceShufflerWithSyntheticData() throws Exception {
        this.setUpSyntheticData();
        ParallelSGDFactorizer.PreferenceShuffler shuffler = new ParallelSGDFactorizer.PreferenceShuffler(this.dataModel);
        shuffler.shuffle();
        shuffler.stage();
        FastByIDMap checked = new FastByIDMap();
        for (int i = 0; i < shuffler.size(); ++i) {
            Preference pref = shuffler.get(i);
            float value = this.dataModel.getPreferenceValue(pref.getUserID(), pref.getItemID()).floatValue();
            ParallelSGDFactorizerTest.assertEquals((double)pref.getValue(), (double)value, (double)0.0);
            if (!checked.containsKey(pref.getUserID())) {
                checked.put(pref.getUserID(), (Object)new FastByIDMap());
            }
            ParallelSGDFactorizerTest.assertNull((Object)((FastByIDMap)checked.get(pref.getUserID())).get(pref.getItemID()));
            ((FastByIDMap)checked.get(pref.getUserID())).put(pref.getItemID(), (Object)true);
        }
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        int index = 0;
        while (userIDs.hasNext()) {
            long userID = userIDs.nextLong();
            PreferenceArray preferencesFromUser = this.dataModel.getPreferencesFromUser(userID);
            for (Preference preference : preferencesFromUser) {
                ParallelSGDFactorizerTest.assertTrue((boolean)((Boolean)((FastByIDMap)checked.get(preference.getUserID())).get(preference.getItemID())));
                ++index;
            }
        }
        ParallelSGDFactorizerTest.assertEquals((long)index, (long)shuffler.size());
    }

    @ThreadLeakLingering(linger=1000)
    @Test
    public void testFactorizerWithToyData() throws Exception {
        double regularization;
        this.setUpToyData();
        long start = System.currentTimeMillis();
        this.factorizer = new ParallelSGDFactorizer(this.dataModel, this.rank, this.lambda, this.numIterations, 0.01, 1.0, 0, 0.0);
        Factorization factorization = this.factorizer.factorize();
        long duration = System.currentTimeMillis() - start;
        FullRunningAverage avg = new FullRunningAverage();
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            long userID = userIDs.nextLong();
            for (Preference pref : this.dataModel.getPreferencesFromUser(userID)) {
                double rating = pref.getValue();
                DenseVector userVector = new DenseVector(factorization.getUserFeatures(userID));
                DenseVector itemVector = new DenseVector(factorization.getItemFeatures(pref.getItemID()));
                double estimate = userVector.dot((Vector)itemVector);
                double err = rating - estimate;
                avg.addDatum(err * err);
            }
        }
        double sum = 0.0;
        userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            long userID = userIDs.nextLong();
            DenseVector userVector = new DenseVector(factorization.getUserFeatures(userID));
            regularization = userVector.dot((Vector)userVector);
            sum += regularization;
        }
        LongPrimitiveIterator itemIDs = this.dataModel.getItemIDs();
        while (itemIDs.hasNext()) {
            long itemID = itemIDs.nextLong();
            DenseVector itemVector = new DenseVector(factorization.getUserFeatures(itemID));
            regularization = itemVector.dot((Vector)itemVector);
            sum += regularization;
        }
        double rmse = Math.sqrt(avg.getAverage());
        double loss = avg.getAverage() / 2.0 + this.lambda / 2.0 * sum;
        logger.info("RMSE: " + rmse + ";\tLoss: " + loss + ";\tTime Used: " + duration);
        ParallelSGDFactorizerTest.assertTrue((rmse < 0.2 ? 1 : 0) != 0);
    }

    @ThreadLeakLingering(linger=1000)
    @Test
    public void testRecommenderWithToyData() throws Exception {
        this.setUpToyData();
        this.factorizer = new ParallelSGDFactorizer(this.dataModel, this.rank, this.lambda, this.numIterations, 0.01, 1.0, 0, 0.0);
        this.svdRecommender = new SVDRecommender(this.dataModel, this.factorizer);
        FullRunningAverage avg = new FullRunningAverage();
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            long userID = userIDs.nextLong();
            for (Preference pref : this.dataModel.getPreferencesFromUser(userID)) {
                double rating = pref.getValue();
                double estimate = this.svdRecommender.estimatePreference(userID, pref.getItemID());
                double err = rating - estimate;
                avg.addDatum(err * err);
            }
        }
        double rmse = Math.sqrt(avg.getAverage());
        logger.info("rmse: " + rmse);
        ParallelSGDFactorizerTest.assertTrue((rmse < 0.2 ? 1 : 0) != 0);
    }

    @Test
    public void testFactorizerWithWithSyntheticData() throws Exception {
        double regularization;
        this.setUpSyntheticData();
        long start = System.currentTimeMillis();
        this.factorizer = new ParallelSGDFactorizer(this.dataModel, this.rank, this.lambda, this.numIterations, 0.01, 1.0, 0, 0.0);
        Factorization factorization = this.factorizer.factorize();
        long duration = System.currentTimeMillis() - start;
        FullRunningAverage avg = new FullRunningAverage();
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            long userID = userIDs.nextLong();
            for (Preference pref : this.dataModel.getPreferencesFromUser(userID)) {
                double rating = pref.getValue();
                DenseVector userVector = new DenseVector(factorization.getUserFeatures(userID));
                DenseVector itemVector = new DenseVector(factorization.getItemFeatures(pref.getItemID()));
                double estimate = userVector.dot((Vector)itemVector);
                double err = rating - estimate;
                avg.addDatum(err * err);
            }
        }
        double sum = 0.0;
        userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            long userID = userIDs.nextLong();
            DenseVector userVector = new DenseVector(factorization.getUserFeatures(userID));
            regularization = userVector.dot((Vector)userVector);
            sum += regularization;
        }
        LongPrimitiveIterator itemIDs = this.dataModel.getItemIDs();
        while (itemIDs.hasNext()) {
            long itemID = itemIDs.nextLong();
            DenseVector itemVector = new DenseVector(factorization.getUserFeatures(itemID));
            regularization = itemVector.dot((Vector)itemVector);
            sum += regularization;
        }
        double rmse = Math.sqrt(avg.getAverage());
        double loss = avg.getAverage() / 2.0 + this.lambda / 2.0 * sum;
        logger.info("RMSE: " + rmse + ";\tLoss: " + loss + ";\tTime Used: " + duration + "ms");
        ParallelSGDFactorizerTest.assertTrue((rmse < 0.2 ? 1 : 0) != 0);
    }

    @Test
    public void testRecommenderWithSyntheticData() throws Exception {
        this.setUpSyntheticData();
        this.factorizer = new ParallelSGDFactorizer(this.dataModel, this.rank, this.lambda, this.numIterations, 0.01, 1.0, 0, 0.0);
        this.svdRecommender = new SVDRecommender(this.dataModel, this.factorizer);
        FullRunningAverage avg = new FullRunningAverage();
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            long userID = userIDs.nextLong();
            for (Preference pref : this.dataModel.getPreferencesFromUser(userID)) {
                double rating = pref.getValue();
                double estimate = this.svdRecommender.estimatePreference(userID, pref.getItemID());
                double err = rating - estimate;
                avg.addDatum(err * err);
            }
        }
        double rmse = Math.sqrt(avg.getAverage());
        logger.info("rmse: " + rmse);
        ParallelSGDFactorizerTest.assertTrue((rmse < 0.2 ? 1 : 0) != 0);
    }
}

