/*
 * Decompiled with CFR 0.152.
 */
package moa.evaluation;

import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.InstanceImpl;
import com.yahoo.labs.samoa.instances.Prediction;
import java.util.TreeSet;
import moa.core.Example;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.core.Utils;
import moa.evaluation.ClassificationPerformanceEvaluator;
import moa.options.AbstractOptionHandler;
import moa.tasks.TaskMonitor;

public class WindowAUCImbalancedPerformanceEvaluator
extends AbstractOptionHandler
implements ClassificationPerformanceEvaluator {
    private static final long serialVersionUID = 1L;
    public IntOption widthOption = new IntOption("width", 'w', "Size of Window", 500);
    protected double totalObservedInstances = 0.0;
    private Estimator aucEstimator;
    private SimpleEstimator weightMajorityClassifier;
    protected int numClasses;

    @Override
    public void reset() {
        this.reset(this.numClasses);
    }

    public void reset(int numClasses) {
        if (numClasses != 2) {
            throw new RuntimeException("Too many classes (" + numClasses + "). AUC evaluation can be performed only for two-class problems!");
        }
        this.numClasses = numClasses;
        this.aucEstimator = new Estimator(this.widthOption.getValue());
        this.weightMajorityClassifier = new SimpleEstimator();
        this.totalObservedInstances = 0.0;
    }

    @Override
    public void addResult(Example<Instance> exampleInstance, double[] classVotes) {
        InstanceImpl inst = (InstanceImpl)exampleInstance.getData();
        double weight = inst.weight();
        if (!inst.classIsMissing()) {
            int trueClass = (int)inst.classValue();
            if (weight > 0.0) {
                if (this.totalObservedInstances == 0.0) {
                    this.reset(inst.dataset().numClasses());
                }
                this.totalObservedInstances += 1.0;
                Double normalizedVote = 0.0;
                if (classVotes.length == 2) {
                    normalizedVote = classVotes[1] / (classVotes[0] + classVotes[1]);
                }
                if (normalizedVote.isNaN()) {
                    normalizedVote = 0.0;
                }
                this.aucEstimator.add(normalizedVote, trueClass == 1, Utils.maxIndex(classVotes) == trueClass);
                this.weightMajorityClassifier.add((this.aucEstimator.getRatio() <= 1.0 ? 0 : 1) == trueClass ? weight : 0.0);
            }
        }
    }

    @Override
    public Measurement[] getPerformanceMeasurements() {
        return new Measurement[]{new Measurement("classified instances", this.totalObservedInstances), new Measurement("AUC", this.aucEstimator.getAUC()), new Measurement("sAUC", this.aucEstimator.getScoredAUC()), new Measurement("Accuracy", this.aucEstimator.getAccuracy()), new Measurement("Kappa", this.aucEstimator.getKappa()), new Measurement("Periodical holdout AUC", this.aucEstimator.getHoldoutAUC()), new Measurement("Pos/Neg ratio", this.aucEstimator.getRatio()), new Measurement("G-Mean", this.aucEstimator.getGMean()), new Measurement("Recall", this.aucEstimator.getRecall()), new Measurement("KappaM", this.aucEstimator.getKappaM())};
    }

    @Override
    public void getDescription(StringBuilder sb, int indent) {
        Measurement.getMeasurementsDescription(this.getPerformanceMeasurements(), sb, indent);
    }

    @Override
    public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
    }

    public Estimator getAucEstimator() {
        return this.aucEstimator;
    }

    @Override
    public void addResult(Example<Instance> arg0, Prediction arg1) {
        throw new RuntimeException("Designed for scoring classifiers");
    }

    public class Estimator {
        protected TreeSet<Score> sortedScores = new TreeSet();
        protected TreeSet<Score> holdoutSortedScores = new TreeSet();
        protected Score[] window;
        protected double[] predictions;
        protected int posWindow;
        protected int size;
        protected double numPos;
        protected double numNeg;
        protected double holdoutNumPos;
        protected double holdoutNumNeg;
        protected double correctPredictions;
        protected double correctPositivePredictions;
        protected double[] columnKappa;
        protected double[] rowKappa;

        public Estimator(int sizeWindow) {
            this.size = sizeWindow;
            this.window = new Score[sizeWindow];
            this.predictions = new double[sizeWindow];
            this.rowKappa = new double[WindowAUCImbalancedPerformanceEvaluator.this.numClasses];
            this.columnKappa = new double[WindowAUCImbalancedPerformanceEvaluator.this.numClasses];
            for (int i = 0; i < WindowAUCImbalancedPerformanceEvaluator.this.numClasses; ++i) {
                this.rowKappa[i] = 0.0;
                this.columnKappa[i] = 0.0;
            }
            this.posWindow = 0;
            this.numPos = 0.0;
            this.numNeg = 0.0;
            this.holdoutNumPos = 0.0;
            this.holdoutNumNeg = 0.0;
            this.correctPredictions = 0.0;
            this.correctPositivePredictions = 0.0;
        }

        public void add(double score, boolean isPositive, boolean correctPrediction) {
            int predictedClass;
            if (this.size > 0 && this.posWindow % this.size == 0) {
                this.holdoutSortedScores = new TreeSet();
                for (Score s : this.sortedScores) {
                    this.holdoutSortedScores.add(s);
                }
                this.holdoutNumPos = this.numPos;
                this.holdoutNumNeg = this.numNeg;
            }
            if (this.size > 0 && this.posWindow >= this.size) {
                int oldestExamplePredictedClass;
                this.sortedScores.remove(this.window[this.posWindow % this.size]);
                this.correctPredictions -= this.predictions[this.posWindow % this.size];
                this.correctPositivePredictions -= this.window[this.posWindow % this.size].isPositive ? this.predictions[this.posWindow % this.size] : 0.0;
                if (this.window[this.posWindow % this.size].isPositive) {
                    this.numPos -= 1.0;
                } else {
                    this.numNeg -= 1.0;
                }
                int oldestExampleTrueClass = this.window[this.posWindow % this.size].isPositive ? 1 : 0;
                int n = oldestExamplePredictedClass = this.predictions[this.posWindow % this.size] == 1.0 ? oldestExampleTrueClass : Math.abs(oldestExampleTrueClass - 1);
                this.rowKappa[n] = this.rowKappa[n] - 1.0;
                int n2 = oldestExampleTrueClass;
                this.columnKappa[n2] = this.columnKappa[n2] - 1.0;
            }
            Score newScore = new Score(score, this.posWindow, isPositive);
            this.sortedScores.add(newScore);
            this.correctPredictions += correctPrediction ? 1.0 : 0.0;
            this.correctPositivePredictions += correctPrediction && isPositive ? 1.0 : 0.0;
            int trueClass = isPositive ? 1 : 0;
            int n = predictedClass = correctPrediction ? trueClass : Math.abs(trueClass - 1);
            this.rowKappa[n] = this.rowKappa[n] + 1.0;
            int n3 = trueClass;
            this.columnKappa[n3] = this.columnKappa[n3] + 1.0;
            if (newScore.isPositive) {
                this.numPos += 1.0;
            } else {
                this.numNeg += 1.0;
            }
            if (this.size > 0) {
                this.window[this.posWindow % this.size] = newScore;
                this.predictions[this.posWindow % this.size] = correctPrediction ? 1.0 : 0.0;
            }
            ++this.posWindow;
        }

        public double getAUC() {
            double AUC = 0.0;
            double c = 0.0;
            double prevc = 0.0;
            double lastPosScore = Double.MAX_VALUE;
            if (this.numPos == 0.0 || this.numNeg == 0.0) {
                return 1.0;
            }
            for (Score s : this.sortedScores) {
                if (s.isPositive) {
                    if (s.value != lastPosScore) {
                        prevc = c;
                        lastPosScore = s.value;
                    }
                    c += 1.0;
                    continue;
                }
                if (s.value == lastPosScore) {
                    AUC += (c + prevc) / 2.0;
                    continue;
                }
                AUC += c;
            }
            return AUC / (this.numPos * this.numNeg);
        }

        public double getHoldoutAUC() {
            double AUC = 0.0;
            double c = 0.0;
            double prevc = 0.0;
            double lastPosScore = Double.MAX_VALUE;
            if (this.holdoutSortedScores.isEmpty()) {
                return 0.0;
            }
            if (this.holdoutNumPos == 0.0 || this.holdoutNumNeg == 0.0) {
                return 1.0;
            }
            for (Score s : this.holdoutSortedScores) {
                if (s.isPositive) {
                    if (s.value != lastPosScore) {
                        prevc = c;
                        lastPosScore = s.value;
                    }
                    c += 1.0;
                    continue;
                }
                if (s.value == lastPosScore) {
                    AUC += (c + prevc) / 2.0;
                    continue;
                }
                AUC += c;
            }
            return AUC / (this.holdoutNumPos * this.holdoutNumNeg);
        }

        public double getScoredAUC() {
            double AOC = 0.0;
            double AUC = 0.0;
            double r = 0.0;
            double prevr = 0.0;
            double c = 0.0;
            double prevc = 0.0;
            double lastPosScore = Double.MAX_VALUE;
            double lastNegScore = Double.MAX_VALUE;
            if (this.numPos == 0.0 || this.numNeg == 0.0) {
                return 1.0;
            }
            for (Score s : this.sortedScores) {
                if (s.isPositive) {
                    if (s.value != lastPosScore) {
                        prevc = c;
                        lastPosScore = s.value;
                    }
                    c += s.value;
                    if (s.value == lastNegScore) {
                        AOC += (r + prevr) / 2.0;
                        continue;
                    }
                    AOC += r;
                    continue;
                }
                if (s.value != lastNegScore) {
                    prevr = r;
                    lastNegScore = s.value;
                }
                r += s.value;
                if (s.value == lastPosScore) {
                    AUC += (c + prevc) / 2.0;
                    continue;
                }
                AUC += c;
            }
            double R_minus = (this.numPos * r - AOC) / (this.numPos * this.numNeg);
            double R_plus = AUC / (this.numPos * this.numNeg);
            return R_plus - R_minus;
        }

        public double getRatio() {
            if (this.numNeg == 0.0) {
                return Double.MAX_VALUE;
            }
            return this.numPos / this.numNeg;
        }

        public double getAccuracy() {
            if (this.size > 0) {
                return WindowAUCImbalancedPerformanceEvaluator.this.totalObservedInstances > 0.0 ? this.correctPredictions / Math.min((double)this.size, WindowAUCImbalancedPerformanceEvaluator.this.totalObservedInstances) : 0.0;
            }
            return WindowAUCImbalancedPerformanceEvaluator.this.totalObservedInstances > 0.0 ? this.correctPredictions / WindowAUCImbalancedPerformanceEvaluator.this.totalObservedInstances : 0.0;
        }

        public double getKappa() {
            double p0 = this.getAccuracy();
            double pc = 0.0;
            if (this.size > 0) {
                for (int i = 0; i < WindowAUCImbalancedPerformanceEvaluator.this.numClasses; ++i) {
                    pc += this.rowKappa[i] / Math.min((double)this.size, WindowAUCImbalancedPerformanceEvaluator.this.totalObservedInstances) * (this.columnKappa[i] / Math.min((double)this.size, WindowAUCImbalancedPerformanceEvaluator.this.totalObservedInstances));
                }
            } else {
                for (int i = 0; i < WindowAUCImbalancedPerformanceEvaluator.this.numClasses; ++i) {
                    pc += this.rowKappa[i] / WindowAUCImbalancedPerformanceEvaluator.this.totalObservedInstances * (this.columnKappa[i] / WindowAUCImbalancedPerformanceEvaluator.this.totalObservedInstances);
                }
            }
            return (p0 - pc) / (1.0 - pc);
        }

        private double getKappaM() {
            double p0 = this.getAccuracy();
            double pc = WindowAUCImbalancedPerformanceEvaluator.this.weightMajorityClassifier.estimation();
            return (p0 - pc) / (1.0 - pc);
        }

        public double getGMean() {
            double positiveAccuracy = this.correctPositivePredictions / this.numPos;
            double negativeAccuracy = (this.correctPredictions - this.correctPositivePredictions) / this.numNeg;
            return Math.sqrt(positiveAccuracy * negativeAccuracy);
        }

        public double getRecall() {
            return this.correctPositivePredictions / this.numPos;
        }

        public class Score
        implements Comparable<Score> {
            protected double value;
            protected int posWindow;
            protected boolean isPositive;

            public Score(double value, int position, boolean isPositive) {
                this.value = value;
                this.posWindow = position;
                this.isPositive = isPositive;
            }

            @Override
            public int compareTo(Score o) {
                if (o.value < this.value) {
                    return -1;
                }
                if (o.value > this.value) {
                    return 1;
                }
                if (!o.isPositive && this.isPositive) {
                    return -1;
                }
                if (o.isPositive && !this.isPositive) {
                    return 1;
                }
                if (o.posWindow > this.posWindow) {
                    return -1;
                }
                if (o.posWindow < this.posWindow) {
                    return 1;
                }
                return 0;
            }

            public boolean equals(Object o) {
                return o instanceof Score && ((Score)o).posWindow == this.posWindow;
            }
        }
    }

    public class SimpleEstimator {
        protected double len;
        protected double sum;

        public void add(double value) {
            this.sum += value;
            this.len += 1.0;
        }

        public double estimation() {
            return this.sum / this.len;
        }
    }
}

