/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.math.hadoop.stochasticsvd;

import java.io.Closeable;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Iterator;
import org.apache.commons.lang3.Validate;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.compress.DefaultCodec;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import org.apache.hadoop.mapred.lib.MultipleOutputs;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.IOUtils;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.UpperTriangular;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.PlusMult;
import org.apache.mahout.math.hadoop.stochasticsvd.DenseBlockWritable;
import org.apache.mahout.math.hadoop.stochasticsvd.SSVDHelper;
import org.apache.mahout.math.hadoop.stochasticsvd.SparseRowBlockAccumulator;
import org.apache.mahout.math.hadoop.stochasticsvd.SparseRowBlockWritable;
import org.apache.mahout.math.hadoop.stochasticsvd.qr.QRLastStep;

public final class BtJob {
    public static final String OUTPUT_Q = "Q";
    public static final String OUTPUT_BT = "part";
    public static final String OUTPUT_BBT = "bbt";
    public static final String OUTPUT_SQ = "sq";
    public static final String OUTPUT_SB = "sb";
    public static final String PROP_QJOB_PATH = "ssvd.QJob.path";
    public static final String PROP_OUPTUT_BBT_PRODUCTS = "ssvd.BtJob.outputBBtProducts";
    public static final String PROP_OUTER_PROD_BLOCK_HEIGHT = "ssvd.outerProdBlockHeight";
    public static final String PROP_RHAT_BROADCAST = "ssvd.rhat.broadcast";
    public static final String PROP_XI_PATH = "ssvdpca.xi.path";
    public static final String PROP_NV = "ssvd.nv";
    static final double SPARSE_ZEROS_PCT_THRESHOLD = 0.1;

    private BtJob() {
    }

    public static void run(Configuration conf, Path[] inputPathA, Path inputPathQJob, Path xiPath, Path outputPath, int minSplitSize, int k, int p, int btBlockHeight, int numReduceTasks, boolean broadcast, Class<? extends Writable> labelClass, boolean outputBBtProducts) throws ClassNotFoundException, InterruptedException, IOException {
        JobConf oldApiJob = new JobConf(conf);
        MultipleOutputs.addNamedOutput((JobConf)oldApiJob, (String)OUTPUT_Q, SequenceFileOutputFormat.class, labelClass, VectorWritable.class);
        if (outputBBtProducts) {
            MultipleOutputs.addNamedOutput((JobConf)oldApiJob, (String)OUTPUT_BBT, SequenceFileOutputFormat.class, IntWritable.class, VectorWritable.class);
            oldApiJob.setBoolean(PROP_NV, true);
        }
        if (xiPath != null) {
            MultipleOutputs.addNamedOutput((JobConf)oldApiJob, (String)OUTPUT_SQ, SequenceFileOutputFormat.class, IntWritable.class, VectorWritable.class);
            MultipleOutputs.addNamedOutput((JobConf)oldApiJob, (String)OUTPUT_SB, SequenceFileOutputFormat.class, IntWritable.class, VectorWritable.class);
        }
        Job job = new Job((Configuration)oldApiJob);
        job.setJobName("Bt-job");
        job.setJarByClass(BtJob.class);
        job.setInputFormatClass(SequenceFileInputFormat.class);
        job.setOutputFormatClass(org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class);
        FileInputFormat.setInputPaths((Job)job, (Path[])inputPathA);
        if (minSplitSize > 0) {
            FileInputFormat.setMinInputSplitSize((Job)job, (long)minSplitSize);
        }
        FileOutputFormat.setOutputPath((Job)job, (Path)outputPath);
        job.getConfiguration().set("mapreduce.output.basename", OUTPUT_BT);
        FileOutputFormat.setOutputCompressorClass((Job)job, DefaultCodec.class);
        org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.setOutputCompressionType((Job)job, (SequenceFile.CompressionType)SequenceFile.CompressionType.BLOCK);
        job.setMapOutputKeyClass(LongWritable.class);
        job.setMapOutputValueClass(SparseRowBlockWritable.class);
        job.setOutputKeyClass(IntWritable.class);
        job.setOutputValueClass(VectorWritable.class);
        job.setMapperClass(BtMapper.class);
        job.setCombinerClass(OuterProductCombiner.class);
        job.setReducerClass(OuterProductReducer.class);
        job.getConfiguration().setInt("ssvd.k", k);
        job.getConfiguration().setInt("ssvd.p", p);
        job.getConfiguration().set(PROP_QJOB_PATH, inputPathQJob.toString());
        job.getConfiguration().setBoolean(PROP_OUPTUT_BBT_PRODUCTS, outputBBtProducts);
        job.getConfiguration().setInt(PROP_OUTER_PROD_BLOCK_HEIGHT, btBlockHeight);
        job.setNumReduceTasks(numReduceTasks);
        if (xiPath != null) {
            job.getConfiguration().set(PROP_XI_PATH, xiPath.toString());
        }
        if (broadcast) {
            job.getConfiguration().set(PROP_RHAT_BROADCAST, "y");
            FileSystem fs = FileSystem.get((URI)inputPathQJob.toUri(), (Configuration)conf);
            FileStatus[] fstats = fs.globStatus(new Path(inputPathQJob, "R-*"));
            if (fstats != null) {
                for (FileStatus fstat : fstats) {
                    DistributedCache.addCacheFile((URI)fstat.getPath().toUri(), (Configuration)job.getConfiguration());
                }
            }
        }
        job.submit();
        job.waitForCompletion(false);
        if (!job.isSuccessful()) {
            throw new IOException("Bt job unsuccessful.");
        }
    }

