/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.streaming.cluster;

import java.util.Arrays;
import java.util.List;
import org.apache.mahout.clustering.ClusteringUtils;
import org.apache.mahout.clustering.streaming.cluster.DataUtils;
import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.FastProjectionSearch;
import org.apache.mahout.math.neighborhood.ProjectionSearch;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
import org.apache.mahout.math.random.WeightedThing;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class StreamingKMeansTest {
    private static final int NUM_DATA_POINTS = 65536;
    private static final int NUM_DIMENSIONS = 6;
    private static final int NUM_PROJECTIONS = 2;
    private static final int SEARCH_SIZE = 10;
    private static Pair<List<Centroid>, List<Centroid>> syntheticData;
    private UpdatableSearcher searcher;
    private boolean allAtOnce;

    @Before
    public void setUp() {
        RandomUtils.useTestSeed();
        syntheticData = DataUtils.sampleMultiNormalHypercube(6, 65536);
    }

    public StreamingKMeansTest(UpdatableSearcher searcher, boolean allAtOnce) {
        this.searcher = searcher;
        this.allAtOnce = allAtOnce;
    }

    @Parameterized.Parameters
    public static List<Object[]> generateData() {
        return Arrays.asList({new ProjectionSearch((DistanceMeasure)new SquaredEuclideanDistanceMeasure(), 2, 10), true}, {new FastProjectionSearch((DistanceMeasure)new SquaredEuclideanDistanceMeasure(), 2, 10), true}, {new ProjectionSearch((DistanceMeasure)new SquaredEuclideanDistanceMeasure(), 2, 10), false}, {new FastProjectionSearch((DistanceMeasure)new SquaredEuclideanDistanceMeasure(), 2, 10), false});
    }

    @Test
    public void testAverageDistanceCutoff() {
        double avgDistanceCutoff = 0.0;
        double avgNumClusters = 0.0;
        int numTests = 1;
        System.out.printf("Distance cutoff for %s\n", this.searcher.getClass().getName());
        for (int i = 0; i < numTests; ++i) {
            this.searcher.clear();
            int numStreamingClusters = (int)Math.log(((List)syntheticData.getFirst()).size()) * 64;
            double distanceCutoff = 1.0E-6;
            double estimatedCutoff = ClusteringUtils.estimateDistanceCutoff((Iterable)((Iterable)syntheticData.getFirst()), (DistanceMeasure)this.searcher.getDistanceMeasure(), (int)100);
            System.out.printf("[%d] Generated synthetic data [magic] %f [estimate] %f\n", i, distanceCutoff, estimatedCutoff);
            StreamingKMeans clusterer = new StreamingKMeans(this.searcher, numStreamingClusters, estimatedCutoff);
            clusterer.cluster((Iterable)syntheticData.getFirst());
            avgDistanceCutoff += clusterer.getDistanceCutoff();
            avgNumClusters += (double)clusterer.getNumClusters();
            System.out.printf("[%d] %f\n", i, clusterer.getDistanceCutoff());
        }
        System.out.printf("Final: distanceCutoff: %f estNumClusters: %f\n", avgDistanceCutoff /= (double)numTests, avgNumClusters /= (double)numTests);
    }

    @Test
    public void testClustering() {
        this.searcher.clear();
        int numStreamingClusters = (int)Math.log(((List)syntheticData.getFirst()).size()) * 64;
        System.out.printf("k log n = %d\n", numStreamingClusters);
        double estimatedCutoff = ClusteringUtils.estimateDistanceCutoff((Iterable)((Iterable)syntheticData.getFirst()), (DistanceMeasure)this.searcher.getDistanceMeasure(), (int)100);
        StreamingKMeans clusterer = new StreamingKMeans(this.searcher, numStreamingClusters, estimatedCutoff);
        long startTime = System.currentTimeMillis();
        if (this.allAtOnce) {
            clusterer.cluster((Iterable)syntheticData.getFirst());
        } else {
            for (Centroid datapoint : (List)syntheticData.getFirst()) {
                clusterer.cluster(datapoint);
            }
        }
        long endTime = System.currentTimeMillis();
        System.out.printf("%s %s\n", this.searcher.getClass().getName(), this.searcher.getDistanceMeasure().getClass().getName());
        System.out.printf("Total number of clusters %d\n", clusterer.getNumClusters());
        System.out.printf("Weights: %f %f\n", ClusteringUtils.totalWeight((Iterable)((Iterable)syntheticData.getFirst())), ClusteringUtils.totalWeight((Iterable)clusterer));
        Assert.assertEquals((String)"Total weight not preserved", (double)ClusteringUtils.totalWeight((Iterable)((Iterable)syntheticData.getFirst())), (double)ClusteringUtils.totalWeight((Iterable)clusterer), (double)1.0E-9);
        double maxWeight = 0.0;
        for (Vector mean : (List)syntheticData.getSecond()) {
            WeightedThing v = (WeightedThing)this.searcher.search(mean, 1).get(0);
            maxWeight = Math.max(v.getWeight(), maxWeight);
        }
        Assert.assertTrue((String)("Maximum weight too large " + maxWeight), (maxWeight < 0.05 ? 1 : 0) != 0);
        double clusterTime = (double)(endTime - startTime) / 1000.0;
        System.out.printf("%s\n%.2f for clustering\n%.1f us per row\n\n", this.searcher.getClass().getName(), clusterTime, clusterTime / (double)((List)syntheticData.getFirst()).size() * 1000000.0);
        double[] cornerWeights = new double[64];
        BruteSearch trueFinder = new BruteSearch((DistanceMeasure)new EuclideanDistanceMeasure());
        for (Vector trueCluster : (List)syntheticData.getSecond()) {
            trueFinder.add(trueCluster);
        }
        for (Centroid centroid : clusterer) {
            WeightedThing closest = (WeightedThing)trueFinder.search((Vector)centroid, 1).get(0);
            int n = ((Centroid)closest.getValue()).getIndex();
            cornerWeights[n] = cornerWeights[n] + centroid.getWeight();
        }
        int expectedNumPoints = 1024;
        for (double v : cornerWeights) {
            System.out.printf("%f ", v);
        }
        System.out.println();
        for (double v : cornerWeights) {
            Assert.assertEquals((double)expectedNumPoints, (double)v, (double)0.0);
        }
    }
}

