/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.parse.spark;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.Context;
import org.apache.hadoop.hive.ql.exec.ConditionalTask;
import org.apache.hadoop.hive.ql.exec.DummyStoreOperator;
import org.apache.hadoop.hive.ql.exec.FileSinkOperator;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.SMBMapJoinOperator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.UnionOperator;
import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
import org.apache.hadoop.hive.ql.hooks.ReadEntity;
import org.apache.hadoop.hive.ql.hooks.WriteEntity;
import org.apache.hadoop.hive.ql.lib.CompositeProcessor;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.lib.TypeRule;
import org.apache.hadoop.hive.ql.log.PerfLogger;
import org.apache.hadoop.hive.ql.optimizer.physical.MetadataOnlyOptimizer;
import org.apache.hadoop.hive.ql.optimizer.physical.NullScanOptimizer;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hadoop.hive.ql.optimizer.physical.SparkCrossProductCheck;
import org.apache.hadoop.hive.ql.optimizer.physical.SparkMapJoinResolver;
import org.apache.hadoop.hive.ql.optimizer.physical.StageIDsRearranger;
import org.apache.hadoop.hive.ql.optimizer.physical.Vectorizer;
import org.apache.hadoop.hive.ql.optimizer.spark.SetSparkReducerParallelism;
import org.apache.hadoop.hive.ql.optimizer.spark.SparkJoinHintOptimizer;
import org.apache.hadoop.hive.ql.optimizer.spark.SparkJoinOptimizer;
import org.apache.hadoop.hive.ql.optimizer.spark.SparkReduceSinkMapJoinProc;
import org.apache.hadoop.hive.ql.optimizer.spark.SparkSkewJoinResolver;
import org.apache.hadoop.hive.ql.optimizer.spark.SplitSparkWorkResolver;
import org.apache.hadoop.hive.ql.parse.GlobalLimitCtx;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.TaskCompiler;
import org.apache.hadoop.hive.ql.parse.spark.GenSparkProcContext;
import org.apache.hadoop.hive.ql.parse.spark.GenSparkUtils;
import org.apache.hadoop.hive.ql.parse.spark.GenSparkWork;
import org.apache.hadoop.hive.ql.parse.spark.GenSparkWorkWalker;
import org.apache.hadoop.hive.ql.parse.spark.OptimizeSparkProcContext;
import org.apache.hadoop.hive.ql.parse.spark.SparkFileSinkProcessor;
import org.apache.hadoop.hive.ql.parse.spark.SparkProcessAnalyzeTable;
import org.apache.hadoop.hive.ql.parse.spark.SparkSMBMapJoinInfo;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.MoveWork;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.SparkWork;

