/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.spark;

import java.util.List;
import java.util.Stack;
import org.apache.hadoop.hive.common.ObjectPair;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.FileSinkOperator;
import org.apache.hadoop.hive.ql.exec.LimitOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.exec.spark.SparkUtilities;
import org.apache.hadoop.hive.ql.exec.spark.session.SparkSession;
import org.apache.hadoop.hive.ql.exec.spark.session.SparkSessionManagerImpl;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.spark.GenSparkUtils;
import org.apache.hadoop.hive.ql.parse.spark.OptimizeSparkProcContext;
import org.apache.hadoop.hive.ql.plan.FileSinkDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SetSparkReducerParallelism
implements NodeProcessor {
    private static final Logger LOG = LoggerFactory.getLogger((String)SetSparkReducerParallelism.class.getName());
    private ObjectPair<Long, Integer> sparkMemoryAndCores;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procContext, Object ... nodeOutputs) throws SemanticException {
        OptimizeSparkProcContext context = (OptimizeSparkProcContext)procContext;
        ReduceSinkOperator sink = (ReduceSinkOperator)nd;
        ReduceSinkDesc desc = (ReduceSinkDesc)sink.getConf();
        int maxReducers = context.getConf().getIntVar(HiveConf.ConfVars.MAXREDUCERS);
        int constantReducers = context.getConf().getIntVar(HiveConf.ConfVars.HADOOPNUMREDUCERS);
        if (context.getVisitedReduceSinks().contains(sink)) {
            LOG.debug("Already processed reduce sink: " + sink.getName());
            return true;
        }
        context.getVisitedReduceSinks().add(sink);
        if (this.needSetParallelism(sink, context.getConf())) {
            if (constantReducers > 0) {
                LOG.info("Parallelism for reduce sink " + sink + " set by user to " + constantReducers);
                desc.setNumReducers(constantReducers);
            } else {
                FileSinkOperator fso = GenSparkUtils.getChildOperator(sink, FileSinkOperator.class);
                if (fso != null) {
                    int numBuckets;
                    String bucketCount = ((FileSinkDesc)fso.getConf()).getTableInfo().getProperties().getProperty("bucket_count");
                    int n = numBuckets = bucketCount == null ? 0 : Integer.parseInt(bucketCount);
                    if (numBuckets > 0) {
                        LOG.info("Set parallelism for reduce sink " + sink + " to: " + numBuckets + " (buckets)");
                        desc.setNumReducers(numBuckets);
                        return false;
                    }
                }
                long numberOfBytes = 0L;
                for (Operator<OperatorDesc> sibling : sink.getChildOperators().get(0).getParentOperators()) {
                    if (sibling.getStatistics() != null) {
                        numberOfBytes += sibling.getStatistics().getDataSize();
                        if (!LOG.isDebugEnabled()) continue;
                        LOG.debug("Sibling " + sibling + " has stats: " + sibling.getStatistics());
                        continue;
                    }
                    LOG.warn("No stats available from: " + sibling);
                }
                if (this.sparkMemoryAndCores == null) {
                    SparkSessionManagerImpl sparkSessionManager = null;
                    SparkSession sparkSession = null;
                    try {
                        sparkSessionManager = SparkSessionManagerImpl.getInstance();
                        sparkSession = SparkUtilities.getSparkSession(context.getConf(), sparkSessionManager);
                        this.sparkMemoryAndCores = sparkSession.getMemoryAndCores();
                    }
                    catch (HiveException e) {
                        throw new SemanticException("Failed to get a spark session: " + e);
                    }
                    catch (Exception e) {
                        LOG.warn("Failed to get spark memory/core info", (Throwable)e);
                    }
                    finally {
                        if (sparkSession != null && sparkSessionManager != null) {
                            try {
                                sparkSessionManager.returnSession(sparkSession);
                            }
                            catch (HiveException ex) {
                                LOG.error("Failed to return the session to SessionManager: " + ex, (Throwable)ex);
                            }
                        }
                    }
                }
                long bytesPerReducer = context.getConf().getLongVar(HiveConf.ConfVars.BYTESPERREDUCER) / 2L;
                int numReducers = Utilities.estimateReducers(numberOfBytes, bytesPerReducer, maxReducers, false);
                if (this.sparkMemoryAndCores != null && (Long)this.sparkMemoryAndCores.getFirst() > 0L && (Integer)this.sparkMemoryAndCores.getSecond() > 0) {
                    if ((double)((Long)this.sparkMemoryAndCores.getFirst()).longValue() / (double)bytesPerReducer < 0.5) {
                        LOG.warn("Average load of a reducer is much larger than its available memory. Consider decreasing hive.exec.reducers.bytes.per.reducer");
                    }
                    numReducers = Math.max(numReducers, (Integer)this.sparkMemoryAndCores.getSecond());
                }
                numReducers = Math.min(numReducers, maxReducers);
                LOG.info("Set parallelism for reduce sink " + sink + " to: " + numReducers + " (calculated)");
                desc.setNumReducers(numReducers);
            }
        } else {
            LOG.info("Number of reducers determined to be: " + desc.getNumReducers());
        }
        return false;
    }

    private boolean needSetParallelism(ReduceSinkOperator reduceSink, HiveConf hiveConf) {
        ReduceSinkDesc desc = (ReduceSinkDesc)reduceSink.getConf();
        if (desc.getNumReducers() <= 0) {
            return true;
        }
        if (desc.getNumReducers() == 1 && desc.hasOrderBy() && hiveConf.getBoolVar(HiveConf.ConfVars.HIVESAMPLINGFORORDERBY) && !desc.isDeduplicated()) {
            List<Operator<OperatorDesc>> children = reduceSink.getChildOperators();
            while (children != null && children.size() > 0) {
                if (children.size() != 1 || children.get(0) instanceof LimitOperator) {
                    return false;
                }
                if (children.get(0) instanceof ReduceSinkOperator || children.get(0) instanceof FileSinkOperator) break;
                children = children.get(0).getChildOperators();
            }
            return true;
        }
        return false;
    }
}