    public static class OuterProductReducer
    extends Reducer<LongWritable, SparseRowBlockWritable, IntWritable, VectorWritable> {
        protected final SparseRowBlockWritable accum = new SparseRowBlockWritable();
        protected final Deque<Closeable> closeables = new ArrayDeque<Closeable>();
        protected int blockHeight;
        private boolean outputBBt;
        private UpperTriangular mBBt;
        private MultipleOutputs outputs;
        private final IntWritable btKey = new IntWritable();
        private final VectorWritable btValue = new VectorWritable();
        private Vector xi;
        private final PlusMult pmult = new PlusMult(0.0);
        private Vector sbAccum;

        protected void setup(Reducer.Context context) throws IOException, InterruptedException {
            String xiPathStr;
            Configuration conf = context.getConfiguration();
            this.blockHeight = conf.getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, -1);
            this.outputBBt = conf.getBoolean(BtJob.PROP_OUPTUT_BBT_PRODUCTS, false);
            if (this.outputBBt) {
                int k = conf.getInt("ssvd.k", -1);
                int p = conf.getInt("ssvd.p", -1);
                Validate.isTrue((k > 0 ? 1 : 0) != 0, (String)"invalid k parameter", (Object[])new Object[0]);
                Validate.isTrue((p >= 0 ? 1 : 0) != 0, (String)"invalid p parameter", (Object[])new Object[0]);
                this.mBBt = new UpperTriangular(k + p);
            }
            if ((xiPathStr = conf.get(BtJob.PROP_XI_PATH)) != null) {
                this.xi = SSVDHelper.loadAndSumUpVectors(new Path(xiPathStr), conf);
                if (this.xi == null) {
                    throw new IOException(String.format("unable to load mean path xi from %s.", xiPathStr));
                }
            }
            if (this.outputBBt || this.xi != null) {
                this.outputs = new MultipleOutputs(new JobConf(conf));
                this.closeables.addFirst(new IOUtils.MultipleOutputsCloseableAdapter(this.outputs));
            }
        }

