/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import java.io.Serializable;
import java.util.Arrays;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.DenseVector$;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Option;
import scala.Predef$;
import scala.collection.ArrayOps$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0005-4QAD\b\u0001'mA\u0001b\r\u0001\u0003\u0002\u0003\u0006I!\u000e\u0005\t\u0003\u0002\u0011\t\u0011)A\u0005k!A!\t\u0001B\u0001B\u0003%1\t\u0003\u0005G\u0001\t\u0005\t\u0015!\u0003H\u0011\u0015q\u0005\u0001\"\u0001P\u0011\u001d)\u0006A1A\u0005\nYCaA\u0017\u0001!\u0002\u00139\u0006bB.\u0001\u0005\u0004%\tF\u0016\u0005\u00079\u0002\u0001\u000b\u0011B,\t\u0011u\u0003\u0001R1A\u0005\nyCqa\u0019\u0001C\u0002\u0013%A\r\u0003\u0004f\u0001\u0001\u0006IA\u0010\u0005\u0006M\u0002!\ta\u001a\u0002\u0015\u0011&tw-\u001a\"m_\u000e\\\u0017iZ4sK\u001e\fGo\u001c:\u000b\u0005A\t\u0012AC1hOJ,w-\u0019;pe*\u0011!cE\u0001\u0006_B$\u0018.\u001c\u0006\u0003)U\t!!\u001c7\u000b\u0005Y9\u0012!B:qCJ\\'B\u0001\r\u001a\u0003\u0019\t\u0007/Y2iK*\t!$A\u0002pe\u001e\u001cB\u0001\u0001\u000f#[A\u0011Q\u0004I\u0007\u0002=)\tq$A\u0003tG\u0006d\u0017-\u0003\u0002\"=\t1\u0011I\\=SK\u001a\u0004Ba\t\u0013'Y5\tq\"\u0003\u0002&\u001f\taB)\u001b4gKJ,g\u000e^5bE2,Gj\\:t\u0003\u001e<'/Z4bi>\u0014\bCA\u0014+\u001b\u0005A#BA\u0015\u0014\u0003\u001d1W-\u0019;ve\u0016L!a\u000b\u0015\u0003\u001b%s7\u000f^1oG\u0016\u0014En\\2l!\t\u0019\u0003\u0001\u0005\u0002/c5\tqF\u0003\u00021+\u0005A\u0011N\u001c;fe:\fG.\u0003\u00023_\t9Aj\\4hS:<\u0017\u0001\u00042d\u0013:4XM]:f'R$7\u0001\u0001\t\u0004meZT\"A\u001c\u000b\u0005a*\u0012!\u00032s_\u0006$7-Y:u\u0013\tQtGA\u0005Ce>\fGmY1tiB\u0019Q\u0004\u0010 \n\u0005ur\"!B!se\u0006L\bCA\u000f@\u0013\t\u0001eD\u0001\u0004E_V\u0014G.Z\u0001\rE\u000e\u001c6-\u00197fI6+\u0017M\\\u0001\rM&$\u0018J\u001c;fe\u000e,\u0007\u000f\u001e\t\u0003;\u0011K!!\u0012\u0010\u0003\u000f\t{w\u000e\\3b]\u0006q!mY\"pK\u001a4\u0017nY5f]R\u001c\bc\u0001\u001c:\u0011B\u0011\u0011\nT\u0007\u0002\u0015*\u00111jE\u0001\u0007Y&t\u0017\r\\4\n\u00055S%A\u0002,fGR|'/\u0001\u0004=S:LGO\u0010\u000b\u0005!J\u001bF\u000b\u0006\u0002-#\")a)\u0002a\u0001\u000f\")1'\u0002a\u0001k!)\u0011)\u0002a\u0001k!)!)\u0002a\u0001\u0007\u0006Ya.^7GK\u0006$XO]3t+\u00059\u0006CA\u000fY\u0013\tIfDA\u0002J]R\fAB\\;n\r\u0016\fG/\u001e:fg\u0002\n1\u0001Z5n\u0003\u0011!\u0017.\u001c\u0011\u0002#\r|WM\u001a4jG&,g\u000e^:BeJ\f\u00170F\u0001<Q\tQ\u0001\r\u0005\u0002\u001eC&\u0011!M\b\u0002\niJ\fgn]5f]R\fA\"\\1sO&twJ\u001a4tKR,\u0012AP\u0001\u000e[\u0006\u0014x-\u001b8PM\u001a\u001cX\r\u001e\u0011\u0002\u0007\u0005$G\r\u0006\u0002iS6\t\u0001\u0001C\u0003k\u001b\u0001\u0007a%A\u0003cY>\u001c7\u000e")
public class HingeBlockAggregator
implements DifferentiableLossAggregator<InstanceBlock, HingeBlockAggregator>,
Logging {
    private transient double[] coefficientsArray;
    private final Broadcast<double[]> bcScaledMean;
    private final boolean fitIntercept;
    private final Broadcast<Vector> bcCoefficients;
    private final int numFeatures;
    private final int dim;
    private final double marginOffset;
    private transient Logger org$apache$spark$internal$Logging$$log_;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient boolean bitmap$trans$0;
    private volatile boolean bitmap$0;

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    public void initializeForcefully(boolean isInterpreter, boolean silent) {
        Logging.initializeForcefully$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator.merge$(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator.gradient$(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator.weight$(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator.loss$(this);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$internal$Logging$$log_ = x$1;
    }

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        HingeBlockAggregator hingeBlockAggregator = this;
        synchronized (hingeBlockAggregator) {
            if (!this.bitmap$0) {
                this.gradientSumArray = DifferentiableLossAggregator.gradientSumArray$(this);
                this.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? this.gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    private double[] coefficientsArray$lzycompute() {
        HingeBlockAggregator hingeBlockAggregator = this;
        synchronized (hingeBlockAggregator) {
            if (!this.bitmap$trans$0) {
                double[] values;
                DenseVector denseVector;
                Option option;
                Vector vector = (Vector)this.bcCoefficients.value();
                if (!(vector instanceof DenseVector) || (option = DenseVector$.MODULE$.unapply(denseVector = (DenseVector)vector)).isEmpty()) {
                    throw new IllegalArgumentException(new StringBuilder(0).append("coefficients only supports dense vector but ").append(new StringBuilder(11).append("got type ").append(this.bcCoefficients.value().getClass()).append(".)").toString()).toString());
                }
                double[] dArray = values = (double[])option.get();
                this.coefficientsArray = dArray;
                this.bitmap$trans$0 = true;
            }
        }
        return this.coefficientsArray;
    }

    private double[] coefficientsArray() {
        return !this.bitmap$trans$0 ? this.coefficientsArray$lzycompute() : this.coefficientsArray;
    }

    private double marginOffset() {
        return this.marginOffset;
    }

    @Override
    public HingeBlockAggregator add(InstanceBlock block) {
        block6: {
            Predef$.MODULE$.require(block.matrix().isTransposed());
            Predef$.MODULE$.require(this.numFeatures() == block.numFeatures(), (Function0 & Serializable)() -> new StringBuilder(0).append("Dimensions mismatch when adding new ").append(new StringBuilder(30).append("instance. Expecting ").append(this.numFeatures()).append(" but got ").append(block.numFeatures()).append(".").toString()).toString());
            Predef$.MODULE$.require(block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable)x$1 -> x$1 >= 0.0), (Function0 & Serializable)() -> new StringBuilder(34).append("instance weights ").append(block.weightIter().mkString("[", ",", "]")).append(" has to be >= 0.0").toString());
            if (block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable)x$2 -> x$2 == 0.0)) {
                return this;
            }
            int size = block.size();
            double[] arr = (double[])Array$.MODULE$.ofDim(size, (ClassTag)ClassTag$.MODULE$.Double());
            if (this.fitIntercept) {
                Arrays.fill(arr, this.marginOffset());
            }
            BLAS$.MODULE$.gemv(1.0, block.matrix(), this.coefficientsArray(), 1.0, arr);
            double localLossSum = 0.0;
            double localWeightSum = 0.0;
            double multiplierSum = 0.0;
            for (int i = 0; i < size; ++i) {
                double weight = block.getWeight().apply$mcDI$sp(i);
                localWeightSum += weight;
                if (weight > 0.0) {
                    double label = block.getLabel(i);
                    double labelScaled = label + label - 1.0;
                    double loss = (1.0 - labelScaled * arr[i]) * weight;
                    if (loss > 0.0) {
                        double multiplier;
                        localLossSum += loss;
                        arr[i] = multiplier = -labelScaled * weight;
                        multiplierSum += multiplier;
                        continue;
                    }
                    arr[i] = 0.0;
                    continue;
                }
                arr[i] = 0.0;
            }
            this.lossSum_$eq(this.lossSum() + localLossSum);
            this.weightSum_$eq(this.weightSum() + localWeightSum);
            if (ArrayOps$.MODULE$.forall$extension(Predef$.MODULE$.doubleArrayOps(arr), (Function1)(JFunction1.mcZD.sp & Serializable)x$3 -> x$3 == 0.0)) {
                return this;
            }
            BLAS$.MODULE$.gemv(1.0, block.matrix().transpose(), arr, 1.0, this.gradientSumArray());
            if (!this.fitIntercept) break block6;
            BLAS$.MODULE$.javaBLAS().daxpy(this.numFeatures(), -multiplierSum, (double[])this.bcScaledMean.value(), 1, this.gradientSumArray(), 1);
            this.gradientSumArray()[this.numFeatures()] = this.gradientSumArray()[this.numFeatures()] + multiplierSum;
        }
        return this;
    }

    public HingeBlockAggregator(Broadcast<double[]> bcInverseStd, Broadcast<double[]> bcScaledMean, boolean fitIntercept, Broadcast<Vector> bcCoefficients) {
        this.bcScaledMean = bcScaledMean;
        this.fitIntercept = fitIntercept;
        this.bcCoefficients = bcCoefficients;
        DifferentiableLossAggregator.$init$(this);
        Logging.$init$((Logging)this);
        if (fitIntercept) {
            Predef$.MODULE$.require(bcScaledMean != null && ((double[])bcScaledMean.value()).length == ((double[])bcInverseStd.value()).length, (Function0 & Serializable)() -> "scaled means is required when center the vectors");
        }
        this.numFeatures = ((double[])bcInverseStd.value()).length;
        this.dim = ((Vector)bcCoefficients.value()).size();
        this.marginOffset = fitIntercept ? BoxesRunTime.unboxToDouble((Object)ArrayOps$.MODULE$.last$extension(Predef$.MODULE$.doubleArrayOps(this.coefficientsArray()))) - BLAS$.MODULE$.javaBLAS().ddot(this.numFeatures(), this.coefficientsArray(), 1, (double[])bcScaledMean.value(), 1) : Double.NaN;
    }
}

