/*
 * Decompiled with CFR 0.152.
 */
package org.pentaho.di.scoring;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Vector;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.apache.commons.vfs.FileObject;
import org.pentaho.di.core.logging.LogChannelInterface;
import org.pentaho.di.core.row.RowDataUtil;
import org.pentaho.di.core.row.RowMetaInterface;
import org.pentaho.di.core.row.ValueMetaInterface;
import org.pentaho.di.core.variables.VariableSpace;
import org.pentaho.di.core.vfs.KettleVFS;
import org.pentaho.di.i18n.BaseMessages;
import org.pentaho.di.scoring.WekaScoringClusterer;
import org.pentaho.di.scoring.WekaScoringMeta;
import org.pentaho.di.scoring.WekaScoringModel;
import org.pentaho.di.trans.step.BaseStepData;
import org.pentaho.di.trans.step.StepDataInterface;
import weka.clusterers.Clusterer;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.pmml.PMMLFactory;
import weka.core.pmml.PMMLModel;
import weka.core.xml.XStream;

public class WekaScoringData
extends BaseStepData
implements StepDataInterface {
    public static final int NO_MATCH = -1;
    public static final int TYPE_MISMATCH = -2;
    protected RowMetaInterface m_outputRowMeta;
    private double[] m_vals = null;
    protected WekaScoringModel m_model;
    protected WekaScoringModel m_defaultModel;
    private int[] m_mappingIndexes;
    protected boolean m_updateIncrementalModel = false;

    public void setModel(WekaScoringModel model) {
        this.m_model = model;
    }

    public WekaScoringModel getModel() {
        return this.m_model;
    }

    public void setDefaultModel(WekaScoringModel model) {
        this.m_defaultModel = model;
    }

    public WekaScoringModel getDefaultModel() {
        return this.m_defaultModel;
    }

    public RowMetaInterface getOutputRowMeta() {
        return this.m_outputRowMeta;
    }

    public void setOutputRowMeta(RowMetaInterface rmi) {
        this.m_outputRowMeta = rmi;
    }

    public void mapIncomingRowMetaData(Instances header, RowMetaInterface inputRowMeta, boolean updateIncrementalModel, LogChannelInterface log) {
        this.m_mappingIndexes = WekaScoringData.findMappings(header, inputRowMeta);
        this.m_updateIncrementalModel = updateIncrementalModel;
        if (this.m_updateIncrementalModel && this.m_model.isSupervisedLearningModel()) {
            if (this.m_model.isUpdateableModel()) {
                if (this.m_mappingIndexes[header.classIndex()] == -1 || this.m_mappingIndexes[header.classIndex()] == -2) {
                    this.m_updateIncrementalModel = false;
                    log.logError(BaseMessages.getString(WekaScoringMeta.PKG, (String)"WekaScoringMeta.Log.NoMatchForClass", (String[])new String[0]));
                }
            } else {
                this.m_updateIncrementalModel = false;
                log.logError(BaseMessages.getString(WekaScoringMeta.PKG, (String)"WekaScoringMeta.Log.ModelNotUpdateable", (String[])new String[0]));
            }
        }
    }

    public static boolean modelFileExists(String modelFile, VariableSpace space) throws Exception {
        modelFile = space.environmentSubstitute(modelFile);
        FileObject modelF = KettleVFS.getFileObject((String)modelFile);
        return modelF.exists();
    }

    /*
     * WARNING - void declaration
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public static WekaScoringModel loadSerializedModel(String modelFile, LogChannelInterface log, VariableSpace space) throws Exception {
        void var3_7;
        Object var3_3 = null;
        Instances header = null;
        int[] ignoredAttsForClustering = null;
        FileObject modelF = KettleVFS.getFileObject((String)(modelFile = space.environmentSubstitute(modelFile)));
        if (!modelF.exists()) {
            throw new Exception(BaseMessages.getString(WekaScoringMeta.PKG, (String)"WekaScoring.Error.NonExistentModelFile", (String[])new String[]{space.environmentSubstitute(modelFile)}));
        }
        InputStream is = KettleVFS.getInputStream((FileObject)modelF);
        BufferedInputStream buff = new BufferedInputStream(is);
        if (modelFile.toLowerCase().endsWith(".xml")) {
            PMMLModel pMMLModel = PMMLFactory.getPMMLModel((InputStream)buff, null);
            header = pMMLModel.getMiningSchema().getMiningSchemaAsInstances();
            buff.close();
        } else if (modelFile.toLowerCase().endsWith(".xstreammodel")) {
            log.logBasic(BaseMessages.getString(WekaScoringMeta.PKG, (String)"WekaScoringData.Log.LoadXMLModel", (String[])new String[0]));
            if (!XStream.isPresent()) {
                buff.close();
                throw new Exception(BaseMessages.getString(WekaScoringMeta.PKG, (String)"WekaScoringData.Error.CantLoadXMLModel", (String[])new String[0]));
            }
            Vector vector = (Vector)XStream.read((InputStream)buff);
            Object e = vector.elementAt(0);
            if (vector.size() == 2) {
                header = (Instances)vector.elementAt(1);
            }
            buff.close();
        } else {
            void var9_16;
            BufferedInputStream bufferedInputStream = buff;
            if (modelFile.toLowerCase().endsWith(".gz")) {
                GZIPInputStream gZIPInputStream = new GZIPInputStream(buff);
            }
            ObjectInputStream oi = new ObjectInputStream((InputStream)var9_16);
            Object object = oi.readObject();
            header = (Instances)oi.readObject();
            if (object instanceof Clusterer) {
                try {
                    ignoredAttsForClustering = (int[])oi.readObject();
                }
                catch (Exception ex) {
                    // empty catch block
                }
            }
            oi.close();
        }
        WekaScoringModel wekaScoringModel = WekaScoringModel.createScorer(var3_7);
        wekaScoringModel.setHeader(header);
        if (wekaScoringModel instanceof WekaScoringClusterer && ignoredAttsForClustering != null) {
            ((WekaScoringClusterer)wekaScoringModel).setAttributesToIgnore(ignoredAttsForClustering);
        }
        wekaScoringModel.setLog(log);
        return wekaScoringModel;
    }

    public static void saveSerializedModel(WekaScoringModel wsm, File saveTo) throws Exception {
        Object model = wsm.getModel();
        Instances header = wsm.getHeader();
        OutputStream os = new FileOutputStream(saveTo);
        if (saveTo.getName().toLowerCase().endsWith(".gz")) {
            os = new GZIPOutputStream(os);
        }
        ObjectOutputStream oos = new ObjectOutputStream(new BufferedOutputStream(os));
        oos.writeObject(model);
        oos.writeObject(header);
        oos.close();
    }

    public static int[] findMappings(Instances header, RowMetaInterface inputRowMeta) {
        int i;
        int[] mappingIndexes = new int[header.numAttributes()];
        HashMap<String, Integer> inputFieldLookup = new HashMap<String, Integer>();
        for (i = 0; i < inputRowMeta.size(); ++i) {
            ValueMetaInterface inField = inputRowMeta.getValueMeta(i);
            inputFieldLookup.put(inField.getName(), i);
        }
        for (i = 0; i < header.numAttributes(); ++i) {
            Attribute temp = header.attribute(i);
            String attName = temp.name();
            Integer matchIndex = (Integer)inputFieldLookup.get(attName);
            boolean ok = false;
            int status = -1;
            if (matchIndex != null) {
                ValueMetaInterface tempField = inputRowMeta.getValueMeta(matchIndex.intValue());
                if (tempField.isNumeric() || tempField.isBoolean()) {
                    if (temp.isNumeric()) {
                        ok = true;
                        status = 0;
                    } else {
                        status = -2;
                    }
                } else if (tempField.isString()) {
                    if (temp.isNominal() || temp.isString()) {
                        ok = true;
                        status = 0;
                    } else {
                        status = -2;
                    }
                } else {
                    status = -2;
                }
            }
            mappingIndexes[i] = ok ? matchIndex : status;
        }
        return mappingIndexes;
    }

    public Object[][] generatePredictions(RowMetaInterface inputMeta, RowMetaInterface outputMeta, List<Object[]> inputRows, WekaScoringMeta meta) throws Exception {
        int[] mappingIndexes = this.m_mappingIndexes;
        WekaScoringModel model = this.getModel();
        boolean outputProbs = meta.getOutputProbabilities();
        boolean supervised = model.isSupervisedLearningModel();
        Attribute classAtt = null;
        if (supervised) {
            classAtt = model.getHeader().classAttribute();
        }
        Instances batch = new Instances(model.getHeader(), inputRows.size());
        for (Object[] r : inputRows) {
            Instance inst = this.constructInstance(inputMeta, r, mappingIndexes, model, true);
            batch.add(inst);
        }
        double[][] preds = model.distributionsForInstances(batch);
        Object[][] result = new Object[preds.length][];
        for (int i = 0; i < preds.length; ++i) {
            Object newVal;
            Object[] resultRow = RowDataUtil.resizeArray((Object[])inputRows.get(i), (int)outputMeta.size());
            int index = inputMeta.size();
            double[] prediction = preds[i];
            if (prediction.length == 1 || !outputProbs) {
                if (supervised) {
                    if (classAtt.isNumeric()) {
                        Double newVal2 = new Double(prediction[0]);
                        resultRow[index++] = newVal2;
                    } else {
                        int maxProb = Utils.maxIndex((double[])prediction);
                        if (prediction[maxProb] > 0.0) {
                            newVal = classAtt.value(maxProb);
                            resultRow[index++] = newVal;
                        } else {
                            newVal = BaseMessages.getString(WekaScoringMeta.PKG, (String)"WekaScoringData.Message.UnableToPredict", (String[])new String[0]);
                            resultRow[index++] = newVal;
                        }
                    }
                } else {
                    int maxProb = Utils.maxIndex((double[])prediction);
                    if (prediction[maxProb] > 0.0) {
                        newVal = new Double(maxProb);
                        resultRow[index++] = newVal;
                    } else {
                        newVal = BaseMessages.getString(WekaScoringMeta.PKG, (String)"WekaScoringData.Message.UnableToPredictCluster", (String[])new String[0]);
                        resultRow[index++] = newVal;
                    }
                }
            } else {
                for (int j = 0; j < prediction.length; ++j) {
                    newVal = new Double(prediction[j]);
                    resultRow[index++] = newVal;
                }
            }
            result[i] = resultRow;
        }
        return result;
    }

    public Object[] generatePrediction(RowMetaInterface inputMeta, RowMetaInterface outputMeta, Object[] inputRow, WekaScoringMeta meta) throws Exception {
        int[] mappingIndexes = this.m_mappingIndexes;
        WekaScoringModel model = this.getModel();
        boolean outputProbs = meta.getOutputProbabilities();
        boolean supervised = model.isSupervisedLearningModel();
        Attribute classAtt = null;
        if (supervised) {
            classAtt = model.getHeader().classAttribute();
        }
        Instance toScore = this.constructInstance(inputMeta, inputRow, mappingIndexes, model, false);
        double[] prediction = model.distributionForInstance(toScore);
        if (meta.getUpdateIncrementalModel() && model.isUpdateableModel() && !toScore.isMissing(toScore.classIndex())) {
            model.update(toScore);
        }
        Object[] resultRow = RowDataUtil.resizeArray((Object[])inputRow, (int)outputMeta.size());
        int index = inputMeta.size();
        if (prediction.length == 1 || !outputProbs) {
            if (supervised) {
                if (classAtt.isNumeric()) {
                    Double newVal = new Double(prediction[0]);
                    resultRow[index++] = newVal;
                } else {
                    int maxProb = Utils.maxIndex((double[])prediction);
                    if (prediction[maxProb] > 0.0) {
                        String newVal = classAtt.value(maxProb);
                        resultRow[index++] = newVal;
                    } else {
                        String newVal = BaseMessages.getString(WekaScoringMeta.PKG, (String)"WekaScoringData.Message.UnableToPredict", (String[])new String[0]);
                        resultRow[index++] = newVal;
                    }
                }
            } else {
                int maxProb = Utils.maxIndex((double[])prediction);
                if (prediction[maxProb] > 0.0) {
                    Double newVal = new Double(maxProb);
                    resultRow[index++] = newVal;
                } else {
                    String newVal = BaseMessages.getString(WekaScoringMeta.PKG, (String)"WekaScoringData.Message.UnableToPredictCluster", (String[])new String[0]);
                    resultRow[index++] = newVal;
                }
            }
        } else {
            for (int i = 0; i < prediction.length; ++i) {
                Double newVal = new Double(prediction[i]);
                resultRow[index++] = newVal;
            }
        }
        return resultRow;
    }

    private Instance constructInstance(RowMetaInterface inputMeta, Object[] inputRow, int[] mappingIndexes, WekaScoringModel model, boolean freshVector) {
        Instances header = model.getHeader();
        if (this.m_vals == null || freshVector) {
            this.m_vals = new double[header.numAttributes()];
        }
        for (int i = 0; i < header.numAttributes(); ++i) {
            if (mappingIndexes[i] >= 0) {
                try {
                    Object inputVal = inputRow[mappingIndexes[i]];
                    Attribute temp = header.attribute(i);
                    ValueMetaInterface tempField = inputMeta.getValueMeta(mappingIndexes[i]);
                    int fieldType = tempField.getType();
                    if (tempField.isNull(inputVal)) {
                        this.m_vals[i] = Utils.missingValue();
                        continue;
                    }
                    switch (temp.type()) {
                        case 0: {
                            if (fieldType == 4) {
                                Boolean b = tempField.getBoolean(inputVal);
                                if (b.booleanValue()) {
                                    this.m_vals[i] = 1.0;
                                    break;
                                }
                                this.m_vals[i] = 0.0;
                                break;
                            }
                            if (fieldType == 5) {
                                Long t = tempField.getInteger(inputVal);
                                this.m_vals[i] = t.longValue();
                                break;
                            }
                            Double n = tempField.getNumber(inputVal);
                            this.m_vals[i] = n;
                            break;
                        }
                        case 1: {
                            String s = tempField.getString(inputVal);
                            int index = temp.indexOfValue(s);
                            if (index < 0) {
                                this.m_vals[i] = Utils.missingValue();
                                break;
                            }
                            this.m_vals[i] = index;
                            break;
                        }
                        case 2: {
                            String s = tempField.getString(inputVal);
                            temp.setStringValue(s);
                            this.m_vals[i] = 0.0;
                            break;
                        }
                        default: {
                            this.m_vals[i] = Utils.missingValue();
                        }
                    }
                }
                catch (Exception e) {
                    this.m_vals[i] = Utils.missingValue();
                }
                continue;
            }
            this.m_vals[i] = Utils.missingValue();
        }
        DenseInstance newInst = new DenseInstance(1.0, this.m_vals);
        newInst.setDataset(header);
        return newInst;
    }
}