        protected void reduce(LongWritable key, Iterable<SparseRowBlockWritable> values, Reducer.Context context) throws IOException, InterruptedException {
            this.accum.clear();
            for (SparseRowBlockWritable bw : values) {
                this.accum.plusBlock(bw);
            }
            if (key.get() == -1L) {
                Vector sq = this.accum.getRows()[0];
                OutputCollector sqOut = this.outputs.getCollector(BtJob.OUTPUT_SQ, null);
                sqOut.collect((Object)new IntWritable(0), (Object)new VectorWritable(sq));
                return;
            }
            for (int k = 0; k < this.accum.getNumRows(); ++k) {
                Vector btRow = this.accum.getRows()[k];
                this.btKey.set((int)(key.get() * (long)this.blockHeight + (long)this.accum.getRowIndices()[k]));
                this.btValue.set(btRow);
                context.write((Object)this.btKey, (Object)this.btValue);
                if (this.outputBBt) {
                    int kp = this.mBBt.numRows();
                    for (int i = 0; i < kp; ++i) {
                        double vi = btRow.get(i);
                        if (vi == 0.0) continue;
                        for (int j = i; j < kp; ++j) {
                            double vj = btRow.get(j);
                            if (vj == 0.0) continue;
                            this.mBBt.setQuick(i, j, this.mBBt.getQuick(i, j) + vi * vj);
                        }
                    }
                }
                if (this.xi == null) continue;
                int btIndex = this.btKey.get();
                double xii = this.xi.size() > btIndex ? this.xi.getQuick(btIndex) : 0.0;
                this.pmult.setMultiplicator(xii);
                if (this.sbAccum == null) {
                    this.sbAccum = new DenseVector(btRow.size());
                }
                this.sbAccum.assign(btRow, (DoubleDoubleFunction)this.pmult);
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        protected void cleanup(Reducer.Context context) throws IOException, InterruptedException {
            try {
                OutputCollector collector;
                if (this.outputBBt) {
                    collector = this.outputs.getCollector(BtJob.OUTPUT_BBT, null);
                    collector.collect((Object)new IntWritable(), (Object)new VectorWritable((Vector)new DenseVector(this.mBBt.getData())));
                }
                if (this.sbAccum != null) {
                    collector = this.outputs.getCollector(BtJob.OUTPUT_SB, null);
                    collector.collect((Object)new IntWritable(), (Object)new VectorWritable(this.sbAccum));
                }
            }
            finally {
                IOUtils.close(this.closeables);
            }
        }
    }

    public static class OuterProductCombiner
    extends Reducer<Writable, SparseRowBlockWritable, Writable, SparseRowBlockWritable> {
        protected final SparseRowBlockWritable accum = new SparseRowBlockWritable();
        protected final Deque<Closeable> closeables = new ArrayDeque<Closeable>();
        protected int blockHeight;

        protected void setup(Reducer.Context context) throws IOException, InterruptedException {
            this.blockHeight = context.getConfiguration().getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, -1);
        }

        protected void reduce(Writable key, Iterable<SparseRowBlockWritable> values, Reducer.Context context) throws IOException, InterruptedException {
            for (SparseRowBlockWritable bw : values) {
                this.accum.plusBlock(bw);
            }
            context.write((Object)key, (Object)this.accum);
            this.accum.clear();
        }

        protected void cleanup(Reducer.Context context) throws IOException, InterruptedException {
            IOUtils.close(this.closeables);
        }
    }

