/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.math.als;

import java.util.Arrays;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MahoutTestCase;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
import org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver;
import org.apache.mahout.math.map.OpenIntObjectHashMap;
import org.junit.Test;

public class AlternatingLeastSquaresSolverTest
extends MahoutTestCase {
    @Test
    public void testYtY() {
        double[][] testMatrix = new double[][]{{1.0, 2.0, 3.0, 4.0, 5.0}, {1.0, 2.0, 3.0, 4.0, 5.0}, {1.0, 2.0, 3.0, 4.0, 5.0}, {1.0, 2.0, 3.0, 4.0, 5.0}, {1.0, 2.0, 3.0, 4.0, 5.0}};
        double[][] testMatrix2 = new double[][]{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {5.0, 4.0, 3.0, 2.0, 1.0, 7.0}, {1.0, 2.0, 3.0, 4.0, 5.0, 8.0}, {1.0, 2.0, 3.0, 4.0, 5.0, 8.0}, {11.0, 12.0, 13.0, 20.0, 27.0, 8.0}};
        double[][][] testData = new double[][][]{testMatrix, testMatrix2};
        for (int i = 0; i < testData.length; ++i) {
            DenseMatrix matrixToTest = new DenseMatrix(testData[i]);
            for (int j = 0; j < 100; ++j) {
                this.validateYtY((Matrix)matrixToTest, 4);
            }
            this.validateYtY((Matrix)matrixToTest, 1);
        }
    }

    private void validateYtY(Matrix matrixToTest, int numThreads) {
        OpenIntObjectHashMap<Vector> matrixToTestAsRowVectors = this.asRowVectors(matrixToTest);
        ImplicitFeedbackAlternatingLeastSquaresSolver solver = new ImplicitFeedbackAlternatingLeastSquaresSolver(matrixToTest.columnSize(), 1.0, 1.0, matrixToTestAsRowVectors, numThreads);
        Matrix yTy = matrixToTest.transpose().times(matrixToTest);
        Matrix shouldMatchyTy = solver.getYtransposeY(matrixToTestAsRowVectors);
        for (int row = 0; row < yTy.rowSize(); ++row) {
            for (int column = 0; column < yTy.columnSize(); ++column) {
                AlternatingLeastSquaresSolverTest.assertEquals((double)yTy.getQuick(row, column), (double)shouldMatchyTy.getQuick(row, column), (double)0.0);
            }
        }
    }

    private OpenIntObjectHashMap<Vector> asRowVectors(Matrix matrix) {
        OpenIntObjectHashMap rows = new OpenIntObjectHashMap();
        for (int row = 0; row < matrix.numRows(); ++row) {
            rows.put(row, (Object)matrix.viewRow(row).clone());
        }
        return rows;
    }

    @Test
    public void addLambdaTimesNuiTimesE() {
        int nui = 5;
        double lambda = 0.2;
        SparseMatrix matrix = new SparseMatrix(5, 5);
        AlternatingLeastSquaresSolver.addLambdaTimesNuiTimesE((Matrix)matrix, (double)lambda, (int)nui);
        for (int n = 0; n < 5; ++n) {
            AlternatingLeastSquaresSolverTest.assertEquals((double)1.0, (double)matrix.getQuick(n, n), (double)1.0E-6);
        }
    }

    @Test
    public void createMiIi() {
        DenseVector f1 = new DenseVector(new double[]{1.0, 2.0, 3.0});
        DenseVector f2 = new DenseVector(new double[]{4.0, 5.0, 6.0});
        Matrix miIi = AlternatingLeastSquaresSolver.createMiIi(Arrays.asList(f1, f2), (int)3);
        AlternatingLeastSquaresSolverTest.assertEquals((double)1.0, (double)miIi.getQuick(0, 0), (double)1.0E-6);
        AlternatingLeastSquaresSolverTest.assertEquals((double)2.0, (double)miIi.getQuick(1, 0), (double)1.0E-6);
        AlternatingLeastSquaresSolverTest.assertEquals((double)3.0, (double)miIi.getQuick(2, 0), (double)1.0E-6);
        AlternatingLeastSquaresSolverTest.assertEquals((double)4.0, (double)miIi.getQuick(0, 1), (double)1.0E-6);
        AlternatingLeastSquaresSolverTest.assertEquals((double)5.0, (double)miIi.getQuick(1, 1), (double)1.0E-6);
        AlternatingLeastSquaresSolverTest.assertEquals((double)6.0, (double)miIi.getQuick(2, 1), (double)1.0E-6);
    }

    @Test
    public void createRiIiMaybeTransposed() {
        SequentialAccessSparseVector ratings = new SequentialAccessSparseVector(3);
        ratings.setQuick(1, 1.0);
        ratings.setQuick(3, 3.0);
        ratings.setQuick(5, 5.0);
        Matrix riIiMaybeTransposed = AlternatingLeastSquaresSolver.createRiIiMaybeTransposed((Vector)ratings);
        AlternatingLeastSquaresSolverTest.assertEquals((float)1.0f, (float)riIiMaybeTransposed.numCols(), (float)1.0f);
        AlternatingLeastSquaresSolverTest.assertEquals((float)3.0f, (float)riIiMaybeTransposed.numRows(), (float)3.0f);
        AlternatingLeastSquaresSolverTest.assertEquals((double)1.0, (double)riIiMaybeTransposed.getQuick(0, 0), (double)1.0E-6);
        AlternatingLeastSquaresSolverTest.assertEquals((double)3.0, (double)riIiMaybeTransposed.getQuick(1, 0), (double)1.0E-6);
        AlternatingLeastSquaresSolverTest.assertEquals((double)5.0, (double)riIiMaybeTransposed.getQuick(2, 0), (double)1.0E-6);
    }

    @Test
    public void createRiIiMaybeTransposedExceptionOnNonSequentialVector() {
        RandomAccessSparseVector ratings = new RandomAccessSparseVector(3);
        ratings.setQuick(1, 1.0);
        ratings.setQuick(3, 3.0);
        ratings.setQuick(5, 5.0);
        try {
            AlternatingLeastSquaresSolver.createRiIiMaybeTransposed((Vector)ratings);
            AlternatingLeastSquaresSolverTest.fail();
        }
        catch (IllegalArgumentException illegalArgumentException) {
            // empty catch block
        }
    }
}

