/*
 * Decompiled with CFR 0.152.
 */
package org.eigenbase.rel.rules;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import net.hydromatic.optiq.prepare.OptiqPrepareImpl;
import net.hydromatic.optiq.util.BitSets;
import org.eigenbase.rel.JoinRelType;
import org.eigenbase.rel.RelFactories;
import org.eigenbase.rel.RelNode;
import org.eigenbase.rel.metadata.RelMdUtil;
import org.eigenbase.rel.metadata.RelMetadataQuery;
import org.eigenbase.rel.rules.LoptMultiJoin;
import org.eigenbase.rel.rules.MultiJoinRel;
import org.eigenbase.relopt.RelOptRule;
import org.eigenbase.relopt.RelOptRuleCall;
import org.eigenbase.relopt.RelOptUtil;
import org.eigenbase.rex.RexBuilder;
import org.eigenbase.rex.RexNode;
import org.eigenbase.rex.RexPermuteInputsShuttle;
import org.eigenbase.rex.RexUtil;
import org.eigenbase.util.Pair;
import org.eigenbase.util.Util;
import org.eigenbase.util.mapping.Mappings;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class OptimizeBushyJoinRule
extends RelOptRule {
    public static final OptimizeBushyJoinRule INSTANCE = new OptimizeBushyJoinRule(RelFactories.DEFAULT_JOIN_FACTORY, RelFactories.DEFAULT_PROJECT_FACTORY);
    private final RelFactories.JoinFactory joinFactory;
    private final RelFactories.ProjectFactory projectFactory;
    private final PrintWriter pw = OptiqPrepareImpl.DEBUG ? new PrintWriter(System.out, true) : null;

    public OptimizeBushyJoinRule(RelFactories.JoinFactory joinFactory, RelFactories.ProjectFactory projectFactory) {
        super(OptimizeBushyJoinRule.operand(MultiJoinRel.class, OptimizeBushyJoinRule.any()));
        this.joinFactory = joinFactory;
        this.projectFactory = projectFactory;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        MultiJoinRel multiJoinRel = (MultiJoinRel)call.rel(0);
        RexBuilder rexBuilder = multiJoinRel.getCluster().getRexBuilder();
        LoptMultiJoin multiJoin = new LoptMultiJoin(multiJoinRel);
        final ArrayList vertexes = Lists.newArrayList();
        int x = 0;
        for (int i = 0; i < multiJoin.getNumJoinFactors(); ++i) {
            RelNode rel = multiJoin.getJoinFactor(i);
            double cost = RelMetadataQuery.getRowCount(rel);
            vertexes.add(new LeafVertex(i, rel, cost, x));
            x += rel.getRowType().getFieldCount();
        }
        assert (x == multiJoin.getNumTotalFields());
        ArrayList unusedEdges = Lists.newArrayList();
        for (RexNode node : multiJoin.getJoinFilters()) {
            unusedEdges.add(multiJoin.createEdge(node));
        }
        Comparator<LoptMultiJoin.Edge> edgeComparator = new Comparator<LoptMultiJoin.Edge>(){

            @Override
            public int compare(LoptMultiJoin.Edge e0, LoptMultiJoin.Edge e1) {
                return Double.compare(this.rowCountDiff(e0), this.rowCountDiff(e1));
            }

            private double rowCountDiff(LoptMultiJoin.Edge edge) {
                assert (edge.factors.cardinality() == 2) : edge.factors;
                int factor0 = edge.factors.nextSetBit(0);
                int factor1 = edge.factors.nextSetBit(factor0 + 1);
                return Math.abs(((Vertex)vertexes.get((int)factor0)).cost - ((Vertex)vertexes.get((int)factor1)).cost);
            }
        };
        ArrayList usedEdges = Lists.newArrayList();
        block2: while (true) {
            int minorFactor;
            int majorFactor;
            int[] factors;
            int edgeOrdinal = this.chooseBestEdge(unusedEdges, edgeComparator);
            if (this.pw != null) {
                this.trace(vertexes, unusedEdges, usedEdges, edgeOrdinal, this.pw);
            }
            if (edgeOrdinal == -1) {
                Vertex lastVertex = (Vertex)Util.last(vertexes);
                int z = BitSets.previousClearBit(lastVertex.factors, lastVertex.id - 1);
                if (z < 0) break;
                factors = new int[]{z, lastVertex.id};
            } else {
                LoptMultiJoin.Edge bestEdge = (LoptMultiJoin.Edge)unusedEdges.get(edgeOrdinal);
                assert (bestEdge.factors.cardinality() == 2);
                factors = BitSets.toArray(bestEdge.factors);
            }
            if (((Vertex)vertexes.get((int)factors[0])).cost <= ((Vertex)vertexes.get((int)factors[1])).cost) {
                majorFactor = factors[0];
                minorFactor = factors[1];
            } else {
                majorFactor = factors[1];
                minorFactor = factors[0];
            }
            Vertex majorVertex = (Vertex)vertexes.get(majorFactor);
            Vertex minorVertex = (Vertex)vertexes.get(minorFactor);
            BitSet newFactors = BitSets.union(majorVertex.factors, minorVertex.factors);
            ArrayList conditions = Lists.newArrayList();
            Iterator edgeIterator = unusedEdges.iterator();
            while (edgeIterator.hasNext()) {
                LoptMultiJoin.Edge edge = (LoptMultiJoin.Edge)edgeIterator.next();
                if (!BitSets.contains(newFactors, edge.factors)) continue;
                conditions.add(edge.condition);
                edgeIterator.remove();
                usedEdges.add(edge);
            }
            int v = vertexes.size();
            double cost = majorVertex.cost * minorVertex.cost * RelMdUtil.guessSelectivity(RexUtil.composeConjunction(rexBuilder, conditions, false));
            newFactors.set(v);
            JoinVertex newVertex = new JoinVertex(v, majorFactor, minorFactor, newFactors, cost, (ImmutableList<RexNode>)ImmutableList.copyOf((Collection)conditions));
            vertexes.add(newVertex);
            BitSet merged = BitSets.of(minorFactor, majorFactor);
            int i = 0;
            while (true) {
                if (i >= unusedEdges.size()) continue block2;
                LoptMultiJoin.Edge edge = (LoptMultiJoin.Edge)unusedEdges.get(i);
                if (edge.factors.intersects(merged)) {
                    BitSet newEdgeFactors = (BitSet)edge.factors.clone();
                    newEdgeFactors.andNot(newFactors);
                    newEdgeFactors.set(v);
                    assert (newEdgeFactors.cardinality() == 2);
                    LoptMultiJoin.Edge newEdge = new LoptMultiJoin.Edge(edge.condition, newEdgeFactors, edge.columns);
                    unusedEdges.set(i, newEdge);
                }
                ++i;
            }
            break;
        }
        ArrayList relNodes = Lists.newArrayList();
        for (Vertex vertex : vertexes) {
            if (vertex instanceof LeafVertex) {
                LeafVertex leafVertex = (LeafVertex)vertex;
                Mappings.TargetMapping mapping = Mappings.offsetSource(Mappings.createIdentity(leafVertex.rel.getRowType().getFieldCount()), leafVertex.fieldOffset, multiJoin.getNumTotalFields());
                relNodes.add(Pair.of(leafVertex.rel, mapping));
            } else {
                JoinVertex joinVertex = (JoinVertex)vertex;
                Pair leftPair = (Pair)relNodes.get(joinVertex.leftFactor);
                RelNode left = (RelNode)leftPair.left;
                Mappings.TargetMapping leftMapping = (Mappings.TargetMapping)leftPair.right;
                Pair rightPair = (Pair)relNodes.get(joinVertex.rightFactor);
                RelNode right = (RelNode)rightPair.left;
                Mappings.TargetMapping rightMapping = (Mappings.TargetMapping)rightPair.right;
                Mappings.TargetMapping mapping = Mappings.merge(leftMapping, Mappings.offsetTarget(rightMapping, left.getRowType().getFieldCount()));
                if (this.pw != null) {
                    this.pw.println("left: " + leftMapping);
                    this.pw.println("right: " + rightMapping);
                    this.pw.println("combined: " + mapping);
                    this.pw.println();
                }
                RexPermuteInputsShuttle shuttle = new RexPermuteInputsShuttle(mapping, left, right);
                RexNode condition = RexUtil.composeConjunction(rexBuilder, joinVertex.conditions, false);
                relNodes.add(Pair.of(this.joinFactory.createJoin(left, right, condition.accept(shuttle), JoinRelType.INNER, (Set<String>)ImmutableSet.of(), false), mapping));
            }
            if (this.pw == null) continue;
            this.pw.println(Util.last(relNodes));
        }
        Pair top = (Pair)Util.last(relNodes);
        RelNode project = RelOptUtil.createProject(this.projectFactory, (RelNode)top.left, Mappings.asList((Mappings.TargetMapping)top.right));
        call.transformTo(project);
    }

    private void trace(List<Vertex> vertexes, List<LoptMultiJoin.Edge> unusedEdges, List<LoptMultiJoin.Edge> usedEdges, int edgeOrdinal, PrintWriter pw) {
        pw.println("bestEdge: " + edgeOrdinal);
        pw.println("vertexes:");
        for (Vertex vertex : vertexes) {
            pw.println(vertex);
        }
        pw.println("unused edges:");
        for (LoptMultiJoin.Edge edge : unusedEdges) {
            pw.println(edge);
        }
        pw.println("edges:");
        for (LoptMultiJoin.Edge edge : usedEdges) {
            pw.println(edge);
        }
        pw.println();
        pw.flush();
    }

    int chooseBestEdge(List<LoptMultiJoin.Edge> edges, Comparator<LoptMultiJoin.Edge> comparator) {
        return OptimizeBushyJoinRule.minPos(edges, comparator);
    }

    static <E> int minPos(List<E> list, Comparator<E> fn) {
        if (list.isEmpty()) {
            return -1;
        }
        E eBest = list.get(0);
        int iBest = 0;
        for (int i = 1; i < list.size(); ++i) {
            E e = list.get(i);
            if (fn.compare(e, eBest) >= 0) continue;
            eBest = e;
            iBest = i;
        }
        return iBest;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class JoinVertex
    extends Vertex {
        private final int leftFactor;
        private final int rightFactor;
        final ImmutableList<RexNode> conditions;

        JoinVertex(int id, int leftFactor, int rightFactor, BitSet factors, double cost, ImmutableList<RexNode> conditions) {
            super(id, factors, cost);
            this.leftFactor = leftFactor;
            this.rightFactor = rightFactor;
            this.conditions = (ImmutableList)Preconditions.checkNotNull(conditions);
        }

        public String toString() {
            return "JoinVertex(id: " + this.id + ", cost: " + Util.human(this.cost) + ", factors: " + this.factors + ", leftFactor: " + this.leftFactor + ", rightFactor: " + this.rightFactor + ")";
        }
    }

    static class LeafVertex
    extends Vertex {
        private final RelNode rel;
        final int fieldOffset;

        LeafVertex(int id, RelNode rel, double cost, int fieldOffset) {
            super(id, BitSets.of(id), cost);
            this.rel = rel;
            this.fieldOffset = fieldOffset;
        }

        public String toString() {
            return "LeafVertex(id: " + this.id + ", cost: " + Util.human(this.cost) + ", factors: " + this.factors + ", fieldOffset: " + this.fieldOffset + ")";
        }
    }

    static abstract class Vertex {
        final int id;
        protected final BitSet factors;
        final double cost;

        Vertex(int id, BitSet factors, double cost) {
            this.id = id;
            this.factors = factors;
            this.cost = cost;
        }
    }
}