    public static class BtMapper
    extends Mapper<Writable, VectorWritable, LongWritable, SparseRowBlockWritable> {
        private QRLastStep qr;
        private final Deque<Closeable> closeables = new ArrayDeque<Closeable>();
        private int blockNum;
        private MultipleOutputs outputs;
        private final VectorWritable qRowValue = new VectorWritable();
        private Vector btRow;
        private SparseRowBlockAccumulator btCollector;
        private Mapper.Context mapContext;
        private boolean nv;
        private Vector sqAccum;
        private boolean computeSq;

        protected void map(Writable key, VectorWritable value, Mapper.Context context) throws IOException, InterruptedException {
            this.mapContext = context;
            Vector aRow = value.get();
            Vector qRow = this.qr.next();
            int kp = qRow.size();
            this.outputQRow(key, qRow, aRow);
            if (this.computeSq) {
                if (this.sqAccum == null) {
                    this.sqAccum = new DenseVector(kp);
                }
                this.sqAccum.assign(qRow, Functions.PLUS);
            }
            if (this.btRow == null) {
                this.btRow = new DenseVector(kp);
            }
            if (!aRow.isDense()) {
                for (Vector.Element el : aRow.nonZeroes()) {
                    double mul = el.get();
                    for (int j = 0; j < kp; ++j) {
                        this.btRow.setQuick(j, mul * qRow.getQuick(j));
                    }
                    this.btCollector.collect(Long.valueOf(el.index()), this.btRow);
                }
            } else {
                int n = aRow.size();
                for (int i = 0; i < n; ++i) {
                    double mul = aRow.getQuick(i);
                    for (int j = 0; j < kp; ++j) {
                        this.btRow.setQuick(j, mul * qRow.getQuick(j));
                    }
                    this.btCollector.collect(Long.valueOf(i), this.btRow);
                }
            }
        }

        protected void setup(Mapper.Context context) throws IOException, InterruptedException {
            SequenceFileDirValueIterator rhatInput;
            boolean distributedRHat;
            super.setup(context);
            Configuration conf = context.getConfiguration();
            Path qJobPath = new Path(conf.get(BtJob.PROP_QJOB_PATH));
            Path qInputPath = new Path(qJobPath, FileOutputFormat.getUniqueFile((TaskAttemptContext)context, (String)"QHat", (String)""));
            this.blockNum = context.getTaskAttemptID().getTaskID().getId();
            SequenceFileValueIterator qhatInput = new SequenceFileValueIterator(qInputPath, true, conf);
            this.closeables.addFirst(qhatInput);
            boolean bl = distributedRHat = conf.get(BtJob.PROP_RHAT_BROADCAST) != null;
            if (distributedRHat) {
                Path[] rFiles = HadoopUtil.getCachedFiles(conf);
                Validate.notNull((Object)rFiles, (String)"no RHat files in distributed cache job definition", (Object[])new Object[0]);
                Configuration lconf = new Configuration();
                lconf.set("fs.default.name", "file:///");
                rhatInput = new SequenceFileDirValueIterator(rFiles, SSVDHelper.PARTITION_COMPARATOR, true, lconf);
            } else {
                Path rPath = new Path(qJobPath, "R-*");
                rhatInput = new SequenceFileDirValueIterator(rPath, PathType.GLOB, null, SSVDHelper.PARTITION_COMPARATOR, true, conf);
            }
            Validate.isTrue((boolean)rhatInput.hasNext(), (String)"Empty R-hat input!", (Object[])new Object[0]);
            this.closeables.addFirst(rhatInput);
            this.outputs = new MultipleOutputs(new JobConf(conf));
            this.closeables.addFirst(new IOUtils.MultipleOutputsCloseableAdapter(this.outputs));
            this.qr = new QRLastStep((Iterator<DenseBlockWritable>)((Object)qhatInput), (Iterator<VectorWritable>)((Object)rhatInput), this.blockNum);
            this.closeables.addFirst(this.qr);
            if (!rhatInput.hasNext()) {
                this.closeables.remove(rhatInput);
                rhatInput.close();
            }
            OutputCollector<LongWritable, SparseRowBlockWritable> btBlockCollector = new OutputCollector<LongWritable, SparseRowBlockWritable>(){

                public void collect(LongWritable blockKey, SparseRowBlockWritable block) throws IOException {
                    try {
                        BtMapper.this.mapContext.write((Object)blockKey, (Object)block);
                    }
                    catch (InterruptedException exc) {
                        throw new IOException("Interrupted.", exc);
                    }
                }
            };
            this.btCollector = new SparseRowBlockAccumulator(conf.getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, -1), btBlockCollector);
            this.closeables.addFirst(this.btCollector);
            this.computeSq = conf.get(BtJob.PROP_XI_PATH) != null;
            this.nv = conf.getBoolean(BtJob.PROP_NV, false);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        protected void cleanup(Mapper.Context context) throws IOException, InterruptedException {
            try {
                if (this.sqAccum != null) {
                    SparseRowBlockWritable sbrw = new SparseRowBlockWritable(1);
                    sbrw.plusRow(0, this.sqAccum);
                    LongWritable lw = new LongWritable(-1L);
                    context.write((Object)lw, (Object)sbrw);
                }
            }
            finally {
                IOUtils.close(this.closeables);
            }
        }

        private void outputQRow(Writable key, Vector qRow, Vector aRow) throws IOException {
            if (this.nv && aRow instanceof NamedVector) {
                this.qRowValue.set((Vector)new NamedVector(qRow, ((NamedVector)aRow).getName()));
            } else {
                this.qRowValue.set(qRow);
            }
            this.outputs.getCollector(BtJob.OUTPUT_Q, null).collect((Object)key, (Object)this.qRowValue);
        }
    }
}

