/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.lda.cvb;

import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.lda.cvb.CVB0Driver;
import org.apache.mahout.clustering.lda.cvb.InMemoryCollapsedVariationalBayes0;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixUtils;
import org.apache.mahout.math.VectorIterable;
import org.apache.mahout.math.function.DoubleFunction;
import org.junit.Test;

public final class TestCVBModelTrainer
extends MahoutTestCase {
    private static final double ETA = 0.1;
    private static final double ALPHA = 0.1;

    @Test
    public void testInMemoryCVB0() throws Exception {
        String[] terms = new String[26];
        for (int i = 0; i < terms.length; ++i) {
            terms[i] = String.valueOf((char)(i + 97));
        }
        int numGeneratingTopics = 3;
        int numTerms = 26;
        Matrix matrix = ClusteringTestUtils.randomStructuredModel(numGeneratingTopics, numTerms, new DoubleFunction(){

            public double apply(double d) {
                return 1.0 / Math.pow(d + 1.0, 2.0);
            }
        });
        int numDocs = 100;
        int numSamples = 20;
        int numTopicsPerDoc = 1;
        Matrix sampledCorpus = ClusteringTestUtils.sampledCorpus(matrix, (Random)RandomUtils.getRandom(), numDocs, numSamples, numTopicsPerDoc);
        ArrayList perplexities = Lists.newArrayList();
        int numTrials = 1;
        for (int numTestTopics = 1; numTestTopics < 2 * numGeneratingTopics; ++numTestTopics) {
            double[] perps = new double[numTrials];
            for (int trial = 0; trial < numTrials; ++trial) {
                InMemoryCollapsedVariationalBayes0 cvb = new InMemoryCollapsedVariationalBayes0(sampledCorpus, terms, numTestTopics, 0.1, 0.1, 2, 1, 0.0);
                cvb.setVerbose(true);
                perps[trial] = cvb.iterateUntilConvergence(0.0, 5, 0, 0.2);
                System.out.println(perps[trial]);
            }
            Arrays.sort(perps);
            System.out.println(Arrays.toString(perps));
            perplexities.add(perps[0]);
        }
        System.out.println(Joiner.on((String)",").join((Iterable)perplexities));
    }

    @Test
    public void testRandomStructuredModelViaMR() throws Exception {
        int startTopic;
        int numGeneratingTopics = 3;
        int numTerms = 9;
        Matrix matrix = ClusteringTestUtils.randomStructuredModel(numGeneratingTopics, numTerms, new DoubleFunction(){

            public double apply(double d) {
                return 1.0 / Math.pow(d + 1.0, 3.0);
            }
        });
        int numDocs = 500;
        int numSamples = 10;
        int numTopicsPerDoc = 1;
        Matrix sampledCorpus = ClusteringTestUtils.sampledCorpus(matrix, RandomUtils.getRandom((long)1234L), numDocs, numSamples, numTopicsPerDoc);
        Path sampleCorpusPath = this.getTestTempDirPath("corpus");
        Configuration configuration = this.getConfiguration();
        MatrixUtils.write((Path)sampleCorpusPath, (Configuration)configuration, (VectorIterable)sampledCorpus);
        int numIterations = 5;
        ArrayList perplexities = Lists.newArrayList();
        for (int numTestTopics = startTopic = numGeneratingTopics - 1; numTestTopics < numGeneratingTopics + 2; ++numTestTopics) {
            Path topicModelStateTempPath = this.getTestTempDirPath("topicTemp" + numTestTopics);
            Configuration conf = this.getConfiguration();
            CVB0Driver cvb0Driver = new CVB0Driver();
            cvb0Driver.run(conf, sampleCorpusPath, null, numTestTopics, numTerms, 0.1, 0.1, numIterations, 1, 0.0, null, null, topicModelStateTempPath, 1234L, 0.2f, 2, 1, 3, 1, false);
            perplexities.add(TestCVBModelTrainer.lowestPerplexity(conf, topicModelStateTempPath));
        }
        int bestTopic = -1;
        double lowestPerplexity = Double.MAX_VALUE;
        for (int t = 0; t < perplexities.size(); ++t) {
            if (!((Double)perplexities.get(t) < lowestPerplexity)) continue;
            lowestPerplexity = (Double)perplexities.get(t);
            bestTopic = t + startTopic;
        }
        TestCVBModelTrainer.assertEquals((String)"The optimal number of topics is not that of the generating distribution", (long)4L, (long)bestTopic);
        System.out.println("Perplexities: " + Joiner.on((String)", ").join((Iterable)perplexities));
    }

    private static double lowestPerplexity(Configuration conf, Path topicModelTemp) throws IOException {
        double current;
        double lowest = Double.MAX_VALUE;
        int iteration = 2;
        while (!Double.isNaN(current = CVB0Driver.readPerplexity((Configuration)conf, (Path)topicModelTemp, (int)iteration))) {
            lowest = Math.min(current, lowest);
            ++iteration;
        }
        return lowest;
    }
}

