/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.ppd;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import org.apache.hadoop.hive.ql.exec.CommonJoinOperator;
import org.apache.hadoop.hive.ql.exec.FilterOperator;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.RowSchema;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.PreOrderWalker;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.parse.OpParseContext;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.RowResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDescUtils;
import org.apache.hadoop.hive.ql.plan.FilterDesc;
import org.apache.hadoop.hive.ql.plan.JoinCondDesc;
import org.apache.hadoop.hive.ql.plan.JoinDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;

public class PredicateTransitivePropagate
implements Transform {
    private ParseContext pGraphContext;

    @Override
    public ParseContext transform(ParseContext pctx) throws SemanticException {
        this.pGraphContext = pctx;
        LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
        opRules.put(new RuleRegExp("R1", "(" + FilterOperator.getOperatorName() + "%" + ReduceSinkOperator.getOperatorName() + "%" + JoinOperator.getOperatorName() + "%)|" + "(" + FilterOperator.getOperatorName() + "%" + ReduceSinkOperator.getOperatorName() + "%" + MapJoinOperator.getOperatorName() + "%)"), new JoinTransitive());
        TransitiveContext context = new TransitiveContext();
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(null, opRules, context);
        PreOrderWalker ogw = new PreOrderWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(this.pGraphContext.getTopOps().values());
        ogw.startWalking(topNodes, null);
        Map<ReduceSinkOperator, ExprNodeDesc> newFilters = context.getNewfilters();
        for (Map.Entry<ReduceSinkOperator, ExprNodeDesc> entry : newFilters.entrySet()) {
            ReduceSinkOperator reducer = entry.getKey();
            Operator<OperatorDesc> parent = reducer.getParentOperators().get(0);
            ExprNodeDesc expr = entry.getValue();
            if (parent instanceof FilterOperator) {
                ExprNodeDesc prev = ((FilterDesc)((FilterOperator)parent).getConf()).getPredicate();
                ExprNodeDesc merged = ExprNodeDescUtils.mergePredicates(prev, expr);
                ((FilterDesc)((FilterOperator)parent).getConf()).setPredicate(merged);
                continue;
            }
            RowResolver parentRR = this.pGraphContext.getOpParseCtx().get(parent).getRowResolver();
            Operator<FilterDesc> newFilter = this.createFilter(reducer, parent, parentRR, expr);
            this.pGraphContext.getOpParseCtx().put(newFilter, new OpParseContext(parentRR));
        }
        return this.pGraphContext;
    }

    private Operator<FilterDesc> createFilter(Operator<?> target, Operator<?> parent, RowResolver parentRR, ExprNodeDesc filterExpr) {
        Operator<FilterDesc> filter = OperatorFactory.get(new FilterDesc(filterExpr, false), new RowSchema(parentRR.getColumnInfos()), new Operator[0]);
        filter.setParentOperators(new ArrayList<Operator<? extends OperatorDesc>>());
        filter.setChildOperators(new ArrayList<Operator<? extends OperatorDesc>>());
        filter.getParentOperators().add(parent);
        filter.getChildOperators().add(target);
        parent.replaceChild(target, filter);
        target.replaceParent(parent, filter);
        return filter;
    }

    private static class Vectors {
        private Set<Integer>[] vector;

        public Vectors(int length) {
            this.vector = new Set[length];
        }

        public void add(int from, int to) {
            if (this.vector[from] == null) {
                this.vector[from] = new HashSet<Integer>();
            }
            this.vector[from].add(to);
        }

        public int[] traverse(int pos) {
            HashSet<Integer> targets = new HashSet<Integer>();
            this.traverse(targets, pos);
            return this.toArray(targets);
        }

        private int[] toArray(Set<Integer> values) {
            int index = 0;
            int[] result = new int[values.size()];
            for (int value : values) {
                result[index++] = value;
            }
            return result;
        }

        private void traverse(Set<Integer> targets, int pos) {
            if (this.vector[pos] == null) {
                return;
            }
            for (int target : this.vector[pos]) {
                if (!targets.add(target)) continue;
                this.traverse(targets, target);
            }
        }
    }

    private static class JoinTransitive
    implements NodeProcessor {
        private JoinTransitive() {
        }

        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            CommonJoinOperator join = (CommonJoinOperator)nd;
            ReduceSinkOperator source = (ReduceSinkOperator)stack.get(stack.size() - 2);
            FilterOperator filter = (FilterOperator)stack.get(stack.size() - 3);
            int srcPos = join.getParentOperators().indexOf(source);
            TransitiveContext context = (TransitiveContext)procCtx;
            Map<CommonJoinOperator, int[][]> filterPropagates = context.getFilterPropates();
            Map<ReduceSinkOperator, ExprNodeDesc> newFilters = context.getNewfilters();
            int[][] targets = filterPropagates.get(join);
            if (targets == null) {
                targets = this.getTargets(join);
                filterPropagates.put(join, targets);
            }
            List<Operator<OperatorDesc>> parents = join.getParentOperators();
            for (int targetPos : targets[srcPos]) {
                ReduceSinkOperator target = (ReduceSinkOperator)parents.get(targetPos);
                ArrayList<ExprNodeDesc> sourceKeys = ((ReduceSinkDesc)source.getConf()).getKeyCols();
                ArrayList<ExprNodeDesc> targetKeys = ((ReduceSinkDesc)target.getConf()).getKeyCols();
                ExprNodeDesc predicate = ((FilterDesc)filter.getConf()).getPredicate();
                ExprNodeDesc replaced = ExprNodeDescUtils.replace(predicate, sourceKeys, targetKeys);
                if (replaced == null || this.filterExists(target, replaced)) continue;
                ExprNodeDesc prev = newFilters.get(target);
                if (prev == null) {
                    newFilters.put(target, replaced);
                    continue;
                }
                newFilters.put(target, ExprNodeDescUtils.mergePredicates(prev, replaced));
            }
            return null;
        }

        private int[][] getTargets(CommonJoinOperator<JoinDesc> join) {
            JoinCondDesc[] conds = ((JoinDesc)join.getConf()).getConds();
            int aliases = conds.length + 1;
            Vectors vector = new Vectors(aliases);
            block5: for (JoinCondDesc cond : conds) {
                int left = cond.getLeft();
                int right = cond.getRight();
                switch (cond.getType()) {
                    case 0: 
                    case 5: {
                        vector.add(left, right);
                        vector.add(right, left);
                        continue block5;
                    }
                    case 1: {
                        vector.add(left, right);
                        continue block5;
                    }
                    case 2: {
                        vector.add(right, left);
                        continue block5;
                    }
                }
            }
            int[][] result = new int[aliases][];
            for (int pos = 0; pos < aliases; ++pos) {
                result[pos] = vector.traverse(pos);
            }
            return result;
        }

        private boolean filterExists(ReduceSinkOperator target, ExprNodeDesc replaced) {
            Operator<OperatorDesc> operator = target.getParentOperators().get(0);
            while (operator instanceof FilterOperator) {
                ExprNodeDesc predicate = ((FilterDesc)((FilterOperator)operator).getConf()).getPredicate();
                if (ExprNodeDescUtils.containsPredicate(predicate, replaced)) {
                    return true;
                }
                operator = operator.getParentOperators().get(0);
            }
            return false;
        }
    }

    private static class TransitiveContext
    implements NodeProcessorCtx {
        private final Map<CommonJoinOperator, int[][]> filterPropagates = new HashMap<CommonJoinOperator, int[][]>();
        private final Map<ReduceSinkOperator, ExprNodeDesc> newFilters = new HashMap<ReduceSinkOperator, ExprNodeDesc>();

        public Map<CommonJoinOperator, int[][]> getFilterPropates() {
            return this.filterPropagates;
        }

        public Map<ReduceSinkOperator, ExprNodeDesc> getNewfilters() {
            return this.newFilters;
        }
    }
}

