/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptPredicateList;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RexImplicationChecker;
import org.apache.calcite.plan.Strong;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexExecutor;
import org.apache.calcite.rex.RexExecutorImpl;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Util;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAntiJoin;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HiveAntiSemiJoinRule
extends RelOptRule {
    protected static final Logger LOG = LoggerFactory.getLogger(HiveAntiSemiJoinRule.class);
    public static final HiveAntiSemiJoinRule INSTANCE = new HiveAntiSemiJoinRule();

    public HiveAntiSemiJoinRule() {
        super(HiveAntiSemiJoinRule.operand(Project.class, (RelOptRuleOperand)HiveAntiSemiJoinRule.operand(Filter.class, (RelOptRuleOperand)HiveAntiSemiJoinRule.operand(Join.class, (RelOptRuleOperandChildren)RelOptRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), (RelOptRuleOperand[])new RelOptRuleOperand[0]), "HiveJoinWithFilterToAntiJoinRule:filter");
    }

    public void onMatch(RelOptRuleCall call) {
        Project project = (Project)call.rel(0);
        Filter filter = (Filter)call.rel(1);
        Join join = (Join)call.rel(2);
        this.perform(call, project, filter, join);
    }

    protected void perform(RelOptRuleCall call, Project project, Filter filter, Join join) {
        Project newProject;
        LOG.debug("Start Matching HiveAntiJoinRule");
        if (join.getCondition().isAlwaysTrue()) {
            return;
        }
        if (join.getJoinType() != JoinRelType.LEFT) {
            return;
        }
        assert (filter != null);
        ImmutableBitSet rhsFields = HiveCalciteUtil.getRightSideBitset((RelNode)join);
        Optional<List<RexNode>> optFilterList = this.getResidualFilterNodes(filter, join, rhsFields);
        if (optFilterList.isEmpty()) {
            return;
        }
        List<RexNode> filterList = optFilterList.get();
        ImmutableBitSet projectedFields = RelOptUtil.InputFinder.bits((List)project.getProjects(), null);
        boolean projectionUsesRHS = projectedFields.intersects(rhsFields);
        if (projectionUsesRHS) {
            return;
        }
        if (HiveCalciteUtil.checkIfJoinConditionOnlyUsesLeftOperands(join)) {
            return;
        }
        LOG.debug("Matched HiveAntiJoinRule");
        HiveAntiJoin anti = HiveAntiJoin.getAntiJoin(join.getLeft().getCluster(), join.getLeft().getTraitSet(), join.getLeft(), join.getRight(), join.getCondition());
        if (filterList.isEmpty()) {
            newProject = project.copy(project.getTraitSet(), (RelNode)anti, project.getProjects(), project.getRowType());
        } else {
            RexNode condition = filterList.size() == 1 ? filterList.get(0) : join.getCluster().getRexBuilder().makeCall((SqlOperator)SqlStdOperatorTable.AND, filterList);
            Filter newFilter = filter.copy(filter.getTraitSet(), (RelNode)anti, condition);
            newProject = project.copy(project.getTraitSet(), (RelNode)newFilter, project.getProjects(), project.getRowType());
        }
        call.transformTo((RelNode)newProject);
    }

    private Optional<List<RexNode>> getResidualFilterNodes(Filter filter, Join join, ImmutableBitSet rhsFields) {
        List aboveFilters = RelOptUtil.conjunctions((RexNode)filter.getCondition());
        boolean hasNullFilterOnRightSide = false;
        ArrayList<RexNode> filterList = new ArrayList<RexNode>();
        ImmutableBitSet notNullColumnsFromRightSide = this.getNotNullColumnsFromRightSide((RelNode)join);
        for (RexNode filterNode : aboveFilters) {
            ImmutableBitSet usedFields = RelOptUtil.InputFinder.bits((RexNode)filterNode);
            boolean usesFieldFromRHS = usedFields.intersects(rhsFields);
            if (!usesFieldFromRHS) {
                filterList.add(filterNode);
                continue;
            }
            if (filterNode.getKind() != SqlKind.IS_NULL) {
                return Optional.empty();
            }
            boolean usesRHSFieldsOnly = rhsFields.contains(usedFields);
            if (!usesRHSFieldsOnly) {
                return Optional.empty();
            }
            RexNode arg = (RexNode)((RexCall)filterNode).getOperands().get(0);
            if (this.isStrong(arg, notNullColumnsFromRightSide)) {
                hasNullFilterOnRightSide = true;
                continue;
            }
            if (this.isStrong(arg, rhsFields)) continue;
            return Optional.empty();
        }
        if (!hasNullFilterOnRightSide) {
            return Optional.empty();
        }
        return Optional.of(filterList);
    }

    private ImmutableBitSet getNotNullColumnsFromRightSide(RelNode joinRel) {
        int shift = joinRel.getInput(0).getRowType().getFieldCount();
        ImmutableBitSet rhsNotnullColumns = this.deduceNotNullColumns(joinRel.getInput(1));
        return rhsNotnullColumns.shift(shift);
    }

    private ImmutableBitSet deduceNotNullColumns(RelNode relNode) {
        RelOptCluster cluster = relNode.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        RelMetadataQuery mq = cluster.getMetadataQuery();
        ImmutableBitSet.Builder result = ImmutableBitSet.builder();
        ImmutableBitSet.Builder candidatesBuilder = ImmutableBitSet.builder();
        List fieldList = relNode.getRowType().getFieldList();
        for (int i = 0; i < fieldList.size(); ++i) {
            if (((RelDataTypeField)fieldList.get(i)).getType().isNullable()) {
                candidatesBuilder.set(i);
                continue;
            }
            result.set(i);
        }
        ImmutableBitSet candidates = candidatesBuilder.build();
        if (candidates.isEmpty()) {
            return result.build();
        }
        RexExecutor executor = cluster.getPlanner().getExecutor();
        if (!(executor instanceof RexExecutorImpl)) {
            return result.build();
        }
        RexImplicationChecker checker = new RexImplicationChecker(rexBuilder, executor, relNode.getRowType());
        RelOptPredicateList predicates = mq.getPulledUpPredicates(relNode);
        ImmutableList preds = predicates.pulledUpPredicates;
        ArrayList antecedent = new ArrayList(preds);
        RexNode first = RexUtil.composeConjunction((RexBuilder)rexBuilder, antecedent);
        Iterator iterator = candidates.iterator();
        while (iterator.hasNext()) {
            int c = (Integer)iterator.next();
            RelDataTypeField field = (RelDataTypeField)fieldList.get(c);
            RexNode second = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_NULL, new RexNode[]{rexBuilder.makeInputRef(field.getType(), field.getIndex())});
            if (!checker.implies(first, second)) continue;
            result.set(c);
        }
        return result.build();
    }

    private boolean isStrong(RexNode rexNode, ImmutableBitSet rightSideBitset) {
        try {
            rexNode.accept((RexVisitor)new RexVisitorImpl<Void>(this, true){

                public Void visitCall(RexCall call) {
                    if (call.getKind() == SqlKind.CAST) {
                        throw Util.FoundOne.NULL;
                    }
                    return (Void)super.visitCall(call);
                }
            });
        }
        catch (Util.FoundOne e) {
            return false;
        }
        return Strong.isNull((RexNode)rexNode, (ImmutableBitSet)rightSideBitset);
    }
}

