/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import java.util.Iterator;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Column;
import ml.dmlc.xgboost4j.java.ColumnBatch;
import ml.dmlc.xgboost4j.java.DataBatch;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.XGBoostJNI;
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;

public class DMatrix {
    protected long handle = 0L;

    public DMatrix(Iterator<LabeledPoint> iter, String cacheInfo) throws XGBoostError {
        if (iter == null) {
            throw new NullPointerException("iter: null");
        }
        int batchSize = 32768;
        DataBatch.BatchIterator batchIter = new DataBatch.BatchIterator(iter, batchSize);
        long[] out = new long[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out));
        this.handle = out[0];
    }

    public DMatrix(String dataPath) throws XGBoostError {
        if (dataPath == null) {
            throw new NullPointerException("dataPath: null");
        }
        long[] out = new long[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
        this.handle = out[0];
    }

    @Deprecated
    public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
        long[] out = new long[1];
        if (st == SparseType.CSR) {
            XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out));
        } else if (st == SparseType.CSC) {
            XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out));
        } else {
            throw new UnknownError("unknow sparsetype");
        }
        this.handle = out[0];
    }

    public DMatrix(long[] headers, int[] indices, float[] data, SparseType st, int shapeParam) throws XGBoostError {
        long[] out = new long[1];
        if (st == SparseType.CSR) {
            XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, shapeParam, out));
        } else if (st == SparseType.CSC) {
            XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, shapeParam, out));
        } else {
            throw new UnknownError("unknow sparsetype");
        }
        this.handle = out[0];
    }

    @Deprecated
    public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError {
        long[] out = new long[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
        this.handle = out[0];
    }

    public DMatrix(BigDenseMatrix matrix) throws XGBoostError {
        this(matrix, 0.0f);
    }

    public DMatrix(float[] data, int nrow, int ncol, float missing) throws XGBoostError {
        long[] out = new long[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, missing, out));
        this.handle = out[0];
    }

    public DMatrix(BigDenseMatrix matrix, float missing) throws XGBoostError {
        long[] out = new long[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMatRef(matrix.address, matrix.nrow, matrix.ncol, missing, out));
        this.handle = out[0];
    }

    protected DMatrix(long handle) {
        this.handle = handle;
    }

    public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoostError {
        long[] out = new long[1];
        String json = columnBatch.getFeatureArrayInterface();
        if (json == null || json.isEmpty()) {
            throw new XGBoostError("Expecting non-empty feature columns' array interface");
        }
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromArrayInterfaceColumns(json, missing, nthread, out));
        this.handle = out[0];
    }

    public void setLabel(Column column) throws XGBoostError {
        this.setXGBDMatrixInfo("label", column.getArrayInterfaceJson());
    }

    public void setWeight(Column column) throws XGBoostError {
        this.setXGBDMatrixInfo("weight", column.getArrayInterfaceJson());
    }

    public void setBaseMargin(Column column) throws XGBoostError {
        this.setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson());
    }

    private void setXGBDMatrixInfo(String type, String json) throws XGBoostError {
        if (json == null || json.isEmpty()) {
            throw new XGBoostError("Empty " + type + " columns' array interface");
        }
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetInfoFromInterface(this.handle, type, json));
    }

    private void setXGBDMatrixFeatureInfo(String type, String[] values) throws XGBoostError {
        if (type == null || type.isEmpty()) {
            throw new XGBoostError("Found empty type");
        }
        if (values == null || values.length == 0) {
            throw new XGBoostError("Found empty values");
        }
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetStrFeatureInfo(this.handle, type, values));
    }

    private String[] getXGBDMatrixFeatureInfo(String type) throws XGBoostError {
        if (type == null || type.isEmpty()) {
            throw new XGBoostError("Found empty type");
        }
        long[] outLen = new long[1];
        String[][] outValue = new String[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetStrFeatureInfo(this.handle, type, outLen, outValue));
        if (outLen[0] != (long)outValue[0].length) {
            throw new RuntimeException("Failed to get " + type);
        }
        return outValue[0];
    }

    public void setFeatureNames(String[] values) throws XGBoostError {
        this.setXGBDMatrixFeatureInfo("feature_name", values);
    }

    public String[] getFeatureNames() throws XGBoostError {
        return this.getXGBDMatrixFeatureInfo("feature_name");
    }

    public void setFeatureTypes(String[] values) throws XGBoostError {
        this.setXGBDMatrixFeatureInfo("feature_type", values);
    }

    public String[] getFeatureTypes() throws XGBoostError {
        return this.getXGBDMatrixFeatureInfo("feature_type");
    }

    public void setLabel(float[] labels) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(this.handle, "label", labels));
    }

    public void setWeight(float[] weights) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(this.handle, "weight", weights));
    }

    public void setBaseMargin(float[] baseMargin) throws XGBoostError {
        if ((long)baseMargin.length != this.rowNum()) {
            throw new IllegalArgumentException(String.format("base margin must have exactly %s elements, got %s", this.rowNum(), baseMargin.length));
        }
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(this.handle, "base_margin", baseMargin));
    }

    public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
        this.setBaseMargin(DMatrix.flatten(baseMargin));
    }

    public void setGroup(int[] group) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetUIntInfo(this.handle, "group", group));
    }

    public int[] getGroup() throws XGBoostError {
        return this.getIntInfo("group_ptr");
    }

    private float[] getFloatInfo(String field) throws XGBoostError {
        float[][] infos = new float[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(this.handle, field, infos));
        return infos[0];
    }

    private int[] getIntInfo(String field) throws XGBoostError {
        int[][] infos = new int[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetUIntInfo(this.handle, field, infos));
        return infos[0];
    }

    public float[] getLabel() throws XGBoostError {
        return this.getFloatInfo("label");
    }

    public float[] getWeight() throws XGBoostError {
        return this.getFloatInfo("weight");
    }

    public float[] getBaseMargin() throws XGBoostError {
        return this.getFloatInfo("base_margin");
    }

    public DMatrix slice(int[] rowIndex) throws XGBoostError {
        long[] out = new long[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSliceDMatrix(this.handle, rowIndex, out));
        long sHandle = out[0];
        DMatrix sMatrix = new DMatrix(sHandle);
        return sMatrix;
    }

    public long rowNum() throws XGBoostError {
        long[] rowNum = new long[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixNumRow(this.handle, rowNum));
        return rowNum[0];
    }

    public void saveBinary(String filePath) {
        XGBoostJNI.XGDMatrixSaveBinary(this.handle, filePath, 1);
    }

    public long getHandle() {
        return this.handle;
    }

    private static float[] flatten(float[][] mat) {
        int size = 0;
        for (float[] array : mat) {
            size += array.length;
        }
        float[] result = new float[size];
        int pos = 0;
        for (float[] ar : mat) {
            System.arraycopy(ar, 0, result, pos, ar.length);
            pos += ar.length;
        }
        return result;
    }

    protected void finalize() {
        this.dispose();
    }

    public synchronized void dispose() {
        if (this.handle != 0L) {
            XGBoostJNI.XGDMatrixFree(this.handle);
            this.handle = 0L;
        }
    }

    public static enum SparseType {
        CSR,
        CSC;

    }
}