public class SparkCompiler
extends TaskCompiler {
    private static final String CLASS_NAME = SparkCompiler.class.getName();
    private static final PerfLogger PERF_LOGGER = PerfLogger.getPerfLogger();
    private static final Log LOGGER = LogFactory.getLog(SparkCompiler.class);

    @Override
    protected void optimizeOperatorPlan(ParseContext pCtx, Set<ReadEntity> inputs, Set<WriteEntity> outputs) throws SemanticException {
        PERF_LOGGER.PerfLogBegin(CLASS_NAME, "SparkOptimizeOperatorTree");
        LinkedList<Operator<? extends OperatorDesc>> deque = new LinkedList<Operator<? extends OperatorDesc>>();
        deque.addAll(pCtx.getTopOps().values());
        OptimizeSparkProcContext procCtx = new OptimizeSparkProcContext(this.conf, pCtx, inputs, outputs, deque);
        LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
        opRules.put(new RuleRegExp("Set parallelism - ReduceSink", ReduceSinkOperator.getOperatorName() + "%"), new SetSparkReducerParallelism());
        opRules.put(new TypeRule(JoinOperator.class), new SparkJoinOptimizer(pCtx));
        opRules.put(new TypeRule(MapJoinOperator.class), new SparkJoinHintOptimizer(pCtx));
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(null, opRules, procCtx);
        DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pCtx.getTopOps().values());
        ogw.startWalking(topNodes, null);
        PERF_LOGGER.PerfLogEnd(CLASS_NAME, "SparkOptimizeOperatorTree");
    }

    @Override
    protected void generateTaskTree(List<Task<? extends Serializable>> rootTasks, ParseContext pCtx, List<Task<MoveWork>> mvTask, Set<ReadEntity> inputs, Set<WriteEntity> outputs) throws SemanticException {
        PERF_LOGGER.PerfLogBegin(CLASS_NAME, "SparkGenerateTaskTree");
        GenSparkUtils.getUtils().resetSequenceNumber();
        ParseContext tempParseContext = this.getParseContext(pCtx, rootTasks);
        GenSparkWork genSparkWork = new GenSparkWork(GenSparkUtils.getUtils());
        GenSparkProcContext procCtx = new GenSparkProcContext(this.conf, tempParseContext, mvTask, rootTasks, inputs, outputs, pCtx.getTopOps());
        LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
        opRules.put(new RuleRegExp("Split Work - ReduceSink", ReduceSinkOperator.getOperatorName() + "%"), genSparkWork);
        opRules.put(new TypeRule(MapJoinOperator.class), new SparkReduceSinkMapJoinProc());
        opRules.put(new RuleRegExp("Split Work + Move/Merge - FileSink", FileSinkOperator.getOperatorName() + "%"), new CompositeProcessor(new SparkFileSinkProcessor(), genSparkWork));
        opRules.put(new RuleRegExp("Handle Analyze Command", TableScanOperator.getOperatorName() + "%"), new SparkProcessAnalyzeTable(GenSparkUtils.getUtils()));
        opRules.put(new RuleRegExp("Remember union", UnionOperator.getOperatorName() + "%"), new NodeProcessor(){

            @Override
            public Object process(Node n, Stack<Node> s, NodeProcessorCtx procCtx, Object ... os) throws SemanticException {
                GenSparkProcContext context = (GenSparkProcContext)procCtx;
                UnionOperator union = (UnionOperator)n;
                context.currentUnionOperators.add(union);
                return null;
            }
        });
        opRules.put(new TypeRule(SMBMapJoinOperator.class), new NodeProcessor(){

            @Override
            public Object process(Node currNode, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... os) throws SemanticException {
                GenSparkProcContext context = (GenSparkProcContext)procCtx;
                SMBMapJoinOperator currSmbNode = (SMBMapJoinOperator)currNode;
                SparkSMBMapJoinInfo smbMapJoinCtx = context.smbMapJoinCtxMap.get(currSmbNode);
                if (smbMapJoinCtx == null) {
                    smbMapJoinCtx = new SparkSMBMapJoinInfo();
                    context.smbMapJoinCtxMap.put(currSmbNode, smbMapJoinCtx);
                }
                for (Node stackNode : stack) {
                    if (!(stackNode instanceof DummyStoreOperator)) continue;
                    smbMapJoinCtx.smallTableRootOps.add(context.currentRootOperator);
                    return true;
                }
                smbMapJoinCtx.bigTableRootOp = context.currentRootOperator;
                return false;
            }
        });
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(null, opRules, procCtx);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pCtx.getTopOps().values());
        GenSparkWorkWalker ogw = new GenSparkWorkWalker(disp, procCtx);
        ogw.startWalking(topNodes, null);
        for (BaseWork w : procCtx.workWithUnionOperators) {
            GenSparkUtils.getUtils().removeUnionOperators(this.conf, procCtx, w);
        }
        GenSparkUtils.getUtils().annotateMapWork(procCtx);
        for (FileSinkOperator fileSink : procCtx.fileSinkSet) {
            GenSparkUtils.getUtils().processFileSink(procCtx, fileSink);
        }
        PERF_LOGGER.PerfLogEnd(CLASS_NAME, "SparkGenerateTaskTree");
    }

    @Override
    protected void setInputFormat(Task<? extends Serializable> task) {
        if (task instanceof SparkTask) {
            SparkWork work = (SparkWork)((SparkTask)task).getWork();
            List<BaseWork> all = work.getAllWork();
            for (BaseWork w : all) {
                MapWork mapWork;
                LinkedHashMap<String, Operator<? extends OperatorDesc>> opMap;
                if (!(w instanceof MapWork) || (opMap = (mapWork = (MapWork)w).getAliasToWork()).isEmpty()) continue;
                for (Operator op : ((HashMap)opMap).values()) {
                    this.setInputFormat(mapWork, op);
                }
            }
        } else if (task instanceof ConditionalTask) {
            List<Task<? extends Serializable>> listTasks = ((ConditionalTask)task).getListTasks();
            for (Task<? extends Serializable> tsk : listTasks) {
                this.setInputFormat(tsk);
            }
        }
        if (task.getChildTasks() != null) {
            for (Task<Serializable> childTask : task.getChildTasks()) {
                this.setInputFormat(childTask);
            }
        }
    }

    private void setInputFormat(MapWork work, Operator<? extends OperatorDesc> op) {
        if (op.isUseBucketizedHiveInputFormat()) {
            work.setUseBucketizedHiveInputFormat(true);
            return;
        }
        if (op.getChildOperators() != null) {
            for (Operator<OperatorDesc> childOp : op.getChildOperators()) {
                this.setInputFormat(work, childOp);
            }
        }
    }

    @Override
    protected void decideExecMode(List<Task<? extends Serializable>> rootTasks, Context ctx, GlobalLimitCtx globalLimitCtx) throws SemanticException {
    }

    @Override
    protected void optimizeTaskPlan(List<Task<? extends Serializable>> rootTasks, ParseContext pCtx, Context ctx) throws SemanticException {
        PERF_LOGGER.PerfLogBegin(CLASS_NAME, "SparkOptimizeTaskTree");
        PhysicalContext physicalCtx = new PhysicalContext(this.conf, pCtx, pCtx.getContext(), rootTasks, pCtx.getFetchTask());
        physicalCtx = new SplitSparkWorkResolver().resolve(physicalCtx);
        if (this.conf.getBoolVar(HiveConf.ConfVars.HIVESKEWJOIN)) {
            new SparkSkewJoinResolver().resolve(physicalCtx);
        } else {
            this.LOG.debug((Object)"Skipping runtime skew join optimization");
        }
        physicalCtx = new SparkMapJoinResolver().resolve(physicalCtx);
        if (this.conf.getBoolVar(HiveConf.ConfVars.HIVENULLSCANOPTIMIZE)) {
            physicalCtx = new NullScanOptimizer().resolve(physicalCtx);
        } else {
            this.LOG.debug((Object)"Skipping null scan query optimization");
        }
        if (this.conf.getBoolVar(HiveConf.ConfVars.HIVEMETADATAONLYQUERIES)) {
            physicalCtx = new MetadataOnlyOptimizer().resolve(physicalCtx);
        } else {
            this.LOG.debug((Object)"Skipping metadata only query optimization");
        }
        if (this.conf.getBoolVar(HiveConf.ConfVars.HIVE_CHECK_CROSS_PRODUCT)) {
            physicalCtx = new SparkCrossProductCheck().resolve(physicalCtx);
        } else {
            this.LOG.debug((Object)"Skipping cross product analysis");
        }
        if (this.conf.getBoolVar(HiveConf.ConfVars.HIVE_VECTORIZATION_ENABLED)) {
            new Vectorizer().resolve(physicalCtx);
        } else {
            this.LOG.debug((Object)"Skipping vectorization");
        }
        if (!"none".equalsIgnoreCase(this.conf.getVar(HiveConf.ConfVars.HIVESTAGEIDREARRANGE))) {
            new StageIDsRearranger().resolve(physicalCtx);
        } else {
            this.LOG.debug((Object)"Skipping stage id rearranger");
        }
        PERF_LOGGER.PerfLogEnd(CLASS_NAME, "SparkOptimizeTaskTree");
    }
}

