/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim;

import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.feature.OffsetInstance;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.optim.IterativelyReweightedLeastSquaresModel;
import org.apache.spark.ml.optim.WeightedLeastSquares;
import org.apache.spark.ml.optim.WeightedLeastSquares$;
import org.apache.spark.ml.optim.WeightedLeastSquaresModel;
import org.apache.spark.ml.util.OptionalInstrumentation;
import org.apache.spark.ml.util.OptionalInstrumentation$;
import org.apache.spark.rdd.RDD;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;

@ScalaSignature(bytes="\u0006\u0001i4Q!\u0001\u0002\u0001\t1\u0011\u0011%\u0013;fe\u0006$\u0018N^3msJ+w/Z5hQR,G\rT3bgR\u001c\u0016/^1sKNT!a\u0001\u0003\u0002\u000b=\u0004H/[7\u000b\u0005\u00151\u0011AA7m\u0015\t9\u0001\"A\u0003ta\u0006\u00148N\u0003\u0002\n\u0015\u00051\u0011\r]1dQ\u0016T\u0011aC\u0001\u0004_J<7c\u0001\u0001\u000e'A\u0011a\"E\u0007\u0002\u001f)\t\u0001#A\u0003tG\u0006d\u0017-\u0003\u0002\u0013\u001f\t1\u0011I\\=SK\u001a\u0004\"A\u0004\u000b\n\u0005Uy!\u0001D*fe&\fG.\u001b>bE2,\u0007\u0002C\f\u0001\u0005\u000b\u0007I\u0011A\r\u0002\u0019%t\u0017\u000e^5bY6{G-\u001a7\u0004\u0001U\t!\u0004\u0005\u0002\u001c95\t!!\u0003\u0002\u001e\u0005\tIr+Z5hQR,G\rT3bgR\u001c\u0016/^1sKNlu\u000eZ3m\u0011!y\u0002A!A!\u0002\u0013Q\u0012!D5oSRL\u0017\r\\'pI\u0016d\u0007\u0005\u0003\u0005\"\u0001\t\u0015\r\u0011\"\u0001#\u00031\u0011Xm^3jO\"$h)\u001e8d+\u0005\u0019\u0003#\u0002\b%Mia\u0013BA\u0013\u0010\u0005%1UO\\2uS>t'\u0007\u0005\u0002(U5\t\u0001F\u0003\u0002*\t\u00059a-Z1ukJ,\u0017BA\u0016)\u00059yeMZ:fi&s7\u000f^1oG\u0016\u0004BAD\u00170_%\u0011af\u0004\u0002\u0007)V\u0004H.\u001a\u001a\u0011\u00059\u0001\u0014BA\u0019\u0010\u0005\u0019!u.\u001e2mK\"A1\u0007\u0001B\u0001B\u0003%1%A\u0007sK^,\u0017n\u001a5u\rVt7\r\t\u0005\tk\u0001\u0011)\u0019!C\u0001m\u0005aa-\u001b;J]R,'oY3qiV\tq\u0007\u0005\u0002\u000fq%\u0011\u0011h\u0004\u0002\b\u0005>|G.Z1o\u0011!Y\u0004A!A!\u0002\u00139\u0014!\u00044ji&sG/\u001a:dKB$\b\u0005\u0003\u0005>\u0001\t\u0015\r\u0011\"\u0001?\u0003!\u0011Xm\u001a)be\u0006lW#A\u0018\t\u0011\u0001\u0003!\u0011!Q\u0001\n=\n\u0011B]3h!\u0006\u0014\u0018-\u001c\u0011\t\u0011\t\u0003!Q1A\u0005\u0002\r\u000bq!\\1y\u0013R,'/F\u0001E!\tqQ)\u0003\u0002G\u001f\t\u0019\u0011J\u001c;\t\u0011!\u0003!\u0011!Q\u0001\n\u0011\u000b\u0001\"\\1y\u0013R,'\u000f\t\u0005\t\u0015\u0002\u0011)\u0019!C\u0001}\u0005\u0019Ao\u001c7\t\u00111\u0003!\u0011!Q\u0001\n=\nA\u0001^8mA!)a\n\u0001C\u0001\u001f\u00061A(\u001b8jiz\"r\u0001U)S'R+f\u000b\u0005\u0002\u001c\u0001!)q#\u0014a\u00015!)\u0011%\u0014a\u0001G!)Q'\u0014a\u0001o!)Q(\u0014a\u0001_!)!)\u0014a\u0001\t\")!*\u0014a\u0001_!)\u0001\f\u0001C\u00013\u0006\u0019a-\u001b;\u0015\u0007ikV\r\u0005\u0002\u001c7&\u0011AL\u0001\u0002'\u0013R,'/\u0019;jm\u0016d\u0017PU3xK&<\u0007\u000e^3e\u0019\u0016\f7\u000f^*rk\u0006\u0014Xm]'pI\u0016d\u0007\"\u00020X\u0001\u0004y\u0016!C5ogR\fgnY3t!\r\u00017MJ\u0007\u0002C*\u0011!MB\u0001\u0004e\u0012$\u0017B\u00013b\u0005\r\u0011F\t\u0012\u0005\bM^\u0003\n\u00111\u0001h\u0003\u0015Ign\u001d;s!\tA7.D\u0001j\u0015\tQG!\u0001\u0003vi&d\u0017B\u00017j\u0005]y\u0005\u000f^5p]\u0006d\u0017J\\:ueVlWM\u001c;bi&|g\u000eC\u0004o\u0001E\u0005I\u0011A8\u0002\u001b\u0019LG\u000f\n3fM\u0006,H\u000e\u001e\u00133+\u0005\u0001(FA4rW\u0005\u0011\bCA:y\u001b\u0005!(BA;w\u0003%)hn\u00195fG.,GM\u0003\u0002x\u001f\u0005Q\u0011M\u001c8pi\u0006$\u0018n\u001c8\n\u0005e$(!E;oG\",7m[3e-\u0006\u0014\u0018.\u00198dK\u0002")
public class IterativelyReweightedLeastSquares
implements Serializable {
    private final WeightedLeastSquaresModel initialModel;
    private final Function2<OffsetInstance, WeightedLeastSquaresModel, Tuple2<Object, Object>> reweightFunc;
    private final boolean fitIntercept;
    private final double regParam;
    private final int maxIter;
    private final double tol;

    public WeightedLeastSquaresModel initialModel() {
        return this.initialModel;
    }

    public Function2<OffsetInstance, WeightedLeastSquaresModel, Tuple2<Object, Object>> reweightFunc() {
        return this.reweightFunc;
    }

    public boolean fitIntercept() {
        return this.fitIntercept;
    }

    public double regParam() {
        return this.regParam;
    }

    public int maxIter() {
        return this.maxIter;
    }

    public double tol() {
        return this.tol;
    }

    public IterativelyReweightedLeastSquaresModel fit(RDD<OffsetInstance> instances, OptionalInstrumentation instr) {
        boolean converged = false;
        IntRef iter = IntRef.create((int)0);
        WeightedLeastSquaresModel model = this.initialModel();
        ObjectRef oldModel = ObjectRef.create(null);
        while (iter.elem < this.maxIter() && !converged) {
            oldModel.elem = model;
            RDD newInstances = instances.map((Function1)new Serializable(this, oldModel){
                public static final long serialVersionUID = 0L;
                private final /* synthetic */ IterativelyReweightedLeastSquares $outer;
                private final ObjectRef oldModel$1;

                public final Instance apply(OffsetInstance instance) {
                    Tuple2 tuple2 = (Tuple2)this.$outer.reweightFunc().apply((Object)instance, (Object)((WeightedLeastSquaresModel)this.oldModel$1.elem));
                    if (tuple2 != null) {
                        Tuple2.mcDD.sp sp2;
                        double newLabel = tuple2._1$mcD$sp();
                        double newWeight = tuple2._2$mcD$sp();
                        Tuple2.mcDD.sp sp3 = sp2 = new Tuple2.mcDD.sp(newLabel, newWeight);
                        double newLabel2 = sp3._1$mcD$sp();
                        double newWeight2 = sp3._2$mcD$sp();
                        return new Instance(newLabel2, newWeight2, instance.features());
                    }
                    throw new MatchError((Object)tuple2);
                }
                {
                    if ($outer == null) {
                        throw null;
                    }
                    this.$outer = $outer;
                    this.oldModel$1 = oldModel$1;
                }
            }, ClassTag$.MODULE$.apply(Instance.class));
            model = new WeightedLeastSquares(this.fitIntercept(), this.regParam(), 0.0, false, false, WeightedLeastSquares$.MODULE$.$lessinit$greater$default$6(), WeightedLeastSquares$.MODULE$.$lessinit$greater$default$7(), WeightedLeastSquares$.MODULE$.$lessinit$greater$default$8()).fit((RDD<Instance>)newInstances, instr);
            DenseVector oldCoefficients = ((WeightedLeastSquaresModel)oldModel.elem).coefficients();
            DenseVector coefficients2 = model.coefficients();
            BLAS$.MODULE$.axpy(-1.0, (Vector)coefficients2, (Vector)oldCoefficients);
            double maxTolOfCoefficients = BoxesRunTime.unboxToDouble((Object)Predef$.MODULE$.doubleArrayOps(oldCoefficients.toArray()).foldLeft((Object)BoxesRunTime.boxToDouble((double)0.0), (Function2)new Serializable(this){
                public static final long serialVersionUID = 0L;

                public final double apply(double x, double y) {
                    return this.apply$mcDDD$sp(x, y);
                }

                public double apply$mcDDD$sp(double x, double y) {
                    return package$.MODULE$.max(package$.MODULE$.abs(x), package$.MODULE$.abs(y));
                }
            }));
            double maxTol = package$.MODULE$.max(maxTolOfCoefficients, package$.MODULE$.abs(((WeightedLeastSquaresModel)oldModel.elem).intercept() - model.intercept()));
            if (maxTol < this.tol()) {
                converged = true;
                instr.logInfo((Function0<String>)new Serializable(this, iter){
                    public static final long serialVersionUID = 0L;
                    private final IntRef iter$1;

                    public final String apply() {
                        return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"IRLS converged in ", " iterations."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.iter$1.elem)}));
                    }
                    {
                        this.iter$1 = iter$1;
                    }
                });
            }
            instr.logInfo((Function0<String>)new Serializable(this, iter, maxTol){
                public static final long serialVersionUID = 0L;
                private final IntRef iter$1;
                private final double maxTol$1;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Iteration ", " : relative tolerance = ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.iter$1.elem), BoxesRunTime.boxToDouble((double)this.maxTol$1)}));
                }
                {
                    this.iter$1 = iter$1;
                    this.maxTol$1 = maxTol$1;
                }
            });
            ++iter.elem;
            if (iter.elem != this.maxIter()) continue;
            instr.logInfo((Function0<String>)new Serializable(this){
                public static final long serialVersionUID = 0L;
                private final /* synthetic */ IterativelyReweightedLeastSquares $outer;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"IRLS reached the max number of iterations: ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.$outer.maxIter())}));
                }
                {
                    if ($outer == null) {
                        throw null;
                    }
                    this.$outer = $outer;
                }
            });
        }
        return new IterativelyReweightedLeastSquaresModel(model.coefficients(), model.intercept(), model.diagInvAtWA(), iter.elem);
    }

    public OptionalInstrumentation fit$default$2() {
        return OptionalInstrumentation$.MODULE$.create(IterativelyReweightedLeastSquares.class);
    }

    public IterativelyReweightedLeastSquares(WeightedLeastSquaresModel initialModel, Function2<OffsetInstance, WeightedLeastSquaresModel, Tuple2<Object, Object>> reweightFunc, boolean fitIntercept, double regParam, int maxIter, double tol) {
        this.initialModel = initialModel;
        this.reweightFunc = reweightFunc;
        this.fitIntercept = fitIntercept;
        this.regParam = regParam;
        this.maxIter = maxIter;
        this.tol = tol;
    }
}

