/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.naivebayes;

import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;

public abstract class NaiveBayesTestBase
extends MahoutTestCase {
    private NaiveBayesModel standardModel;
    private NaiveBayesModel complementaryModel;

    @Override
    public void setUp() throws Exception {
        super.setUp();
        this.standardModel = NaiveBayesTestBase.createStandardNaiveBayesModel();
        this.standardModel.validate();
        this.complementaryModel = NaiveBayesTestBase.createComplementaryNaiveBayesModel();
        this.complementaryModel.validate();
    }

    protected NaiveBayesModel getStandardModel() {
        return this.standardModel;
    }

    protected NaiveBayesModel getComplementaryModel() {
        return this.complementaryModel;
    }

    protected static double complementaryNaiveBayesThetaWeight(int label, Matrix weightMatrix, Vector labelSum, Vector featureSum) {
        double weight = 0.0;
        double alpha = 1.0;
        for (int i = 0; i < featureSum.size(); ++i) {
            double score = weightMatrix.get(i, label);
            double lSum = labelSum.get(label);
            double fSum = featureSum.get(i);
            double totalSum = featureSum.zSum();
            double numerator = fSum - score + alpha;
            double denominator = totalSum - lSum + (double)featureSum.size();
            weight += Math.abs(Math.log(numerator / denominator));
        }
        return weight;
    }

    protected static double naiveBayesThetaWeight(int label, Matrix weightMatrix, Vector labelSum, Vector featureSum) {
        double weight = 0.0;
        double alpha = 1.0;
        for (int feature = 0; feature < featureSum.size(); ++feature) {
            double score = weightMatrix.get(feature, label);
            double lSum = labelSum.get(label);
            double numerator = score + alpha;
            double denominator = lSum + (double)featureSum.size();
            weight += Math.abs(Math.log(numerator / denominator));
        }
        return weight;
    }

    protected static NaiveBayesModel createStandardNaiveBayesModel() {
        double[][] matrix = new double[][]{{0.7, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1}, {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
        double[] labelSumArray = new double[]{1.2, 1.0, 1.0, 1.0};
        double[] featureSumArray = new double[]{1.3, 0.6, 1.1, 1.2};
        DenseMatrix weightMatrix = new DenseMatrix((double[][])matrix);
        DenseVector labelSum = new DenseVector(labelSumArray);
        DenseVector featureSum = new DenseVector(featureSumArray);
        return new NaiveBayesModel((Matrix)weightMatrix, (Vector)featureSum, (Vector)labelSum, null, 1.0f, false);
    }

    protected static NaiveBayesModel createComplementaryNaiveBayesModel() {
        double[][] matrix = new double[][]{{0.7, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1}, {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
        double[] labelSumArray = new double[]{1.2, 1.0, 1.0, 1.0};
        double[] featureSumArray = new double[]{1.3, 0.6, 1.1, 1.2};
        DenseMatrix weightMatrix = new DenseMatrix((double[][])matrix);
        DenseVector labelSum = new DenseVector(labelSumArray);
        DenseVector featureSum = new DenseVector(featureSumArray);
        double[] thetaNormalizerSum = new double[]{NaiveBayesTestBase.complementaryNaiveBayesThetaWeight(0, (Matrix)weightMatrix, (Vector)labelSum, (Vector)featureSum), NaiveBayesTestBase.complementaryNaiveBayesThetaWeight(1, (Matrix)weightMatrix, (Vector)labelSum, (Vector)featureSum), NaiveBayesTestBase.complementaryNaiveBayesThetaWeight(2, (Matrix)weightMatrix, (Vector)labelSum, (Vector)featureSum), NaiveBayesTestBase.complementaryNaiveBayesThetaWeight(3, (Matrix)weightMatrix, (Vector)labelSum, (Vector)featureSum)};
        return new NaiveBayesModel((Matrix)weightMatrix, (Vector)featureSum, (Vector)labelSum, (Vector)new DenseVector(thetaNormalizerSum), 1.0f, true);
    }

    protected static int maxIndex(Vector instance) {
        int maxIndex = -1;
        double maxScore = -2.147483648E9;
        for (Vector.Element label : instance.all()) {
            if (!(label.get() >= maxScore)) continue;
            maxIndex = label.index();
            maxScore = label.get();
        }
        return maxIndex;
    }
}

