/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableMultipleClassifiersCombiner;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;

public class MultiScheme
extends RandomizableMultipleClassifiersCombiner {
    static final long serialVersionUID = 5710744346128957520L;
    protected Classifier m_Classifier;
    protected int m_ClassifierIndex;
    protected int m_NumXValFolds;

    public String globalInfo() {
        return "Class for selecting a classifier from among several using cross validation on the training data or the performance on the training data. Performance is measured based on percent correct (classification) or mean-squared error (regression).";
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(1);
        newVector.addElement(new Option("\tUse cross validation for model selection using the\n\tgiven number of folds. (default 0, is to\n\tuse training error)", "X", 1, "-X <number of folds>"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement((Option)enu.nextElement());
        }
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String numFoldsString = Utils.getOption('X', options);
        if (numFoldsString.length() != 0) {
            this.setNumFolds(Integer.parseInt(numFoldsString));
        } else {
            this.setNumFolds(0);
        }
        super.setOptions(options);
    }

    @Override
    public String[] getOptions() {
        String[] superOptions = super.getOptions();
        String[] options = new String[superOptions.length + 2];
        int current = 0;
        options[current++] = "-X";
        options[current++] = "" + this.getNumFolds();
        System.arraycopy(superOptions, 0, options, current, superOptions.length);
        return options;
    }

    @Override
    public String classifiersTipText() {
        return "The classifiers to be chosen from.";
    }

    @Override
    public void setClassifiers(Classifier[] classifiers) {
        this.m_Classifiers = classifiers;
    }

    @Override
    public Classifier[] getClassifiers() {
        return this.m_Classifiers;
    }

    @Override
    public Classifier getClassifier(int index) {
        return this.m_Classifiers[index];
    }

    @Override
    protected String getClassifierSpec(int index) {
        if (this.m_Classifiers.length < index) {
            return "";
        }
        Classifier c = this.getClassifier(index);
        if (c instanceof OptionHandler) {
            return c.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)((Object)c)).getOptions());
        }
        return c.getClass().getName();
    }

    @Override
    public String seedTipText() {
        return "The seed used for randomizing the data for cross-validation.";
    }

    @Override
    public void setSeed(int seed) {
        this.m_Seed = seed;
    }

    @Override
    public int getSeed() {
        return this.m_Seed;
    }

    public String numFoldsTipText() {
        return "The number of folds used for cross-validation (if 0, performance on training data will be used).";
    }

    public int getNumFolds() {
        return this.m_NumXValFolds;
    }

    public void setNumFolds(int numFolds) {
        this.m_NumXValFolds = numFolds;
    }

    @Override
    public String debugTipText() {
        return "Whether debug information is output to console.";
    }

    @Override
    public void setDebug(boolean debug) {
        this.m_Debug = debug;
    }

    @Override
    public boolean getDebug() {
        return this.m_Debug;
    }

    public int getBestClassifierIndex() {
        return this.m_ClassifierIndex;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        if (this.m_Classifiers.length == 0) {
            throw new Exception("No base classifiers have been set!");
        }
        this.getCapabilities().testWithFail(data);
        Instances newData = new Instances(data);
        newData.deleteWithMissingClass();
        Random random = new Random(this.m_Seed);
        newData.randomize(random);
        if (newData.classAttribute().isNominal() && this.m_NumXValFolds > 1) {
            newData.stratify(this.m_NumXValFolds);
        }
        Instances train = newData;
        Instances test = newData;
        Classifier bestClassifier = null;
        int bestIndex = -1;
        double bestPerformance = Double.NaN;
        int numClassifiers = this.m_Classifiers.length;
        for (int i = 0; i < numClassifiers; ++i) {
            Evaluation evaluation;
            Classifier currentClassifier = this.getClassifier(i);
            if (this.m_NumXValFolds > 1) {
                evaluation = new Evaluation(newData);
                for (int j = 0; j < this.m_NumXValFolds; ++j) {
                    train = newData.trainCV(this.m_NumXValFolds, j, new Random(1L));
                    test = newData.testCV(this.m_NumXValFolds, j);
                    currentClassifier.buildClassifier(train);
                    evaluation.setPriors(train);
                    evaluation.evaluateModel(currentClassifier, test, new Object[0]);
                }
            } else {
                currentClassifier.buildClassifier(train);
                evaluation = new Evaluation(train);
                evaluation.evaluateModel(currentClassifier, test, new Object[0]);
            }
            double error = evaluation.errorRate();
            if (this.m_Debug) {
                System.err.println("Error rate: " + Utils.doubleToString(error, 6, 4) + " for classifier " + currentClassifier.getClass().getName());
            }
            if (i != 0 && !(error < bestPerformance)) continue;
            bestClassifier = currentClassifier;
            bestPerformance = error;
            bestIndex = i;
        }
        this.m_ClassifierIndex = bestIndex;
        if (this.m_NumXValFolds > 1) {
            bestClassifier.buildClassifier(newData);
        }
        this.m_Classifier = bestClassifier;
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        return this.m_Classifier.distributionForInstance(instance);
    }

    public String toString() {
        if (this.m_Classifier == null) {
            return "MultiScheme: No model built yet.";
        }
        String result = "MultiScheme selection using";
        result = this.m_NumXValFolds > 1 ? result + " cross validation error" : result + " error on training data";
        result = result + " from the following:\n";
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            result = result + '\t' + this.getClassifierSpec(i) + '\n';
        }
        result = result + "Selected scheme: " + this.getClassifierSpec(this.m_ClassifierIndex) + "\n\n" + this.m_Classifier.toString();
        return result;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8048 $");
    }

    public static void main(String[] argv) {
        MultiScheme.runClassifier(new MultiScheme(), argv);
    }
}

