/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer;

import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.ErrorMsg;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.UnionOperator;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.metadata.Partition;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.hive.ql.optimizer.AbstractBucketJoinProc;
import org.apache.hadoop.hive.ql.optimizer.GroupByOptimizer;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.optimizer.ppr.PartitionPruner;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.PrunedPartitionList;
import org.apache.hadoop.hive.ql.parse.QB;
import org.apache.hadoop.hive.ql.parse.QBJoinTree;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.TableAccessAnalyzer;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.util.StringUtils;

public class BucketMapJoinOptimizer
implements Transform {
    private static final Log LOG = LogFactory.getLog((String)GroupByOptimizer.class.getName());

    @Override
    public ParseContext transform(ParseContext pctx) throws SemanticException {
        LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
        BucketMapjoinOptProcCtx bucketMapJoinOptimizeCtx = new BucketMapjoinOptProcCtx(pctx.getConf());
        opRules.put(new RuleRegExp("R1", MapJoinOperator.getOperatorName() + "%"), this.getBucketMapjoinProc(pctx));
        opRules.put(new RuleRegExp("R2", ReduceSinkOperator.getOperatorName() + "%.*" + MapJoinOperator.getOperatorName()), this.getBucketMapjoinRejectProc(pctx));
        opRules.put(new RuleRegExp(new String("R3"), UnionOperator.getOperatorName() + "%.*" + MapJoinOperator.getOperatorName() + "%"), this.getBucketMapjoinRejectProc(pctx));
        opRules.put(new RuleRegExp(new String("R4"), MapJoinOperator.getOperatorName() + "%.*" + MapJoinOperator.getOperatorName() + "%"), this.getBucketMapjoinRejectProc(pctx));
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(this.getDefaultProc(), opRules, bucketMapJoinOptimizeCtx);
        DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getTopOps().values());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    private NodeProcessor getBucketMapjoinRejectProc(ParseContext pctx) {
        return new NodeProcessor(){

            @Override
            public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
                MapJoinOperator mapJoinOp = (MapJoinOperator)nd;
                BucketMapjoinOptProcCtx context = (BucketMapjoinOptProcCtx)procCtx;
                context.listOfRejectedMapjoins.add(mapJoinOp);
                return null;
            }
        };
    }

    private NodeProcessor getBucketMapjoinProc(ParseContext pctx) {
        return new BucketMapjoinOptProc(pctx);
    }

    private NodeProcessor getDefaultProc() {
        return new NodeProcessor(){

            @Override
            public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
                return null;
            }
        };
    }

    class BucketMapjoinOptProcCtx
    implements NodeProcessorCtx {
        private final HiveConf conf;
        Set<MapJoinOperator> listOfRejectedMapjoins = new HashSet<MapJoinOperator>();

        public BucketMapjoinOptProcCtx(HiveConf conf) {
            this.conf = conf;
        }

        public HiveConf getConf() {
            return this.conf;
        }

        public Set<MapJoinOperator> getListOfRejectedMapjoins() {
            return this.listOfRejectedMapjoins;
        }
    }

    class BucketMapjoinOptProc
    extends AbstractBucketJoinProc
    implements NodeProcessor {
        protected ParseContext pGraphContext;

        public BucketMapjoinOptProc(ParseContext pGraphContext) {
            this.pGraphContext = pGraphContext;
        }

        private boolean convertBucketMapJoin(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            String subQueryAlias;
            MapJoinOperator mapJoinOp = (MapJoinOperator)nd;
            BucketMapjoinOptProcCtx context = (BucketMapjoinOptProcCtx)procCtx;
            HiveConf conf = context.getConf();
            if (context.getListOfRejectedMapjoins().contains(mapJoinOp)) {
                return false;
            }
            QBJoinTree joinCxt = this.pGraphContext.getMapJoinContext().get(mapJoinOp);
            if (joinCxt == null) {
                return false;
            }
            ArrayList<String> joinAliases = new ArrayList<String>();
            String[] srcs = joinCxt.getBaseSrc();
            String[] left = joinCxt.getLeftAliases();
            List<String> mapAlias = joinCxt.getMapAliases();
            String baseBigAlias = null;
            for (String s : left) {
                if (s == null || joinAliases.contains(subQueryAlias = QB.getAppendedAliasFromId(joinCxt.getId(), s))) continue;
                joinAliases.add(subQueryAlias);
                if (mapAlias.contains(s)) continue;
                baseBigAlias = subQueryAlias;
            }
            for (String s : srcs) {
                if (s == null || joinAliases.contains(subQueryAlias = QB.getAppendedAliasFromId(joinCxt.getId(), s))) continue;
                joinAliases.add(subQueryAlias);
                if (mapAlias.contains(s)) continue;
                baseBigAlias = subQueryAlias;
            }
            MapJoinDesc mjDesc = (MapJoinDesc)mapJoinOp.getConf();
            LinkedHashMap<String, List<Integer>> aliasToPartitionBucketNumberMapping = new LinkedHashMap<String, List<Integer>>();
            LinkedHashMap<String, List<Object>> aliasToPartitionBucketFileNamesMapping = new LinkedHashMap<String, List<Object>>();
            HashMap<String, Operator<? extends OperatorDesc>> topOps = this.pGraphContext.getTopOps();
            HashMap<TableScanOperator, Table> topToTable = this.pGraphContext.getTopToTable();
            LinkedHashMap<Partition, List<String>> bigTblPartsToBucketFileNames = new LinkedHashMap<Partition, List<String>>();
            LinkedHashMap<Partition, Integer> bigTblPartsToBucketNumber = new LinkedHashMap<Partition, Integer>();
            Integer[] orders = null;
            boolean bigTablePartitioned = true;
            for (int index = 0; index < joinAliases.size(); ++index) {
                Table tbl;
                String alias = (String)joinAliases.get(index);
                Operator<? extends OperatorDesc> topOp = joinCxt.getAliasToOpInfo().get(alias);
                if (topOp == null) {
                    return false;
                }
                List<String> keys = this.toColumns(mjDesc.getKeys().get((byte)index));
                if (keys == null || keys.isEmpty()) {
                    return false;
                }
                int oldKeySize = keys.size();
                TableScanOperator tso = TableAccessAnalyzer.genRootTableScan(topOp, keys);
                if (tso == null) {
                    return false;
                }
                if (topOps.containsValue(tso)) {
                    for (Map.Entry topOpEntry : topOps.entrySet()) {
                        if (topOpEntry.getValue() != tso) continue;
                        String newAlias = (String)topOpEntry.getKey();
                        joinAliases.set(index, newAlias);
                        if (baseBigAlias.equals(alias)) {
                            baseBigAlias = newAlias;
                        }
                        alias = newAlias;
                        break;
                    }
                } else {
                    return false;
                }
                if (keys.size() != oldKeySize) {
                    return false;
                }
                if (orders == null) {
                    orders = new Integer[keys.size()];
                }
                if ((tbl = (Table)topToTable.get(tso)).isPartitioned()) {
                    PrunedPartitionList prunedParts;
                    try {
                        prunedParts = this.pGraphContext.getOpToPartList().get(tso);
                        if (prunedParts == null) {
                            prunedParts = PartitionPruner.prune(tbl, this.pGraphContext.getOpToPartPruner().get(tso), this.pGraphContext.getConf(), alias, this.pGraphContext.getPrunedPartitions());
                            this.pGraphContext.getOpToPartList().put(tso, prunedParts);
                        }
                    }
                    catch (HiveException e) {
                        LOG.error((Object)StringUtils.stringifyException((Throwable)e));
                        throw new SemanticException(e.getMessage(), e);
                    }
                    List<Partition> partitions = prunedParts.getNotDeniedPartns();
                    if (partitions.isEmpty()) {
                        if (alias.equals(baseBigAlias)) continue;
                        aliasToPartitionBucketNumberMapping.put(alias, Arrays.asList(new Integer[0]));
                        aliasToPartitionBucketFileNamesMapping.put(alias, new ArrayList());
                        continue;
                    }
                    ArrayList<Integer> buckets = new ArrayList<Integer>();
                    ArrayList<List<String>> files = new ArrayList<List<String>>();
                    for (Partition p : partitions) {
                        if (!this.checkBucketColumns(p.getBucketCols(), keys, orders)) {
                            return false;
                        }
                        List<String> fileNames = this.getOnePartitionBucketFileNames(p.getDataLocation());
                        int bucketCount = p.getBucketCount();
                        if (fileNames.size() != bucketCount) {
                            String msg = "The number of buckets for table " + tbl.getTableName() + " partition " + p.getName() + " is " + p.getBucketCount() + ", whereas the number of files is " + fileNames.size();
                            throw new SemanticException(ErrorMsg.BUCKETED_TABLE_METADATA_INCORRECT.getMsg(msg));
                        }
                        if (alias.equals(baseBigAlias)) {
                            bigTblPartsToBucketFileNames.put(p, fileNames);
                            bigTblPartsToBucketNumber.put(p, bucketCount);
                            continue;
                        }
                        files.add(fileNames);
                        buckets.add(bucketCount);
                    }
                    if (alias.equals(baseBigAlias)) continue;
                    aliasToPartitionBucketNumberMapping.put(alias, buckets);
                    aliasToPartitionBucketFileNamesMapping.put(alias, files);
                    continue;
                }
                if (!this.checkBucketColumns(tbl.getBucketCols(), keys, orders)) {
                    return false;
                }
                List<String> fileNames = this.getOnePartitionBucketFileNames(tbl.getDataLocation());
                Integer num = new Integer(tbl.getNumBuckets());
                if (fileNames.size() != num.intValue()) {
                    String msg = "The number of buckets for table " + tbl.getTableName() + " is " + tbl.getNumBuckets() + ", whereas the number of files is " + fileNames.size();
                    throw new SemanticException(ErrorMsg.BUCKETED_TABLE_METADATA_INCORRECT.getMsg(msg));
                }
                if (alias.equals(baseBigAlias)) {
                    bigTblPartsToBucketFileNames.put(null, fileNames);
                    bigTblPartsToBucketNumber.put(null, tbl.getNumBuckets());
                    bigTablePartitioned = false;
                    continue;
                }
                aliasToPartitionBucketNumberMapping.put(alias, Arrays.asList(num));
                aliasToPartitionBucketFileNamesMapping.put(alias, Arrays.asList(fileNames));
            }
            for (Integer bucketNumber : bigTblPartsToBucketNumber.values()) {
                if (this.checkBucketNumberAgainstBigTable(aliasToPartitionBucketNumberMapping, bucketNumber)) continue;
                return false;
            }
            MapJoinDesc desc = (MapJoinDesc)mapJoinOp.getConf();
            LinkedHashMap<String, Map<String, List<String>>> aliasBucketFileNameMapping = new LinkedHashMap<String, Map<String, List<String>>>();
            for (List partBucketNames : bigTblPartsToBucketFileNames.values()) {
                Collections.sort(partBucketNames);
            }
            for (int j = 0; j < joinAliases.size(); ++j) {
                String alias = (String)joinAliases.get(j);
                if (alias.equals(baseBigAlias)) continue;
                for (List names : (List)aliasToPartitionBucketFileNamesMapping.get(alias)) {
                    Collections.sort(names);
                }
                List<Integer> smallTblBucketNums = aliasToPartitionBucketNumberMapping.get(alias);
                List smallTblFilesList = (List)aliasToPartitionBucketFileNamesMapping.get(alias);
                LinkedHashMap<String, List<String>> mapping = new LinkedHashMap<String, List<String>>();
                aliasBucketFileNameMapping.put(alias, mapping);
                Iterator<Map.Entry<Partition, List<String>>> bigTblPartToBucketNames = bigTblPartsToBucketFileNames.entrySet().iterator();
                Iterator bigTblPartToBucketNum = bigTblPartsToBucketNumber.entrySet().iterator();
                while (bigTblPartToBucketNames.hasNext()) {
                    assert (bigTblPartToBucketNum.hasNext());
                    int bigTblBucketNum = (Integer)bigTblPartToBucketNum.next().getValue();
                    List<String> bigTblBucketNameList = bigTblPartToBucketNames.next().getValue();
                    this.fillMapping(smallTblBucketNums, smallTblFilesList, mapping, bigTblBucketNum, bigTblBucketNameList, desc.getBigTableBucketNumMapping());
                }
            }
            desc.setAliasBucketFileNameMapping(aliasBucketFileNameMapping);
            desc.setBigTableAlias(baseBigAlias);
            if (bigTablePartitioned) {
                desc.setBigTablePartSpecToFileMapping(this.convert(bigTblPartsToBucketFileNames));
            }
            return true;
        }

        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            boolean convert = this.convertBucketMapJoin(nd, stack, procCtx, nodeOutputs);
            BucketMapjoinOptProcCtx context = (BucketMapjoinOptProcCtx)procCtx;
            HiveConf conf = context.getConf();
            if (!convert && conf.getBoolVar(HiveConf.ConfVars.HIVEENFORCEBUCKETMAPJOIN)) {
                throw new SemanticException(ErrorMsg.BUCKET_MAPJOIN_NOT_POSSIBLE.getMsg());
            }
            return null;
        }

        private Map<String, List<String>> convert(Map<Partition, List<String>> mapping) {
            HashMap<String, List<String>> converted = new HashMap<String, List<String>>();
            for (Map.Entry<Partition, List<String>> entry : mapping.entrySet()) {
                converted.put(entry.getKey().getName(), entry.getValue());
            }
            return converted;
        }

        private void fillMapping(List<Integer> smallTblBucketNums, List<List<String>> smallTblFilesList, Map<String, List<String>> mapping, int bigTblBucketNum, List<String> bigTblBucketNameList, Map<String, Integer> bucketFileNameMapping) {
            for (int bindex = 0; bindex < bigTblBucketNameList.size(); ++bindex) {
                ArrayList<String> resultFileNames = new ArrayList<String>();
                for (int sindex = 0; sindex < smallTblBucketNums.size(); ++sindex) {
                    int smallTblBucketNum = smallTblBucketNums.get(sindex);
                    List<String> smallTblFileNames = smallTblFilesList.get(sindex);
                    if (bigTblBucketNum >= smallTblBucketNum) {
                        int toAddSmallIndex = bindex % smallTblBucketNum;
                        resultFileNames.add(smallTblFileNames.get(toAddSmallIndex));
                        continue;
                    }
                    int jump = smallTblBucketNum / bigTblBucketNum;
                    for (int i = bindex; i < smallTblFileNames.size(); i += jump) {
                        resultFileNames.add(smallTblFileNames.get(i));
                    }
                }
                String inputBigTBLBucket = bigTblBucketNameList.get(bindex);
                mapping.put(inputBigTBLBucket, resultFileNames);
                bucketFileNameMapping.put(inputBigTBLBucket, bindex);
            }
        }

        private boolean checkBucketNumberAgainstBigTable(Map<String, List<Integer>> aliasToBucketNumber, int bucketNumberInPart) {
            for (List<Integer> bucketNums : aliasToBucketNumber.values()) {
                for (int nxt : bucketNums) {
                    boolean ok = nxt >= bucketNumberInPart ? nxt % bucketNumberInPart == 0 : bucketNumberInPart % nxt == 0;
                    if (ok) continue;
                    return false;
                }
            }
            return true;
        }

        private List<String> getOnePartitionBucketFileNames(URI location) throws SemanticException {
            ArrayList<String> fileNames = new ArrayList<String>();
            try {
                FileSystem fs = FileSystem.get((URI)location, (Configuration)this.pGraphContext.getConf());
                FileStatus[] files = fs.listStatus(new Path(location.toString()));
                if (files != null) {
                    for (FileStatus file : files) {
                        fileNames.add(file.getPath().toString());
                    }
                }
            }
            catch (IOException e) {
                throw new SemanticException(e);
            }
            return fileNames;
        }

        private boolean checkBucketColumns(List<String> bucketColumns, List<String> keys, Integer[] orders) {
            if (keys == null || bucketColumns == null || bucketColumns.isEmpty()) {
                return false;
            }
            for (int i = 0; i < keys.size(); ++i) {
                int index = bucketColumns.indexOf(keys.get(i));
                if (orders[i] != null && orders[i] != index) {
                    return false;
                }
                orders[i] = index;
            }
            return keys.containsAll(bucketColumns);
        }
    }
}

