/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.pmml.consumer;

import java.io.Serializable;
import java.util.ArrayList;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import weka.classifiers.pmml.consumer.PMMLClassifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.matrix.Maths;
import weka.core.pmml.MiningSchema;
import weka.core.pmml.TargetMetaInfo;

public class Regression
extends PMMLClassifier
implements Serializable {
    private static final long serialVersionUID = -5551125528409488634L;
    protected String m_algorithmName;
    protected RegressionTable[] m_regressionTables;
    protected Normalization m_normalizationMethod = Normalization.NONE;

    public Regression(Element model, Instances dataDictionary, MiningSchema miningSchema) throws Exception {
        super(dataDictionary, miningSchema);
        int functionType = 0;
        String fName = model.getAttribute("functionName");
        if (fName.equals("regression")) {
            functionType = 0;
        } else if (fName.equals("classification")) {
            functionType = 1;
        } else {
            throw new Exception("[PMML Regression] Function name not defined in pmml!");
        }
        String algName = model.getAttribute("algorithmName");
        if (algName != null && algName.length() > 0) {
            this.m_algorithmName = algName;
        }
        this.m_normalizationMethod = Regression.determineNormalization(model);
        this.setUpRegressionTables(model, functionType);
    }

    private void setUpRegressionTables(Element model, int functionType) throws Exception {
        NodeList tableList = model.getElementsByTagName("RegressionTable");
        if (tableList.getLength() == 0) {
            throw new Exception("[Regression] no regression tables defined!");
        }
        this.m_regressionTables = new RegressionTable[tableList.getLength()];
        for (int i = 0; i < tableList.getLength(); ++i) {
            RegressionTable tempRTable;
            Node table = tableList.item(i);
            if (table.getNodeType() != 1) continue;
            this.m_regressionTables[i] = tempRTable = new RegressionTable((Element)table, functionType, this.m_miningSchema);
        }
    }

    private static Normalization determineNormalization(Element model) {
        Normalization normMethod = Normalization.NONE;
        String normName = model.getAttribute("normalizationMethod");
        if (normName.equals("simplemax")) {
            normMethod = Normalization.SIMPLEMAX;
        } else if (normName.equals("softmax")) {
            normMethod = Normalization.SOFTMAX;
        } else if (normName.equals("logit")) {
            normMethod = Normalization.LOGIT;
        } else if (normName.equals("probit")) {
            normMethod = Normalization.PROBIT;
        } else if (normName.equals("cloglog")) {
            normMethod = Normalization.CLOGLOG;
        } else if (normName.equals("exp")) {
            normMethod = Normalization.EXP;
        } else if (normName.equals("loglog")) {
            normMethod = Normalization.LOGLOG;
        } else if (normName.equals("cauchit")) {
            normMethod = Normalization.CAUCHIT;
        }
        return normMethod;
    }

    public String toString() {
        StringBuffer temp = new StringBuffer();
        temp.append("PMML version " + this.getPMMLVersion());
        if (!this.getCreatorApplication().equals("?")) {
            temp.append("\nApplication: " + this.getCreatorApplication());
        }
        if (this.m_algorithmName != null) {
            temp.append("\nPMML Model: " + this.m_algorithmName);
        }
        temp.append("\n\n");
        temp.append(this.m_miningSchema);
        for (RegressionTable table : this.m_regressionTables) {
            temp.append(table);
        }
        if (this.m_normalizationMethod != Normalization.NONE) {
            temp.append("Normalization: " + (Object)((Object)this.m_normalizationMethod));
        }
        temp.append("\n");
        return temp.toString();
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        int i;
        if (!this.m_initialized) {
            this.mapToMiningSchema(inst.dataset());
        }
        double[] preds = null;
        preds = this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric() ? new double[1] : new double[this.m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
        double[] incoming = this.m_fieldsMap.instanceToSchema(inst, this.m_miningSchema);
        boolean hasMissing = false;
        for (i = 0; i < incoming.length; ++i) {
            if (i == this.m_miningSchema.getFieldsAsInstances().classIndex() || !Utils.isMissingValue(incoming[i])) continue;
            hasMissing = true;
            break;
        }
        if (hasMissing) {
            if (!this.m_miningSchema.hasTargetMetaData()) {
                String message = "[Regression] WARNING: Instance to predict has missing value(s) but there is no missing value handling meta data and no prior probabilities/default value to fall back to. No prediction will be made (" + (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNominal() || this.m_miningSchema.getFieldsAsInstances().classAttribute().isString() ? "zero probabilities output)." : "NaN output).");
                if (this.m_log == null) {
                    System.err.println(message);
                } else {
                    this.m_log.logMessage(message);
                }
                if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
                    preds[0] = Utils.missingValue();
                }
                return preds;
            }
            TargetMetaInfo targetData = this.m_miningSchema.getTargetMetaData();
            if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
                preds[0] = targetData.getDefaultValue();
            } else {
                Instances miningSchemaI = this.m_miningSchema.getFieldsAsInstances();
                for (int i2 = 0; i2 < miningSchemaI.classAttribute().numValues(); ++i2) {
                    preds[i2] = targetData.getPriorProbability(miningSchemaI.classAttribute().value(i2));
                }
            }
            return preds;
        }
        for (i = 0; i < this.m_regressionTables.length; ++i) {
            this.m_regressionTables[i].predict(preds, incoming);
        }
        switch (this.m_normalizationMethod) {
            case NONE: {
                break;
            }
            case SIMPLEMAX: {
                Utils.normalize(preds);
                break;
            }
            case SOFTMAX: {
                for (i = 0; i < preds.length; ++i) {
                    preds[i] = Math.exp(preds[i]);
                }
                if (preds.length == 1) {
                    preds[0] = preds[0] / (preds[0] + 1.0);
                    break;
                }
                Utils.normalize(preds);
                break;
            }
            case LOGIT: {
                for (i = 0; i < preds.length; ++i) {
                    preds[i] = 1.0 / (1.0 + Math.exp(-preds[i]));
                }
                Utils.normalize(preds);
                break;
            }
            case PROBIT: {
                for (i = 0; i < preds.length; ++i) {
                    preds[i] = Maths.pnorm(preds[i]);
                }
                Utils.normalize(preds);
                break;
            }
            case CLOGLOG: {
                for (i = 0; i < preds.length; ++i) {
                    preds[i] = 1.0 - Math.exp(-Math.exp(-preds[i]));
                }
                Utils.normalize(preds);
                break;
            }
            case EXP: {
                for (i = 0; i < preds.length; ++i) {
                    preds[i] = Math.exp(preds[i]);
                }
                Utils.normalize(preds);
                break;
            }
            case LOGLOG: {
                for (i = 0; i < preds.length; ++i) {
                    preds[i] = Math.exp(-Math.exp(-preds[i]));
                }
                Utils.normalize(preds);
                break;
            }
            case CAUCHIT: {
                for (i = 0; i < preds.length; ++i) {
                    preds[i] = 0.5 + 0.3183098861837907 * Math.atan(preds[i]);
                }
                Utils.normalize(preds);
                break;
            }
            default: {
                throw new Exception("[Regression] unknown normalization method");
            }
        }
        if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric() && this.m_miningSchema.hasTargetMetaData()) {
            TargetMetaInfo targetData = this.m_miningSchema.getTargetMetaData();
            preds[0] = targetData.applyMinMaxRescaleCast(preds[0]);
        }
        return preds;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8048 $");
    }

    static enum Normalization {
        NONE,
        SIMPLEMAX,
        SOFTMAX,
        LOGIT,
        PROBIT,
        CLOGLOG,
        EXP,
        LOGLOG,
        CAUCHIT;

    }

    static class RegressionTable
    implements Serializable {
        private static final long serialVersionUID = -5259866093996338995L;
        public static final int REGRESSION = 0;
        public static final int CLASSIFICATION = 1;
        protected int m_functionType = 0;
        protected MiningSchema m_miningSchema;
        protected double m_intercept = 0.0;
        protected int m_targetCategory = -1;
        protected ArrayList<Predictor> m_predictors = new ArrayList();
        protected ArrayList<PredictorTerm> m_predictorTerms = new ArrayList();

        public String toString() {
            int i;
            Instances miningSchema = this.m_miningSchema.getFieldsAsInstances();
            StringBuffer temp = new StringBuffer();
            temp.append("Regression table:\n");
            temp.append(miningSchema.classAttribute().name());
            if (this.m_functionType == 1) {
                temp.append("=" + miningSchema.classAttribute().value(this.m_targetCategory));
            }
            temp.append(" =\n\n");
            for (i = 0; i < this.m_predictors.size(); ++i) {
                temp.append(this.m_predictors.get(i).toString() + " +\n");
            }
            for (i = 0; i < this.m_predictorTerms.size(); ++i) {
                temp.append(this.m_predictorTerms.get(i).toString() + " +\n");
            }
            temp.append(Utils.doubleToString(this.m_intercept, 12, 4));
            temp.append("\n\n");
            return temp.toString();
        }

        protected RegressionTable(Element table, int functionType, MiningSchema mSchema) throws Exception {
            int i;
            this.m_miningSchema = mSchema;
            this.m_functionType = functionType;
            Instances miningSchema = this.m_miningSchema.getFieldsAsInstances();
            String intercept = table.getAttribute("intercept");
            if (intercept.length() > 0) {
                this.m_intercept = Double.parseDouble(intercept);
            }
            if (this.m_functionType == 1) {
                String targetCat = table.getAttribute("targetCategory");
                if (targetCat.length() > 0) {
                    Attribute classA = miningSchema.classAttribute();
                    for (i = 0; i < classA.numValues(); ++i) {
                        if (!classA.value(i).equals(targetCat)) continue;
                        this.m_targetCategory = i;
                    }
                }
                if (this.m_targetCategory == -1) {
                    throw new Exception("[RegressionTable] No target categories defined for classification");
                }
            }
            NodeList numericPs = table.getElementsByTagName("NumericPredictor");
            for (int i2 = 0; i2 < numericPs.getLength(); ++i2) {
                Node nP = numericPs.item(i2);
                if (nP.getNodeType() != 1) continue;
                NumericPredictor numP = new NumericPredictor((Element)nP, miningSchema);
                this.m_predictors.add(numP);
            }
            NodeList categoricalPs = table.getElementsByTagName("CategoricalPredictor");
            for (i = 0; i < categoricalPs.getLength(); ++i) {
                Node cP = categoricalPs.item(i);
                if (cP.getNodeType() != 1) continue;
                CategoricalPredictor catP = new CategoricalPredictor((Element)cP, miningSchema);
                this.m_predictors.add(catP);
            }
            NodeList predictorTerms = table.getElementsByTagName("PredictorTerm");
            for (int i3 = 0; i3 < predictorTerms.getLength(); ++i3) {
                Node pT = predictorTerms.item(i3);
                PredictorTerm predT = new PredictorTerm((Element)pT, miningSchema);
                this.m_predictorTerms.add(predT);
            }
        }

        public void predict(double[] preds, double[] input) {
            int i;
            if (this.m_targetCategory == -1) {
                preds[0] = this.m_intercept;
            } else {
                preds[this.m_targetCategory] = this.m_intercept;
            }
            for (i = 0; i < this.m_predictors.size(); ++i) {
                Predictor p = this.m_predictors.get(i);
                p.add(preds, input);
            }
            for (i = 0; i < this.m_predictorTerms.size(); ++i) {
                PredictorTerm pt = this.m_predictorTerms.get(i);
                pt.add(preds, input);
            }
        }

        protected class PredictorTerm
        implements Serializable {
            private static final long serialVersionUID = 5493100145890252757L;
            protected double m_coefficient = 1.0;
            protected int[] m_indexes;
            protected String[] m_fieldNames;

            protected PredictorTerm(Element predictorTerm, Instances miningSchema) throws Exception {
                NodeList fields;
                String coeff = predictorTerm.getAttribute("coefficient");
                if (coeff != null && coeff.length() > 0) {
                    try {
                        this.m_coefficient = Double.parseDouble(coeff);
                    }
                    catch (IllegalArgumentException ex) {
                        throw new Exception("[PredictorTerm] unable to parse coefficient");
                    }
                }
                if ((fields = predictorTerm.getElementsByTagName("FieldRef")).getLength() > 0) {
                    this.m_indexes = new int[fields.getLength()];
                    this.m_fieldNames = new String[fields.getLength()];
                    for (int i = 0; i < fields.getLength(); ++i) {
                        String fieldName;
                        Node fieldRef = fields.item(i);
                        if (fieldRef.getNodeType() != 1 || (fieldName = ((Element)fieldRef).getAttribute("field")) == null || fieldName.length() <= 0) continue;
                        boolean found = false;
                        for (int j = 0; j < miningSchema.numAttributes(); ++j) {
                            if (!miningSchema.attribute(j).name().equals(fieldName)) continue;
                            if (!miningSchema.attribute(j).isNumeric()) {
                                throw new Exception("[PredictorTerm] field is not continuous: " + fieldName);
                            }
                            found = true;
                            this.m_indexes[i] = j;
                            this.m_fieldNames[i] = fieldName;
                            break;
                        }
                        if (found) continue;
                        throw new Exception("[PredictorTerm] Unable to find field " + fieldName + " in mining schema!");
                    }
                }
            }

            public String toString() {
                StringBuffer result = new StringBuffer();
                result.append("(" + Utils.doubleToString(this.m_coefficient, 12, 4));
                for (int i = 0; i < this.m_fieldNames.length; ++i) {
                    result.append(" * " + this.m_fieldNames[i]);
                }
                result.append(")");
                return result.toString();
            }

            public void add(double[] preds, double[] input) {
                int indx = 0;
                if (RegressionTable.this.m_targetCategory != -1) {
                    indx = RegressionTable.this.m_targetCategory;
                }
                double result = this.m_coefficient;
                for (int i = 0; i < this.m_indexes.length; ++i) {
                    result *= input[this.m_indexes[i]];
                }
                int n = indx;
                preds[n] = preds[n] + result;
            }
        }

        protected class CategoricalPredictor
        extends Predictor {
            private static final long serialVersionUID = 3077920125549906819L;
            protected String m_valueName;
            protected int m_valueIndex;

            protected CategoricalPredictor(Element predictor, Instances miningSchema) throws Exception {
                super(predictor, miningSchema);
                this.m_valueIndex = -1;
                String valName = predictor.getAttribute("value");
                if (valName.length() == 0) {
                    throw new Exception("[CategoricalPredictor] attribute value not specified!");
                }
                this.m_valueName = valName;
                Attribute att = miningSchema.attribute(this.m_miningSchemaAttIndex);
                if (att.isString()) {
                    att.addStringValue(this.m_valueName);
                }
                this.m_valueIndex = att.indexOfValue(this.m_valueName);
                if (this.m_valueIndex == -1) {
                    throw new Exception("[CategoricalPredictor] unable to find value " + this.m_valueName + " in mining schema attribute " + att.name());
                }
            }

            @Override
            public String toString() {
                String output = super.toString();
                output = output + this.m_name + "=" + this.m_valueName;
                return output;
            }

            @Override
            public void add(double[] preds, double[] input) {
                if (this.m_valueIndex == (int)input[this.m_miningSchemaAttIndex]) {
                    if (RegressionTable.this.m_targetCategory == -1) {
                        preds[0] = preds[0] + this.m_coefficient;
                    } else {
                        int n = RegressionTable.this.m_targetCategory;
                        preds[n] = preds[n] + this.m_coefficient;
                    }
                }
            }
        }

        protected class NumericPredictor
        extends Predictor {
            private static final long serialVersionUID = -4335075205696648273L;
            protected double m_exponent;

            protected NumericPredictor(Element predictor, Instances miningSchema) throws Exception {
                super(predictor, miningSchema);
                this.m_exponent = 1.0;
                String exponent = predictor.getAttribute("exponent");
                if (exponent.length() > 0) {
                    this.m_exponent = Double.parseDouble(exponent);
                }
            }

            @Override
            public String toString() {
                String output = super.toString();
                output = output + this.m_name;
                if (this.m_exponent > 1.0 || this.m_exponent < 1.0) {
                    output = output + "^" + Utils.doubleToString(this.m_exponent, 4);
                }
                return output;
            }

            @Override
            public void add(double[] preds, double[] input) {
                if (RegressionTable.this.m_targetCategory == -1) {
                    preds[0] = preds[0] + this.m_coefficient * Math.pow(input[this.m_miningSchemaAttIndex], this.m_exponent);
                } else {
                    int n = RegressionTable.this.m_targetCategory;
                    preds[n] = preds[n] + this.m_coefficient * Math.pow(input[this.m_miningSchemaAttIndex], this.m_exponent);
                }
            }
        }

        static abstract class Predictor
        implements Serializable {
            private static final long serialVersionUID = 7043831847273383618L;
            protected String m_name;
            protected int m_miningSchemaAttIndex = -1;
            protected double m_coefficient = 1.0;

            protected Predictor(Element predictor, Instances miningSchema) throws Exception {
                this.m_name = predictor.getAttribute("name");
                for (int i = 0; i < miningSchema.numAttributes(); ++i) {
                    Attribute temp = miningSchema.attribute(i);
                    if (!temp.name().equals(this.m_name)) continue;
                    this.m_miningSchemaAttIndex = i;
                }
                if (this.m_miningSchemaAttIndex == -1) {
                    throw new Exception("[Predictor] unable to find matching attribute for predictor " + this.m_name);
                }
                String coeff = predictor.getAttribute("coefficient");
                if (coeff.length() > 0) {
                    this.m_coefficient = Double.parseDouble(coeff);
                }
            }

            public String toString() {
                return Utils.doubleToString(this.m_coefficient, 12, 4) + " * ";
            }

            public abstract void add(double[] var1, double[] var2);
        }
    }
}

