/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.Vector;

public class ModelDissector {
    private final Map<String, Vector> weightMap = Maps.newHashMap();

    public void update(Vector features, Map<String, Set<Integer>> traceDictionary, AbstractVectorClassifier learner) {
        features.assign(0.0);
        for (Map.Entry<String, Set<Integer>> entry : traceDictionary.entrySet()) {
            String key = entry.getKey();
            Set<Integer> value = entry.getValue();
            if (this.weightMap.containsKey(key)) continue;
            for (Integer where : value) {
                features.set(where.intValue(), 1.0);
            }
            Vector v = learner.classifyNoLink(features);
            this.weightMap.put(key, v);
            for (Integer where : value) {
                features.set(where.intValue(), 0.0);
            }
        }
    }

    public List<Weight> summary(int n) {
        PriorityQueue<Weight> pq = new PriorityQueue<Weight>();
        for (Map.Entry<String, Vector> entry : this.weightMap.entrySet()) {
            pq.add(new Weight(entry.getKey(), entry.getValue()));
            while (pq.size() > n) {
                pq.poll();
            }
        }
        ArrayList r = Lists.newArrayList(pq);
        Collections.sort(r, Ordering.natural().reverse());
        return r;
    }

    public static class Weight
    implements Comparable<Weight> {
        private final String feature;
        private final double value;
        private final int maxIndex;
        private final List<Category> categories;

        public Weight(String feature, Vector weights) {
            this(feature, weights, 3);
        }

        public Weight(String feature, Vector weights, int n) {
            this.feature = feature;
            PriorityQueue<Category> biggest = new PriorityQueue<Category>(n + 1, (Comparator<Category>)Ordering.natural());
            for (Vector.Element element : weights.all()) {
                biggest.add(new Category(element.index(), element.get()));
                while (biggest.size() > n) {
                    biggest.poll();
                }
            }
            this.categories = Lists.newArrayList(biggest);
            Collections.sort(this.categories, Ordering.natural().reverse());
            this.value = this.categories.get(0).weight;
            this.maxIndex = this.categories.get(0).index;
        }

        @Override
        public int compareTo(Weight other) {
            int r = Double.compare(Math.abs(this.value), Math.abs(other.value));
            if (r == 0) {
                return this.feature.compareTo(other.feature);
            }
            return r;
        }

        public boolean equals(Object o) {
            if (!(o instanceof Weight)) {
                return false;
            }
            Weight other = (Weight)o;
            return this.feature.equals(other.feature) && this.value == other.value && this.maxIndex == other.maxIndex && this.categories.equals(other.categories);
        }

        public int hashCode() {
            return this.feature.hashCode() ^ RandomUtils.hashDouble((double)this.value) ^ this.maxIndex ^ this.categories.hashCode();
        }

        public String getFeature() {
            return this.feature;
        }

        public double getWeight() {
            return this.value;
        }

        public double getWeight(int n) {
            return this.categories.get(n).weight;
        }

        public double getCategory(int n) {
            return this.categories.get(n).index;
        }

        public int getMaxImpact() {
            return this.maxIndex;
        }
    }

    private static final class Category
    implements Comparable<Category> {
        private final int index;
        private final double weight;

        private Category(int index, double weight) {
            this.index = index;
            this.weight = weight;
        }

        @Override
        public int compareTo(Category o) {
            int r = Double.compare(Math.abs(this.weight), Math.abs(o.weight));
            if (r == 0) {
                if (o.index < this.index) {
                    return -1;
                }
                if (o.index > this.index) {
                    return 1;
                }
                return 0;
            }
            return r;
        }

        public boolean equals(Object o) {
            if (!(o instanceof Category)) {
                return false;
            }
            Category other = (Category)o;
            return this.index == other.index && this.weight == other.weight;
        }

        public int hashCode() {
            return RandomUtils.hashDouble((double)this.weight) ^ this.index;
        }
    }
}

