/*
 * Decompiled with CFR 0.152.
 */
package org.apache.pig.backend.hadoop.executionengine.spark.converter;

import com.google.common.base.Optional;
import com.google.common.collect.Maps;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.backend.hadoop.executionengine.physicalLayer.PhysicalOperator;
import org.apache.pig.backend.hadoop.executionengine.physicalLayer.Result;
import org.apache.pig.backend.hadoop.executionengine.physicalLayer.plans.PhysicalPlan;
import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POLocalRearrange;
import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POSkewedJoin;
import org.apache.pig.backend.hadoop.executionengine.spark.SparkPigContext;
import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
import org.apache.pig.backend.hadoop.executionengine.spark.converter.IndexedKey;
import org.apache.pig.backend.hadoop.executionengine.spark.converter.IteratorTransform;
import org.apache.pig.backend.hadoop.executionengine.spark.converter.RDDConverter;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;
import org.apache.pig.impl.plan.NodeIdGenerator;
import org.apache.pig.impl.plan.OperatorKey;
import org.apache.pig.impl.plan.PlanException;
import org.apache.pig.impl.util.MultiMap;
import org.apache.pig.impl.util.Pair;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import scala.Function1;
import scala.Tuple2;
import scala.runtime.AbstractFunction1;

public class SkewedJoinConverter
implements RDDConverter<Tuple, Tuple, POSkewedJoin>,
Serializable {
    private static Log log = LogFactory.getLog(SkewedJoinConverter.class);
    private POLocalRearrange[] LRs;
    private POSkewedJoin poSkewedJoin;
    private String skewedJoinPartitionFile;

    public void setSkewedJoinPartitionFile(String partitionFile) {
        this.skewedJoinPartitionFile = partitionFile;
    }

    @Override
    public RDD<Tuple> convert(List<RDD<Tuple>> predecessors, POSkewedJoin poSkewedJoin) throws IOException {
        SparkUtil.assertPredecessorSize(predecessors, poSkewedJoin, 2);
        this.LRs = new POLocalRearrange[2];
        this.poSkewedJoin = poSkewedJoin;
        this.createJoinPlans(poSkewedJoin.getJoinPlans());
        RDD<Tuple> rdd1 = predecessors.get(0);
        RDD<Tuple> rdd2 = predecessors.get(1);
        SparkPigContext.get();
        Broadcast<List<Tuple>> keyDist = SparkPigContext.getBroadcastedVars().get(this.skewedJoinPartitionFile);
        SparkPigContext.get();
        Integer defaultParallelism = SparkPigContext.getParallelism(predecessors, poSkewedJoin);
        SkewPartitionIndexKeyFunction skewFun = new SkewPartitionIndexKeyFunction(this, keyDist, defaultParallelism);
        RDD skewIdxKeyRDD = rdd1.map((Function1)skewFun, SparkUtil.getTuple2Manifest());
        JavaPairRDD skewIndexedJavaPairRDD = new JavaPairRDD(skewIdxKeyRDD, SparkUtil.getManifest(PartitionIndexedKey.class), SparkUtil.getManifest(Tuple.class));
        StreamPartitionIndexKeyFunction streamFun = new StreamPartitionIndexKeyFunction(this, keyDist, defaultParallelism);
        JavaRDD streamIdxKeyJavaRDD = rdd2.toJavaRDD().flatMap((FlatMapFunction)streamFun);
        JavaPairRDD streamIndexedJavaPairRDD = new JavaPairRDD(streamIdxKeyJavaRDD.rdd(), SparkUtil.getManifest(PartitionIndexedKey.class), SparkUtil.getManifest(Tuple.class));
        JavaRDD<Tuple> result = this.doJoin((JavaPairRDD<PartitionIndexedKey, Tuple>)skewIndexedJavaPairRDD, (JavaPairRDD<PartitionIndexedKey, Tuple>)streamIndexedJavaPairRDD, this.buildPartitioner(keyDist, defaultParallelism), keyDist);
        return result.rdd();
    }

    private void createJoinPlans(MultiMap<PhysicalOperator, PhysicalPlan> inpPlans) throws PlanException {
        int i = -1;
        for (PhysicalOperator inpPhyOp : inpPlans.keySet()) {
            ++i;
            POLocalRearrange lr = new POLocalRearrange(this.genKey());
            try {
                lr.setIndex(i);
            }
            catch (ExecException e) {
                throw new PlanException(e.getMessage(), e.getErrorCode(), e.getErrorSource(), (Throwable)e);
            }
            lr.setResultType((byte)110);
            lr.setKeyType((byte)110);
            lr.setPlans(inpPlans.get(inpPhyOp));
            this.LRs[i] = lr;
        }
    }

    private OperatorKey genKey() {
        return new OperatorKey(this.poSkewedJoin.getOperatorKey().scope, NodeIdGenerator.getGenerator().getNextNodeId(this.poSkewedJoin.getOperatorKey().scope));
    }

    private static Map<Tuple, Pair<Integer, Integer>> loadKeyDistribution(Broadcast<List<Tuple>> keyDist, Integer[] totalReducers) {
        HashMap<Tuple, Pair<Integer, Integer>> reducerMap = new HashMap<Tuple, Pair<Integer, Integer>>();
        totalReducers[0] = -1;
        if (keyDist == null || keyDist.value() == null || ((List)keyDist.value()).size() == 0) {
            log.warn((Object)"Empty dist file: ");
            return reducerMap;
        }
        try {
            TupleFactory tf = TupleFactory.getInstance();
            Tuple t = (Tuple)((List)keyDist.value()).get(0);
            Map distMap = (Map)t.get(0);
            DataBag partitionList = (DataBag)distMap.get("partition.list");
            totalReducers[0] = Integer.valueOf("" + distMap.get("totalreducers"));
            for (Tuple idxTuple : partitionList) {
                Integer maxIndex = (Integer)idxTuple.get(idxTuple.size() - 1);
                Integer minIndex = (Integer)idxTuple.get(idxTuple.size() - 2);
                if (maxIndex < minIndex) {
                    maxIndex = totalReducers[0] + maxIndex;
                }
                Tuple keyTuple = tf.newTuple();
                for (int i = 0; i < idxTuple.size() - 2; ++i) {
                    keyTuple.append(idxTuple.get(i));
                }
                Integer cnt = maxIndex - minIndex;
                reducerMap.put(keyTuple, new Pair<Integer, Integer>(minIndex, cnt));
            }
        }
        catch (ExecException e) {
            log.warn((Object)e.getMessage());
        }
        return reducerMap;
    }

    private SkewedJoinPartitioner buildPartitioner(Broadcast<List<Tuple>> keyDist, Integer defaultParallelism) {
        Integer parallelism = -1;
        Integer[] reducers = new Integer[1];
        SkewedJoinConverter.loadKeyDistribution(keyDist, reducers);
        parallelism = reducers[0];
        if (parallelism <= 0) {
            parallelism = defaultParallelism;
        }
        return new SkewedJoinPartitioner(parallelism);
    }

    private JavaRDD<Tuple> doJoin(JavaPairRDD<PartitionIndexedKey, Tuple> skewIndexedJavaPairRDD, JavaPairRDD<PartitionIndexedKey, Tuple> streamIndexedJavaPairRDD, SkewedJoinPartitioner partitioner, Broadcast<List<Tuple>> keyDist) {
        boolean[] innerFlags = this.poSkewedJoin.getInnerFlags();
        int[] schemaSize = new int[]{0, 0};
        for (int i = 0; i < 2; ++i) {
            if (this.poSkewedJoin.getSchema(i) == null) continue;
            schemaSize[i] = this.poSkewedJoin.getSchema(i).size();
        }
        ToValueFunction toValueFun = new ToValueFunction(innerFlags, schemaSize, keyDist);
        if (innerFlags[0] && innerFlags[1]) {
            JavaPairRDD resultKeyValue = skewIndexedJavaPairRDD.join(streamIndexedJavaPairRDD, (Partitioner)partitioner);
            return resultKeyValue.mapPartitions(toValueFun);
        }
        if (innerFlags[0] && !innerFlags[1]) {
            JavaPairRDD resultKeyValue = skewIndexedJavaPairRDD.leftOuterJoin(streamIndexedJavaPairRDD, (Partitioner)partitioner);
            return resultKeyValue.mapPartitions(toValueFun);
        }
        if (!innerFlags[0] && innerFlags[1]) {
            JavaPairRDD resultKeyValue = skewIndexedJavaPairRDD.rightOuterJoin(streamIndexedJavaPairRDD, (Partitioner)partitioner);
            return resultKeyValue.mapPartitions(toValueFun);
        }
        JavaPairRDD resultKeyValue = skewIndexedJavaPairRDD.fullOuterJoin(streamIndexedJavaPairRDD, (Partitioner)partitioner);
        return resultKeyValue.mapPartitions(toValueFun);
    }

    private static class SkewedJoinPartitioner
    extends Partitioner {
        private int numPartitions;

        public SkewedJoinPartitioner(int parallelism) {
            this.numPartitions = parallelism;
        }

        public int numPartitions() {
            return this.numPartitions;
        }

        public int getPartition(Object IdxKey) {
            int partitionId;
            if (IdxKey instanceof PartitionIndexedKey && (partitionId = ((PartitionIndexedKey)IdxKey).getPartitionId()) >= 0) {
                return partitionId;
            }
            Tuple key = (Tuple)((PartitionIndexedKey)IdxKey).getKey();
            int code = key.hashCode() % this.numPartitions;
            if (code >= 0) {
                return code;
            }
            return code + this.numPartitions;
        }
    }

    private static class StreamPartitionIndexKeyFunction
    implements FlatMapFunction<Tuple, Tuple2<PartitionIndexedKey, Tuple>> {
        private SkewedJoinConverter poSkewedJoin;
        private final Broadcast<List<Tuple>> keyDist;
        private final Integer defaultParallelism;
        private transient boolean initialized = false;
        protected transient Map<Tuple, Pair<Integer, Integer>> reducerMap;
        private transient Integer parallelism;

        public StreamPartitionIndexKeyFunction(SkewedJoinConverter poSkewedJoin, Broadcast<List<Tuple>> keyDist, Integer defaultParallelism) {
            this.poSkewedJoin = poSkewedJoin;
            this.keyDist = keyDist;
            this.defaultParallelism = defaultParallelism;
        }

        public Iterable<Tuple2<PartitionIndexedKey, Tuple>> call(Tuple tuple) throws Exception {
            if (!this.initialized) {
                Integer[] reducers = new Integer[1];
                this.reducerMap = SkewedJoinConverter.loadKeyDistribution((Broadcast<List<Tuple>>)this.keyDist, reducers);
                this.parallelism = reducers[0];
                if (this.parallelism <= 0) {
                    this.parallelism = this.defaultParallelism;
                }
                this.initialized = true;
            }
            this.poSkewedJoin.LRs[1].attachInput(tuple);
            Result lrOut = this.poSkewedJoin.LRs[1].getNextTuple();
            Byte index = (Byte)((Tuple)lrOut.result).get(0);
            Tuple key = (Tuple)((Tuple)lrOut.result).get(1);
            ArrayList<Tuple2<PartitionIndexedKey, Tuple>> l = new ArrayList<Tuple2<PartitionIndexedKey, Tuple>>();
            Pair<Integer, Integer> indexes = this.reducerMap.get(key);
            if (indexes == null) {
                indexes = new Pair<Integer, Integer>(-1, 0);
            }
            Integer reducerIdx = (Integer)indexes.first;
            Integer cnt = 0;
            while (cnt <= (Integer)indexes.second) {
                if (reducerIdx >= this.parallelism) {
                    reducerIdx = 0;
                }
                int partitionId = reducerIdx;
                PartitionIndexedKey pIndexKey = new PartitionIndexedKey(index, key, partitionId);
                l.add((Tuple2<PartitionIndexedKey, Tuple>)new Tuple2((Object)pIndexKey, (Object)tuple));
                Integer n = reducerIdx;
                Integer n2 = reducerIdx = Integer.valueOf(reducerIdx + 1);
                n = cnt;
                n2 = cnt = Integer.valueOf(cnt + 1);
            }
            return l;
        }
    }

    private static class SkewPartitionIndexKeyFunction
    extends AbstractFunction1<Tuple, Tuple2<PartitionIndexedKey, Tuple>>
    implements Serializable {
        private final SkewedJoinConverter poSkewedJoin;
        private final Broadcast<List<Tuple>> keyDist;
        private final Integer defaultParallelism;
        private transient boolean initialized = false;
        protected transient Map<Tuple, Pair<Integer, Integer>> reducerMap;
        private transient Integer parallelism = -1;
        private transient Map<Tuple, Integer> currentIndexMap;

        public SkewPartitionIndexKeyFunction(SkewedJoinConverter poSkewedJoin, Broadcast<List<Tuple>> keyDist, Integer defaultParallelism) {
            this.poSkewedJoin = poSkewedJoin;
            this.keyDist = keyDist;
            this.defaultParallelism = defaultParallelism;
        }

        public Tuple2<PartitionIndexedKey, Tuple> apply(Tuple tuple) {
            this.poSkewedJoin.LRs[0].attachInput(tuple);
            try {
                Result lrOut = this.poSkewedJoin.LRs[0].getNextTuple();
                Byte index = (Byte)((Tuple)lrOut.result).get(0);
                Object key = ((Tuple)lrOut.result).get(1);
                Tuple keyTuple = (Tuple)key;
                int partitionId = this.getPartitionId(keyTuple);
                PartitionIndexedKey pIndexKey = new PartitionIndexedKey(index, keyTuple, partitionId);
                Tuple2 tuple_KeyValue = new Tuple2((Object)pIndexKey, (Object)tuple);
                return tuple_KeyValue;
            }
            catch (Exception e) {
                System.out.print(e);
                return null;
            }
        }

        private Integer getPartitionId(Tuple keyTuple) {
            if (!this.initialized) {
                Integer[] reducers = new Integer[1];
                this.reducerMap = SkewedJoinConverter.loadKeyDistribution((Broadcast<List<Tuple>>)this.keyDist, reducers);
                this.parallelism = reducers[0];
                if (this.parallelism <= 0) {
                    this.parallelism = this.defaultParallelism;
                }
                this.currentIndexMap = Maps.newHashMap();
                this.initialized = true;
            }
            Integer curIndex = -1;
            Pair<Integer, Integer> indexes = this.reducerMap.get(keyTuple);
            if (indexes == null) {
                return -1;
            }
            if (this.currentIndexMap.containsKey(keyTuple)) {
                curIndex = this.currentIndexMap.get(keyTuple);
            }
            if (curIndex >= (Integer)indexes.first + (Integer)indexes.second || curIndex == -1) {
                curIndex = (Integer)indexes.first;
            } else {
                Integer n = curIndex;
                Integer n2 = curIndex = Integer.valueOf(curIndex + 1);
            }
            this.currentIndexMap.put(keyTuple, curIndex);
            return curIndex % this.parallelism;
        }
    }

    private static class PartitionIndexedKey
    extends IndexedKey {
        int partitionId;

        public PartitionIndexedKey(byte index, Object key) {
            super(index, key);
            this.partitionId = -1;
        }

        public PartitionIndexedKey(byte index, Object key, int pid) {
            super(index, key);
            this.partitionId = pid;
        }

        public int getPartitionId() {
            return this.partitionId;
        }

        private void setPartitionId(int pid) {
            this.partitionId = pid;
        }

        @Override
        public String toString() {
            return "PartitionIndexedKey{index=" + this.getIndex() + ", partitionId=" + this.getPartitionId() + ", key=" + this.getKey() + '}';
        }
    }

    private static class ToValueFunction<L, R>
    implements FlatMapFunction<Iterator<Tuple2<PartitionIndexedKey, Tuple2<L, R>>>, Tuple>,
    Serializable {
        private boolean[] innerFlags;
        private int[] schemaSize;
        private final Broadcast<List<Tuple>> keyDist;
        private transient boolean initialized = false;
        protected transient Map<Tuple, Pair<Integer, Integer>> reducerMap;

        public ToValueFunction(boolean[] innerFlags, int[] schemaSize, Broadcast<List<Tuple>> keyDist) {
            this.innerFlags = innerFlags;
            this.schemaSize = schemaSize;
            this.keyDist = keyDist;
        }

        public Iterable<Tuple> call(Iterator<Tuple2<PartitionIndexedKey, Tuple2<L, R>>> input) {
            return new Tuple2TransformIterable(input);
        }

        private boolean isFirstReduceKey(PartitionIndexedKey pKey) {
            Pair<Integer, Integer> indexes;
            if (pKey.getPartitionId() == -1) {
                return true;
            }
            if (!this.initialized) {
                Integer[] reducers = new Integer[1];
                this.reducerMap = SkewedJoinConverter.loadKeyDistribution((Broadcast<List<Tuple>>)this.keyDist, reducers);
                this.initialized = true;
            }
            return (indexes = this.reducerMap.get(pKey.getKey())) == null || pKey.getPartitionId() == ((Integer)indexes.first).intValue();
        }

        private class Tuple2TransformIterable
        implements Iterable<Tuple> {
            Iterator<Tuple2<PartitionIndexedKey, Tuple2<L, R>>> in;

            Tuple2TransformIterable(Iterator<Tuple2<PartitionIndexedKey, Tuple2<L, R>>> input) {
                this.in = input;
            }

            @Override
            public Iterator<Tuple> iterator() {
                return new IteratorTransform<Tuple2<PartitionIndexedKey, Tuple2<L, R>>, Tuple>(this.in){

                    /*
                     * Enabled force condition propagation
                     * Lifted jumps to return sites
                     */
                    @Override
                    protected Tuple transform(Tuple2<PartitionIndexedKey, Tuple2<L, R>> next) {
                        try {
                            Object left = ((Tuple2)next._2)._1;
                            Object right = ((Tuple2)next._2)._2;
                            TupleFactory tf = TupleFactory.getInstance();
                            Tuple result = tf.newTuple();
                            Tuple leftTuple = tf.newTuple();
                            if (!ToValueFunction.this.innerFlags[0]) {
                                Optional leftOption = (Optional)left;
                                if (!leftOption.isPresent()) {
                                    if (!ToValueFunction.this.isFirstReduceKey((PartitionIndexedKey)next._1)) return (Tuple)this.next();
                                    for (int i = 0; i < ToValueFunction.this.schemaSize[0]; ++i) {
                                        leftTuple.append(null);
                                    }
                                } else {
                                    leftTuple = (Tuple)leftOption.get();
                                }
                            } else {
                                leftTuple = (Tuple)left;
                            }
                            for (int i = 0; i < leftTuple.size(); ++i) {
                                result.append(leftTuple.get(i));
                            }
                            Tuple rightTuple = tf.newTuple();
                            if (!ToValueFunction.this.innerFlags[1]) {
                                Optional rightOption = (Optional)right;
                                if (!rightOption.isPresent()) {
                                    for (int i = 0; i < ToValueFunction.this.schemaSize[1]; ++i) {
                                        rightTuple.append(null);
                                    }
                                } else {
                                    rightTuple = (Tuple)rightOption.get();
                                }
                            } else {
                                rightTuple = (Tuple)right;
                            }
                            for (int i = 0; i < rightTuple.size(); ++i) {
                                result.append(rightTuple.get(i));
                            }
                            if (!log.isDebugEnabled()) return result;
                            log.debug((Object)("MJC: Result = " + result.toDelimitedString(" ")));
                            return result;
                        }
                        catch (Exception e) {
                            log.warn((Object)e);
                            return null;
                        }
                    }
                };
            }
        }
    }
}

