/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.adapter.druid.DruidQuery;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexPermuteInputsShuttle;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql2rel.RelFieldTrimmer;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.IntPair;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan;
import org.apache.hadoop.hive.ql.parse.ColumnAccessInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HiveRelFieldTrimmer
extends RelFieldTrimmer {
    protected static final Logger LOG = LoggerFactory.getLogger(HiveRelFieldTrimmer.class);
    private ColumnAccessInfo columnAccessInfo;
    private Map<HiveProject, Table> viewProjectToTableSchema;
    private final RelBuilder relBuilder;
    private final boolean fetchStats;

    public HiveRelFieldTrimmer(SqlValidator validator, RelBuilder relBuilder) {
        this(validator, relBuilder, false);
    }

    public HiveRelFieldTrimmer(SqlValidator validator, RelBuilder relBuilder, ColumnAccessInfo columnAccessInfo, Map<HiveProject, Table> viewToTableSchema) {
        this(validator, relBuilder, false);
        this.columnAccessInfo = columnAccessInfo;
        this.viewProjectToTableSchema = viewToTableSchema;
    }

    public HiveRelFieldTrimmer(SqlValidator validator, RelBuilder relBuilder, boolean fetchStats) {
        super(validator, relBuilder);
        this.relBuilder = relBuilder;
        this.fetchStats = fetchStats;
    }

    public RelFieldTrimmer.TrimResult trimFields(HiveMultiJoin join, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
        int fieldCount = join.getRowType().getFieldCount();
        RexNode conditionExpr = join.getCondition();
        List<RexNode> joinFilters = join.getJoinFilters();
        LinkedHashSet<RelDataTypeField> combinedInputExtraFields = new LinkedHashSet<RelDataTypeField>(extraFields);
        RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(combinedInputExtraFields);
        inputFinder.inputBitSet.addAll(fieldsUsed);
        conditionExpr.accept(inputFinder);
        ImmutableBitSet fieldsUsedPlus = inputFinder.inputBitSet.build();
        int inputStartPos = 0;
        int changeCount = 0;
        int newFieldCount = 0;
        ArrayList<RelNode> newInputs = new ArrayList<RelNode>();
        ArrayList<Mapping> inputMappings = new ArrayList<Mapping>();
        for (RelNode input : join.getInputs()) {
            RelDataType inputRowType = input.getRowType();
            int inputFieldCount = inputRowType.getFieldCount();
            ImmutableBitSet.Builder inputFieldsUsed = ImmutableBitSet.builder();
            for (int bit : fieldsUsedPlus) {
                if (bit < inputStartPos || bit >= inputStartPos + inputFieldCount) continue;
                inputFieldsUsed.set(bit - inputStartPos);
            }
            Set inputExtraFields = Collections.emptySet();
            RelFieldTrimmer.TrimResult trimResult = this.trimChild(join, input, inputFieldsUsed.build(), inputExtraFields);
            newInputs.add((RelNode)trimResult.left);
            if (trimResult.left != input) {
                ++changeCount;
            }
            Mapping inputMapping = (Mapping)trimResult.right;
            inputMappings.add(inputMapping);
            inputStartPos += inputFieldCount;
            newFieldCount += inputMapping.getTargetCount();
        }
        Mapping mapping = Mappings.create(MappingType.INVERSE_SURJECTION, fieldCount, newFieldCount);
        int offset = 0;
        int newOffset = 0;
        for (int i = 0; i < inputMappings.size(); ++i) {
            Mapping inputMapping = (Mapping)inputMappings.get(i);
            for (Object pair : inputMapping) {
                mapping.set(((IntPair)pair).source + offset, ((IntPair)pair).target + newOffset);
            }
            offset += inputMapping.getSourceCount();
            newOffset += inputMapping.getTargetCount();
        }
        if (changeCount == 0 && mapping.isIdentity()) {
            return new RelFieldTrimmer.TrimResult(join, Mappings.createIdentity(fieldCount));
        }
        RexPermuteInputsShuttle shuttle = new RexPermuteInputsShuttle((Mappings.TargetMapping)mapping, newInputs.toArray(new RelNode[newInputs.size()]));
        RexNode newConditionExpr = conditionExpr.accept(shuttle);
        ArrayList<RexNode> newJoinFilters = Lists.newArrayList();
        for (RexNode joinFilter : joinFilters) {
            newJoinFilters.add(joinFilter.accept(shuttle));
        }
        RelDataType newRowType = RelOptUtil.permute(join.getCluster().getTypeFactory(), join.getRowType(), mapping);
        HiveMultiJoin newJoin = new HiveMultiJoin(join.getCluster(), newInputs, newConditionExpr, newRowType, join.getJoinInputs(), join.getJoinTypes(), newJoinFilters);
        return new RelFieldTrimmer.TrimResult(newJoin, mapping);
    }

    public RelFieldTrimmer.TrimResult trimFields(DruidQuery dq, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
        int fieldCount = dq.getRowType().getFieldCount();
        if (fieldsUsed.equals(ImmutableBitSet.range(fieldCount)) && extraFields.isEmpty()) {
            return this.trimFields((RelNode)dq, fieldsUsed, extraFields);
        }
        RelNode newTableAccessRel = HiveRelFieldTrimmer.project(dq, fieldsUsed, extraFields, this.relBuilder);
        if (fieldsUsed.cardinality() == 0) {
            Project project;
            RelNode input = newTableAccessRel;
            if (input instanceof Project && (project = (Project)input).getRowType().getFieldCount() == 0) {
                input = project.getInput();
            }
            return this.dummyProject(fieldCount, input);
        }
        Mapping mapping = this.createMapping(fieldsUsed, fieldCount);
        return this.result(newTableAccessRel, mapping);
    }

    private static RelNode project(DruidQuery dq, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields, RelBuilder relBuilder) {
        int fieldCount = dq.getRowType().getFieldCount();
        if (fieldsUsed.equals(ImmutableBitSet.range(fieldCount)) && extraFields.isEmpty()) {
            return dq;
        }
        ArrayList<RexNode> exprList = new ArrayList<RexNode>();
        ArrayList<String> nameList = new ArrayList<String>();
        RexBuilder rexBuilder = dq.getCluster().getRexBuilder();
        List<RelDataTypeField> fields = dq.getRowType().getFieldList();
        Iterator<Object> iterator = fieldsUsed.iterator();
        while (iterator.hasNext()) {
            int i = iterator.next();
            RelDataTypeField field = fields.get(i);
            exprList.add(rexBuilder.makeInputRef(dq, i));
            nameList.add(field.getName());
        }
        for (RelDataTypeField extraField : extraFields) {
            exprList.add(rexBuilder.ensureType(extraField.getType(), rexBuilder.constantNull(), true));
            nameList.add(extraField.getName());
        }
        HiveProject hp = (HiveProject)relBuilder.push(dq).project(exprList, nameList).build();
        hp.setSynthetic();
        return hp;
    }

    @Override
    public RelFieldTrimmer.TrimResult trimFields(Project project, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
        for (Ord<RexNode> ord : Ord.zip(project.getProjects())) {
            if (!fieldsUsed.get(ord.i) || this.columnAccessInfo == null || this.viewProjectToTableSchema == null || !this.viewProjectToTableSchema.containsKey(project)) continue;
            Table tab = this.viewProjectToTableSchema.get(project);
            this.columnAccessInfo.add(tab.getCompleteName(), tab.getCols().get(ord.i).getName());
        }
        return super.trimFields(project, fieldsUsed, extraFields);
    }

    @Override
    public RelFieldTrimmer.TrimResult trimFields(TableScan tableAccessRel, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
        RelFieldTrimmer.TrimResult result = super.trimFields(tableAccessRel, fieldsUsed, extraFields);
        if (this.fetchStats) {
            this.fetchColStats((RelNode)result.getKey(), tableAccessRel, fieldsUsed, extraFields);
        }
        return result;
    }

    private void fetchColStats(RelNode key, TableScan tableAccessRel, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
        RelOptTable table;
        ArrayList<Integer> iRefSet = Lists.newArrayList();
        if (key instanceof Project) {
            Project project = (Project)key;
            for (RexNode rx : project.getChildExps()) {
                iRefSet.addAll(HiveCalciteUtil.getInputRefs(rx));
            }
        } else {
            int fieldCount = tableAccessRel.getRowType().getFieldCount();
            if (fieldsUsed.equals(ImmutableBitSet.range(fieldCount)) && extraFields.isEmpty()) {
                iRefSet.addAll(ImmutableBitSet.range(fieldCount).asList());
            }
        }
        if (tableAccessRel instanceof HiveTableScan) {
            iRefSet.removeAll(((HiveTableScan)tableAccessRel).getPartOrVirtualCols());
        }
        if (!iRefSet.isEmpty() && (table = tableAccessRel.getTable()) instanceof RelOptHiveTable) {
            ((RelOptHiveTable)table).getColStat(iRefSet, true);
            LOG.debug("Got col stats for {} in {}", (Object)iRefSet, (Object)tableAccessRel.getTable().getQualifiedName());
        }
    }
}

