/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.streaming.mapreduce;

import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mrunit.mapreduce.MapDriver;
import org.apache.hadoop.mrunit.mapreduce.MapReduceDriver;
import org.apache.hadoop.mrunit.mapreduce.ReduceDriver;
import org.apache.hadoop.mrunit.types.Pair;
import org.apache.mahout.clustering.ClusteringUtils;
import org.apache.mahout.clustering.streaming.cluster.DataUtils;
import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
import org.apache.mahout.clustering.streaming.mapreduce.CentroidWritable;
import org.apache.mahout.clustering.streaming.mapreduce.StreamingKMeansDriver;
import org.apache.mahout.clustering.streaming.mapreduce.StreamingKMeansMapper;
import org.apache.mahout.clustering.streaming.mapreduce.StreamingKMeansReducer;
import org.apache.mahout.clustering.streaming.mapreduce.StreamingKMeansUtilsMR;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.FastProjectionSearch;
import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch;
import org.apache.mahout.math.neighborhood.ProjectionSearch;
import org.apache.mahout.math.random.WeightedThing;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class StreamingKMeansTestMR
extends MahoutTestCase {
    private static final int NUM_DATA_POINTS = 32768;
    private static final int NUM_DIMENSIONS = 8;
    private static final int NUM_PROJECTIONS = 3;
    private static final int SEARCH_SIZE = 5;
    private static final int MAX_NUM_ITERATIONS = 10;
    private static final double DISTANCE_CUTOFF = 1.0E-6;
    private static org.apache.mahout.common.Pair<List<Centroid>, List<Centroid>> syntheticData;
    private final String searcherClassName;
    private final String distanceMeasureClassName;

    @Override
    @Before
    public void setUp() {
        RandomUtils.useTestSeed();
        syntheticData = DataUtils.sampleMultiNormalHypercube(8, 32768, 1.0E-4);
    }

    public StreamingKMeansTestMR(String searcherClassName, String distanceMeasureClassName) {
        this.searcherClassName = searcherClassName;
        this.distanceMeasureClassName = distanceMeasureClassName;
    }

    private void configure(Configuration configuration) {
        configuration.set("distanceMeasure", this.distanceMeasureClassName);
        configuration.setInt("searchSize", 5);
        configuration.setInt("numProjections", 3);
        configuration.set("searcherClass", this.searcherClassName);
        configuration.setInt("numClusters", 256);
        configuration.setInt("estimatedNumMapClusters", 256 * (int)Math.log(32768.0));
        configuration.setFloat("estimatedDistanceCutoff", 1.0E-6f);
        configuration.setInt("maxNumIterations", 10);
        configuration.setBoolean("reduceStreamingKMeans", true);
    }

    @Parameterized.Parameters
    public static List<Object[]> generateData() {
        return Arrays.asList({ProjectionSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName()}, {FastProjectionSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName()}, {LocalitySensitiveHashSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName()});
    }

    @Test
    public void testHypercubeMapper() throws IOException {
        MapDriver mapDriver = MapDriver.newMapDriver((Mapper)new StreamingKMeansMapper());
        this.configure(mapDriver.getConfiguration());
        System.out.printf("%s mapper test\n", mapDriver.getConfiguration().get("searcherClass"));
        for (Centroid datapoint : (List)syntheticData.getFirst()) {
            mapDriver.addInput((Object)new IntWritable(0), (Object)new VectorWritable((Vector)datapoint));
        }
        List results = mapDriver.run();
        BruteSearch resultSearcher = new BruteSearch((DistanceMeasure)new SquaredEuclideanDistanceMeasure());
        for (Pair result : results) {
            resultSearcher.add((Vector)((CentroidWritable)result.getSecond()).getCentroid());
        }
        System.out.printf("Clustered the data into %d clusters\n", results.size());
        for (Vector mean : (List)syntheticData.getSecond()) {
            WeightedThing closest = (WeightedThing)resultSearcher.search(mean, 1).get(0);
            StreamingKMeansTestMR.assertTrue((String)("Weight " + closest.getWeight() + " not less than 0.5"), (closest.getWeight() < 0.5 ? 1 : 0) != 0);
        }
    }

    @Test
    public void testMapperVsLocal() throws IOException {
        MapDriver mapDriver = MapDriver.newMapDriver((Mapper)new StreamingKMeansMapper());
        Configuration configuration = mapDriver.getConfiguration();
        this.configure(configuration);
        System.out.printf("%s mapper vs local test\n", mapDriver.getConfiguration().get("searcherClass"));
        for (Centroid datapoint : (List)syntheticData.getFirst()) {
            mapDriver.addInput((Object)new IntWritable(0), (Object)new VectorWritable((Vector)datapoint));
        }
        ArrayList mapperCentroids = Lists.newArrayList();
        for (Pair pair : mapDriver.run()) {
            mapperCentroids.add(((CentroidWritable)pair.getSecond()).getCentroid());
        }
        StreamingKMeans batchClusterer = new StreamingKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration((Configuration)configuration), mapDriver.getConfiguration().getInt("estimatedNumMapClusters", -1), 1.0E-6);
        batchClusterer.cluster((Iterable)syntheticData.getFirst());
        ArrayList batchCentroids = Lists.newArrayList();
        for (Vector v : batchClusterer) {
            batchCentroids.add((Centroid)v);
        }
        StreamingKMeans perPointClusterer = new StreamingKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration((Configuration)configuration), 256 * (int)Math.log(32768.0), 1.0E-6);
        for (Centroid datapoint : (List)syntheticData.getFirst()) {
            perPointClusterer.cluster(datapoint);
        }
        ArrayList perPointCentroids = Lists.newArrayList();
        for (Vector v : perPointClusterer) {
            perPointCentroids.add((Centroid)v);
        }
        double mapperCost = ClusteringUtils.totalClusterCost((Iterable)((Iterable)syntheticData.getFirst()), (Iterable)mapperCentroids);
        double localCost = ClusteringUtils.totalClusterCost((Iterable)((Iterable)syntheticData.getFirst()), (Iterable)batchCentroids);
        double perPointCost = ClusteringUtils.totalClusterCost((Iterable)((Iterable)syntheticData.getFirst()), (Iterable)perPointCentroids);
        System.out.printf("[Total cost] Mapper %f [%d] Local %f [%d] Perpoint local %f [%d];[ratio m-vs-l %f] [ratio pp-vs-l %f]\n", mapperCost, mapperCentroids.size(), localCost, batchCentroids.size(), perPointCost, perPointCentroids.size(), mapperCost / localCost, perPointCost / localCost);
        StreamingKMeansTestMR.assertEquals((String)"Mapper StreamingKMeans / Batch local StreamingKMeans total cost ratio too far from 1", (double)1.0, (double)(mapperCost / localCost), (double)0.8);
        StreamingKMeansTestMR.assertEquals((String)"One by one local StreamingKMeans / Batch local StreamingKMeans total cost ratio too high", (double)1.0, (double)(perPointCost / localCost), (double)0.8);
    }

    @Test
    public void testHypercubeReducer() throws IOException {
        ReduceDriver reduceDriver = ReduceDriver.newReduceDriver((Reducer)new StreamingKMeansReducer());
        Configuration configuration = reduceDriver.getConfiguration();
        this.configure(configuration);
        System.out.printf("%s reducer test\n", configuration.get("searcherClass"));
        StreamingKMeans clusterer = new StreamingKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration((Configuration)configuration), 256 * (int)Math.log(32768.0), 1.0E-6);
        long start = System.currentTimeMillis();
        clusterer.cluster((Iterable)syntheticData.getFirst());
        long end = System.currentTimeMillis();
        System.out.printf("%f [s]\n", (double)(end - start) / 1000.0);
        ArrayList reducerInputs = Lists.newArrayList();
        int postMapperTotalWeight = 0;
        for (Centroid intermediateCentroid : clusterer) {
            reducerInputs.add(new CentroidWritable(intermediateCentroid));
            postMapperTotalWeight = (int)((double)postMapperTotalWeight + intermediateCentroid.getWeight());
        }
        reduceDriver.addInput((Object)new IntWritable(0), (List)reducerInputs);
        List results = reduceDriver.run();
        StreamingKMeansTestMR.testReducerResults(postMapperTotalWeight, results);
    }

    @Test
    public void testHypercubeMapReduce() throws IOException {
        MapReduceDriver mapReduceDriver = new MapReduceDriver((Mapper)new StreamingKMeansMapper(), (Reducer)new StreamingKMeansReducer());
        Configuration configuration = mapReduceDriver.getConfiguration();
        this.configure(configuration);
        System.out.printf("%s full test\n", configuration.get("searcherClass"));
        for (Centroid datapoint : (List)syntheticData.getFirst()) {
            mapReduceDriver.addInput((Object)new IntWritable(0), (Object)new VectorWritable((Vector)datapoint));
        }
        List results = mapReduceDriver.run();
        StreamingKMeansTestMR.testReducerResults(((List)syntheticData.getFirst()).size(), results);
    }

    @Test
    public void testHypercubeMapReduceRunSequentially() throws Exception {
        Configuration configuration = this.getConfiguration();
        this.configure(configuration);
        configuration.set("method", "sequential");
        Path inputPath = new Path("testInput");
        Path outputPath = new Path("testOutput");
        StreamingKMeansUtilsMR.writeVectorsToSequenceFile((Iterable)((Iterable)syntheticData.getFirst()), (Path)inputPath, (Configuration)configuration);
        StreamingKMeansDriver.run((Configuration)configuration, (Path)inputPath, (Path)outputPath);
        StreamingKMeansTestMR.testReducerResults(((List)syntheticData.getFirst()).size(), Lists.newArrayList((Iterable)Iterables.transform((Iterable)new SequenceFileIterable(outputPath, configuration), (Function)new Function<org.apache.mahout.common.Pair<IntWritable, CentroidWritable>, Pair<IntWritable, CentroidWritable>>(){

            public Pair<IntWritable, CentroidWritable> apply(org.apache.mahout.common.Pair<IntWritable, CentroidWritable> input) {
                return new Pair(input.getFirst(), input.getSecond());
            }
        })));
    }

    private static void testReducerResults(int totalWeight, List<Pair<IntWritable, CentroidWritable>> results) {
        int expectedNumClusters = 256;
        double expectedWeight = (double)totalWeight / (double)expectedNumClusters;
        int numClusters = 0;
        int numUnbalancedClusters = 0;
        int totalReducerWeight = 0;
        for (Pair<IntWritable, CentroidWritable> result : results) {
            if (((CentroidWritable)result.getSecond()).getCentroid().getWeight() != expectedWeight) {
                System.out.printf("Unbalanced weight %f in centroid %d\n", ((CentroidWritable)result.getSecond()).getCentroid().getWeight(), ((CentroidWritable)result.getSecond()).getCentroid().getIndex());
                ++numUnbalancedClusters;
            }
            StreamingKMeansTestMR.assertEquals((String)"Final centroid index is invalid", (long)numClusters, (long)((IntWritable)result.getFirst()).get());
            totalReducerWeight = (int)((double)totalReducerWeight + ((CentroidWritable)result.getSecond()).getCentroid().getWeight());
            ++numClusters;
        }
        System.out.printf("%d clusters are unbalanced\n", numUnbalancedClusters);
        StreamingKMeansTestMR.assertEquals((String)"Invalid total weight", (long)totalWeight, (long)totalReducerWeight);
        StreamingKMeansTestMR.assertEquals((String)"Invalid number of clusters", (long)256L, (long)numClusters);
    }
}

