/*
 * Decompiled with CFR 0.152.
 */
package org.pentaho.di.engine.spark.impl.ops.groupby;

import java.text.MessageFormat;
import java.util.Optional;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.pentaho.di.core.row.RowMetaInterface;
import org.pentaho.di.engine.api.ExecutionContext;
import org.pentaho.di.engine.api.model.Operation;
import org.pentaho.di.engine.api.model.Transformation;
import org.pentaho.di.engine.spark.api.BaseSparkOperation;
import org.pentaho.di.engine.spark.api.SparkOperation;
import org.pentaho.di.engine.spark.impl.accumulators.MetricsAccumulator;
import org.pentaho.di.engine.spark.impl.functions.RowToSparkDatasetConverter;
import org.pentaho.di.engine.spark.impl.functions.SparkToKettleRowFunction;
import org.pentaho.di.engine.spark.impl.functions.sql.functions.UDAggregateFunctionFactory;
import org.pentaho.di.engine.spark.impl.ops.groupby.builder.GroupBySparkSqlQueryBuilder;
import org.pentaho.di.engine.spark.impl.ops.groupby.model.GroupByAelMeta;
import org.pentaho.di.engine.spark.impl.ops.groupby.model.GroupByAelMetaValidator;
import org.pentaho.di.engine.spark.impl.ops.groupby.model.GroupByAggregationTypeEnum;
import org.pentaho.di.engine.spark.util.MetaHelper;
import org.pentaho.di.engine.spark.util.Util;
import org.pentaho.di.trans.TransMeta;
import org.pentaho.di.trans.step.StepMeta;
import org.pentaho.di.trans.step.StepMetaInterface;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseGroupBySparkOperation
extends BaseSparkOperation {
    private static final long serialVersionUID = 1778585150509557393L;
    private final Logger LOG = LoggerFactory.getLogger(BaseGroupBySparkOperation.class);
    private static final String TABLE_NAME_GROUP_BY = "groupByStepData";
    private static final String LOG_MESSAGE_PRINT_DATAFRAME_INCOMMING_ROWS = "Group By DataFrame Schema - Incomming Rows";
    private static final String LOG_MESSAGE_PRINT_DATAFRAME_OUTGOIING_ROWS = "Group By DataFrame Schema - Result Outgoing Rows";
    private static final String LOG_MESSAGE_PRINT_SQL_QUERY = "Executing the following Spark SQL Query for Group By transformation ";
    private static final String LOG_MESSAGE_REGISTER_UDAF = "registering UDAF '{0}' with class '{1}'";
    protected final StepMetaInterface stepMetaInterface;
    protected final JavaSparkContext sparkContext;
    protected final RowMetaInterface inRowMeta;
    protected final RowMetaInterface outRowMeta;
    protected MetricsAccumulator metricsAccumulator = MetricsAccumulator.empty();

    protected BaseGroupBySparkOperation(Operation operation, Transformation transformation, StepMeta stepMeta, JavaSparkContext sparkContext, ExecutionContext executionContext) {
        super(operation);
        this.stepMetaInterface = stepMeta.getStepMetaInterface();
        this.inRowMeta = MetaHelper.getPrevStepFields((TransMeta)MetaHelper.getTransMeta((Transformation)transformation), (String)operation.getId());
        this.outRowMeta = MetaHelper.getRowMeta((TransMeta)MetaHelper.getTransMeta((Transformation)transformation), (String)operation.getId());
        this.sparkContext = sparkContext;
    }

    public Optional<Operation> getLogicalOperation() {
        return Optional.of(this.operation);
    }

    public void apply(SparkOperation.Subscriber subscriber) {
        GroupByAelMeta groupByAelMeta = this.mapToGroupByAelMeta(this.inRowMeta, this.outRowMeta);
        GroupByAelMetaValidator.init(groupByAelMeta).validate();
        this.metricsAccumulator = subscriber.getMetricsAccumulator();
        JavaRDD input = subscriber.getInput().orElseGet(() -> ((JavaSparkContext)this.sparkContext).emptyRDD());
        Dataset<Row> dataFrame = RowToSparkDatasetConverter.convert((JavaRDD<org.pentaho.di.engine.api.model.Row>)input, this.inRowMeta, this.stepMetaInterface.getParentStepMeta().getStepID(), this.metricsAccumulator);
        dataFrame.registerTempTable(TABLE_NAME_GROUP_BY);
        this.logDataFrame(dataFrame, LOG_MESSAGE_PRINT_DATAFRAME_INCOMMING_ROWS);
        SparkSession spark = Util.getSparkSession();
        this.registerAllCustomUDAggregateFunctions(spark);
        String sqlQuery = GroupBySparkSqlQueryBuilder.init(groupByAelMeta, TABLE_NAME_GROUP_BY).buildGroupBySparkSqlQuery();
        this.logSqlQuery(sqlQuery, "Executing the following Spark SQL Query for Group By transformation '" + groupByAelMeta.getStepName() + "'");
        Dataset groupByDataFrame = spark.sql(sqlQuery);
        this.logDataFrame((Dataset<Row>)groupByDataFrame, LOG_MESSAGE_PRINT_DATAFRAME_OUTGOIING_ROWS);
        JavaRDD rowJavaRDD = groupByDataFrame.toJavaRDD();
        JavaRDD output = rowJavaRDD.mapPartitions(new SparkToKettleRowFunction(this.stepMetaInterface.getParentStepMeta(), this.metricsAccumulator).asRegisteredFunction(this.stepMetaInterface.getParentStepMeta().getStepID()).toFlatMap());
        subscriber.setOutput(output);
    }

    protected void registerAllCustomUDAggregateFunctions(SparkSession spark) {
        if (this.LOG.isDebugEnabled()) {
            this.LOG.debug("Registering All Custom Spark SQL User Defined Aggregation Functions...");
        }
        GroupByAggregationTypeEnum.getAllAggregateTypesWithCustomUdaFunction().forEach(aggregateType -> {
            UserDefinedAggregateFunction udaf = UDAggregateFunctionFactory.newInstance((Class)aggregateType.getSqlAggFunctionUserDefinedClass());
            this.LOG.debug("   " + MessageFormat.format(LOG_MESSAGE_REGISTER_UDAF, aggregateType.getSqlAggFunctionName(), udaf.getClass().getName()));
            spark.udf().register(aggregateType.getSqlAggFunctionName(), udaf);
        });
    }

    protected void logDataFrame(Dataset<Row> dataFrame, String message) {
        if (this.LOG.isDebugEnabled()) {
            this.LOG.debug("DataFrames Schema - " + message + ":  ");
            dataFrame.printSchema();
            this.LOG.debug("DataFrames table - " + message + ":  ");
            dataFrame.show();
        }
    }

    protected void logSqlQuery(String sqlQuery, String message) {
        this.LOG.debug(message + "':  \n" + sqlQuery);
    }

    protected abstract GroupByAelMeta mapToGroupByAelMeta(RowMetaInterface var1, RowMetaInterface var2);
}

