/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.math.solver;

import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MahoutTestCase;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorIterable;
import org.apache.mahout.math.solver.ConjugateGradientSolver;
import org.apache.mahout.math.solver.JacobiConditioner;
import org.apache.mahout.math.solver.Preconditioner;
import org.junit.Test;

public class TestConjugateGradientSolver
extends MahoutTestCase {
    @Test
    public void testConjugateGradientSolver() {
        Matrix a = TestConjugateGradientSolver.getA();
        Vector b = TestConjugateGradientSolver.getB();
        ConjugateGradientSolver solver = new ConjugateGradientSolver();
        Vector x = solver.solve((VectorIterable)a, b);
        TestConjugateGradientSolver.assertEquals((double)0.0, (double)Math.sqrt(a.times(x).getDistanceSquared(b)), (double)1.0E-6);
        TestConjugateGradientSolver.assertEquals((double)0.0, (double)solver.getResidualNorm(), (double)1.0E-9);
        TestConjugateGradientSolver.assertEquals((long)10L, (long)solver.getIterations());
    }

    @Test
    public void testConditionedConjugateGradientSolver() {
        Matrix a = TestConjugateGradientSolver.getIllConditionedMatrix();
        Vector b = TestConjugateGradientSolver.getB();
        JacobiConditioner conditioner = new JacobiConditioner(a);
        ConjugateGradientSolver solver = new ConjugateGradientSolver();
        Vector x = solver.solve((VectorIterable)a, b, null, 100, 1.0E-9);
        double distance = Math.sqrt(a.times(x).getDistanceSquared(b));
        TestConjugateGradientSolver.assertEquals((double)0.0, (double)distance, (double)1.0E-6);
        TestConjugateGradientSolver.assertEquals((double)0.0, (double)solver.getResidualNorm(), (double)1.0E-9);
        TestConjugateGradientSolver.assertEquals((long)16L, (long)solver.getIterations());
        Vector x2 = solver.solve((VectorIterable)a, b, (Preconditioner)conditioner, 100, 1.0E-9);
        distance = Math.sqrt(a.times(x2).getDistanceSquared(b));
        TestConjugateGradientSolver.assertEquals((double)0.0, (double)distance, (double)1.0E-6);
        TestConjugateGradientSolver.assertEquals((double)0.0, (double)solver.getResidualNorm(), (double)1.0E-9);
        TestConjugateGradientSolver.assertEquals((long)15L, (long)solver.getIterations());
    }

    @Test
    public void testEarlyStop() {
        Vector b;
        ConjugateGradientSolver solver;
        Vector x;
        Matrix a = TestConjugateGradientSolver.getA();
        double distance = Math.sqrt(a.times(x = (solver = new ConjugateGradientSolver()).solve((VectorIterable)a, b = TestConjugateGradientSolver.getB(), null, 10, 0.1)).getDistanceSquared(b));
        TestConjugateGradientSolver.assertTrue((distance > 1.0E-6 ? 1 : 0) != 0);
        TestConjugateGradientSolver.assertEquals((double)0.0, (double)distance, (double)0.1);
        TestConjugateGradientSolver.assertEquals((long)7L, (long)solver.getIterations());
        x = solver.solve((VectorIterable)a, b, null, 7, 1.0E-9);
        distance = Math.sqrt(a.times(x).getDistanceSquared(b));
        TestConjugateGradientSolver.assertTrue((distance > 1.0E-6 ? 1 : 0) != 0);
        TestConjugateGradientSolver.assertEquals((double)0.0, (double)distance, (double)0.1);
        TestConjugateGradientSolver.assertEquals((long)7L, (long)solver.getIterations());
    }

    private static Matrix getA() {
        return TestConjugateGradientSolver.reshape(new double[]{11.7155649822794, -0.7125253363083646, 4.647361396186018, 1.6020939468348456, -4.6789817799137134, -0.814041676343497, -4.5995617505618345, -1.174907004277534, -1.6747995811678336, 3.1922255171058342, -0.7125253363083646, 12.340057968399487, -2.6498099427000645, 0.5264507222630669, 0.3783428369189767, -2.117018615918881, 2.369513425219053, 3.8182131490333013, 6.528594229827035, 2.8564814419366353, 4.647361396186018, -2.6498099427000645, 16.13179339216685, -0.0409475448061225, 1.4805687075608227, -2.995807648462895, -2.5288893025027264, -0.9614557539842487, -2.2974738351519077, -1.5516184284572598, 1.6020939468348456, 0.5264507222630669, -0.0409475448061225, 4.194680212269448, -2.5210038046912198, 0.6634899962909317, 0.4036187419205338, -0.2829211393003727, -0.2283091172980954, 1.1253516563552464, -4.6789817799137134, 0.3783428369189767, 1.4805687075608227, -2.5210038046912198, 19.430736186273343, -2.5200132222091787, 2.374851197144451, 11.642659844330552, -0.1508136510863874, 4.347134388806351, -0.814041676343497, -2.117018615918881, -2.995807648462895, 0.6634899962909317, -2.5200132222091787, 7.671233441970075, -3.868777362950285, -3.045341871159153, -0.1155580876143619, -2.402545946742212, -4.5995617505618345, 2.369513425219053, -2.5288893025027264, 0.4036187419205338, 2.374851197144451, -3.868777362950285, 10.468166605747008, 1.652718086617123, 2.9341795819365384, -2.17081763727631, -1.174907004277534, 3.8182131490333013, -0.9614557539842487, -0.2829211393003727, 11.642659844330552, -3.045341871159153, 1.652718086617123, 16.005061693417623, 1.1689747208793086, 1.666509094595487, -1.6747995811678336, 6.528594229827035, -2.2974738351519077, -0.2283091172980954, -0.1508136510863874, -0.1155580876143619, 2.9341795819365384, 1.1689747208793086, 6.479432975163748, -1.9197339981871877, 3.1922255171058342, 2.8564814419366353, -1.5516184284572598, 1.1253516563552464, 4.347134388806351, -2.402545946742212, -2.17081763727631, 1.666509094595487, -1.9197339981871877, 18.91490213563446}, 10, 10);
    }

    private static Vector getB() {
        return new DenseVector(new double[]{-0.552252, 0.03843, 0.058392, -1.234496, 1.240369, 0.373649, 0.505113, 0.503723, 1.21534, -0.391908});
    }

    private static Matrix getIllConditionedMatrix() {
        return TestConjugateGradientSolver.reshape(new double[]{0.00695278043678842, 0.09911830022078683, 0.01309584636255063, 0.00652917453032394, 0.04337631487735064, 0.14232165273321387, 0.05808722912361313, -0.06591965049732287, 0.06055771542862332, 0.00577423310349649, 0.09911830022078683, 1.5007140241806143, 0.14988743575884242, 0.07195514527480981, 0.6374736234175272, 1.3071181902041469, 0.8215160938511595, -0.7261612552458794, 1.0349013600202295, 0.12800239664439328, 0.01309584636255063, 0.14988743575884242, 0.04068462583124965, 0.02147022047006482, 0.0738811358014665, 0.58070223915076, 0.11280336266257514, -0.21690068430020618, 0.04065087561300068, -0.00876895259593769, 0.00652917453032394, 0.07195514527480981, 0.02147022047006482, 0.01140105250542524, 0.03624164348693958, 0.31291554581393255, 0.05648457235205666, -0.1150758301607778, 0.01475756130709823, -0.00584453679519805, 0.04337631487735064, 0.6374736234175272, 0.07388113580146649, 0.03624164348693959, 0.2749154320076057, 0.7341054316874812, 0.36120630002843257, -0.36583546331208316, 0.41472509341940017, 0.0458145875825548, 0.14232165273321387, 1.3071181902041467, 0.58070223915076, 0.31291554581393255, 0.7341054316874812, 9.02536073121807, 1.254263855828831, -3.1618633512559464, -0.19740140818905436, -0.26613760880058035, 0.05808722912361314, 0.8215160938511595, 0.11280336266257514, 0.05648457235205667, 0.36120630002843257, 1.2542638558288313, 0.4866105845160682, -0.570305113365622, 0.491512804648181, 0.04428280690189127, -0.06591965049732286, -0.7261612552458794, -0.21690068430020618, -0.11507583016077781, -0.36583546331208316, -3.1618633512559464, -0.570305113365622, 1.1627081503807895, -0.14837898963724327, 0.05917203395002889, 0.06055771542862331, 1.0349013600202293, 0.04065087561300068, 0.01475756130709823, 0.4147250934194002, -0.19740140818905436, 0.49151280464818103, -0.14837898963724327, 0.8669382068204972, 0.1408968875257034, 0.00577423310349649, 0.12800239664439328, -0.00876895259593769, -0.00584453679519805, 0.0458145875825548, -0.26613760880058035, 0.04428280690189126, 0.05917203395002889, 0.1408968875257034, 0.02901858439788401}, 10, 10);
    }

    private static Matrix reshape(double[] values, int rows, int columns) {
        DenseMatrix m = new DenseMatrix(rows, columns);
        int i = 0;
        for (double v : values) {
            m.set(i % rows, i / rows, v);
            ++i;
        }
        return m;
    }
}

