package org.grobid.trainer.evaluation;

import java.util.Iterator;
import java.util.Set;
import java.util.TreeMap;
import org.grobid.core.exceptions.GrobidException;
import org.grobid.core.utilities.TextUtilities;

/* loaded from: input_file:org/grobid/trainer/evaluation/Stats.class */
public final class Stats {
    private boolean requiredToRecomputeMetrics = true;
    private double cumulated_tp = 0.0d;
    private double cumulated_fp = 0.0d;
    private double cumulated_tn = 0.0d;
    private double cumulated_fn = 0.0d;
    private double cumulated_f1 = 0.0d;
    private double cumulated_accuracy = 0.0d;
    private double cumulated_precision = 0.0d;
    private double cumulated_recall = 0.0d;
    private double cumulated_expected = 0.0d;
    private int totalValidFields = 0;
    private final TreeMap<String, LabelStat> labelStats = new TreeMap<>();

    public Set<String> getLabels() {
        return this.labelStats.keySet();
    }

    public void incrementFalsePositive(String str) {
        incrementFalsePositive(str, 1);
    }

    public void incrementFalsePositive(String str, int i) {
        LabelStat labelStat = getLabelStat(str);
        if (labelStat == null) {
            throw new GrobidException("Unknown label: " + str);
        }
        labelStat.incrementFalsePositive(i);
        this.requiredToRecomputeMetrics = true;
    }

    public void incrementFalseNegative(String str) {
        incrementFalseNegative(str, 1);
    }

    public void incrementFalseNegative(String str, int i) {
        LabelStat labelStat = getLabelStat(str);
        if (labelStat == null) {
            throw new GrobidException("Unknown label: " + str);
        }
        labelStat.incrementFalseNegative(i);
        this.requiredToRecomputeMetrics = true;
    }

    public void incrementObserved(String str) {
        incrementObserved(str, 1);
    }

    public void incrementObserved(String str, int i) {
        LabelStat labelStat = getLabelStat(str);
        if (labelStat == null) {
            throw new GrobidException("Unknown label: " + str);
        }
        labelStat.incrementObserved(i);
        this.requiredToRecomputeMetrics = true;
    }

    public void incrementExpected(String str) {
        incrementExpected(str, 1);
    }

    public void incrementExpected(String str, int i) {
        LabelStat labelStat = getLabelStat(str);
        if (labelStat == null) {
            throw new GrobidException("Unknown label: " + str);
        }
        labelStat.incrementExpected(i);
        this.requiredToRecomputeMetrics = true;
    }

    public LabelStat getLabelStat(String str) {
        if (this.labelStats.containsKey(str)) {
            return this.labelStats.get(str);
        }
        LabelStat create = LabelStat.create();
        this.labelStats.put(str, create);
        this.requiredToRecomputeMetrics = true;
        return create;
    }

    public int size() {
        return this.labelStats.size();
    }

    public double getPrecision(String str) {
        LabelStat labelStat = getLabelStat(str);
        if (labelStat == null) {
            throw new GrobidException("Unknown label: " + str);
        }
        return labelStat.getPrecision();
    }

    public double getRecall(String str) {
        LabelStat labelStat = getLabelStat(str);
        if (labelStat == null) {
            throw new GrobidException("Unknown label: " + str);
        }
        return labelStat.getRecall();
    }

    public double getF1Score(String str) {
        LabelStat labelStat = getLabelStat(str);
        if (labelStat == null) {
            throw new GrobidException("Unknown label: " + str);
        }
        return labelStat.getF1Score();
    }

    public void computeMetrics() {
        Iterator<String> it = getLabels().iterator();
        while (true) {
            if (it.hasNext()) {
                if (getLabelStat(it.next()).hasChanged()) {
                    this.requiredToRecomputeMetrics = true;
                    break;
                }
            } else {
                break;
            }
        }
        if (this.requiredToRecomputeMetrics) {
            int i = 0;
            Iterator<String> it2 = getLabels().iterator();
            while (it2.hasNext()) {
                LabelStat labelStat = getLabelStat(it2.next());
                i = i + labelStat.getObserved() + labelStat.getFalseNegative() + labelStat.getFalsePositive();
            }
            for (String str : getLabels()) {
                if (!str.equals("<other>") && !str.equals("base") && !str.equals("O")) {
                    LabelStat labelStat2 = getLabelStat(str);
                    int observed = labelStat2.getObserved();
                    int falsePositive = labelStat2.getFalsePositive();
                    int falseNegative = labelStat2.getFalseNegative();
                    int i2 = (i - observed) - (falsePositive + falseNegative);
                    labelStat2.setTrueNegative(i2);
                    int expected = labelStat2.getExpected();
                    if (expected != 0) {
                        this.totalValidFields++;
                    }
                    if (expected != 0) {
                        this.cumulated_tp += observed;
                        this.cumulated_fp += falsePositive;
                        this.cumulated_tn += i2;
                        this.cumulated_fn += falseNegative;
                        this.cumulated_expected += expected;
                        this.cumulated_f1 += labelStat2.getF1Score();
                        this.cumulated_accuracy += labelStat2.getAccuracy();
                        this.cumulated_precision += labelStat2.getPrecision();
                        this.cumulated_recall += labelStat2.getRecall();
                    }
                }
            }
            this.requiredToRecomputeMetrics = false;
        }
    }

    public TreeMap<String, LabelResult> getLabelsResults() {
        computeMetrics();
        TreeMap<String, LabelResult> treeMap = new TreeMap<>();
        for (String str : getLabels()) {
            if (!str.equals("<other>") && !str.equals("base") && !str.equals("O")) {
                LabelStat labelStat = getLabelStat(str);
                LabelResult labelResult = new LabelResult(str);
                labelResult.setAccuracy(labelStat.getAccuracy());
                labelResult.setPrecision(labelStat.getPrecision());
                labelResult.setRecall(labelStat.getRecall());
                labelResult.setF1Score(labelStat.getF1Score());
                labelResult.setSupport(labelStat.getSupport());
                treeMap.put(str, labelResult);
            }
        }
        return treeMap;
    }

    public double getMicroAverageAccuracy() {
        computeMetrics();
        if (this.totalValidFields == 0) {
            return 0.0d;
        }
        return Math.min(1.0d, this.cumulated_accuracy / this.totalValidFields);
    }

    public double getMacroAverageAccuracy() {
        computeMetrics();
        double d = 0.0d;
        if (this.cumulated_tp + this.cumulated_fp + this.cumulated_tn + this.cumulated_fn != 0.0d) {
            d = (this.cumulated_tp + this.cumulated_tn) / (((this.cumulated_tp + this.cumulated_fp) + this.cumulated_tn) + this.cumulated_fn);
        }
        return Math.min(1.0d, d);
    }

    public double getMicroAveragePrecision() {
        computeMetrics();
        double d = 0.0d;
        if (this.cumulated_tp + this.cumulated_fp != 0.0d) {
            d = this.cumulated_tp / (this.cumulated_tp + this.cumulated_fp);
        }
        return Math.min(1.0d, d);
    }

    public double getMacroAveragePrecision() {
        computeMetrics();
        if (this.totalValidFields == 0) {
            return 0.0d;
        }
        return Math.min(1.0d, this.cumulated_precision / this.totalValidFields);
    }

    public double getMicroAverageRecall() {
        computeMetrics();
        double d = 0.0d;
        if (this.cumulated_expected != 0.0d) {
            d = this.cumulated_tp / this.cumulated_expected;
        }
        return Math.min(1.0d, d);
    }

    public double getMacroAverageRecall() {
        computeMetrics();
        if (this.totalValidFields == 0) {
            return 0.0d;
        }
        return Math.min(1.0d, this.cumulated_recall / this.totalValidFields);
    }

    public int getTotalValidFields() {
        computeMetrics();
        return this.totalValidFields;
    }

    public double getMicroAverageF1() {
        double microAveragePrecision = getMicroAveragePrecision();
        double microAverageRecall = getMicroAverageRecall();
        double d = 0.0d;
        if (microAveragePrecision + microAverageRecall != 0.0d) {
            d = ((2.0d * microAveragePrecision) * microAverageRecall) / (microAveragePrecision + microAverageRecall);
        }
        return d;
    }

    public double getMacroAverageF1() {
        computeMetrics();
        if (this.totalValidFields == 0) {
            return 0.0d;
        }
        return Math.min(1.0d, this.cumulated_f1 / this.totalValidFields);
    }

    public String getOldReport() {
        computeMetrics();
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("\n%-20s %-12s %-12s %-12s %-12s %-7s\n\n", "label", "accuracy", "precision", "recall", "f1", "support"));
        long j = 0;
        for (String str : getLabels()) {
            if (!str.equals("<other>") && !str.equals("base") && !str.equals("O")) {
                LabelStat labelStat = getLabelStat(str);
                long support = labelStat.getSupport();
                sb.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", str, TextUtilities.formatTwoDecimals(labelStat.getAccuracy() * 100.0d), TextUtilities.formatTwoDecimals(labelStat.getPrecision() * 100.0d), TextUtilities.formatTwoDecimals(labelStat.getRecall() * 100.0d), TextUtilities.formatTwoDecimals(labelStat.getF1Score() * 100.0d), String.valueOf(support)));
                j += support;
            }
        }
        sb.append("\n");
        sb.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", "all (micro avg.)", TextUtilities.formatTwoDecimals(getMicroAverageAccuracy() * 100.0d), TextUtilities.formatTwoDecimals(getMicroAveragePrecision() * 100.0d), TextUtilities.formatTwoDecimals(getMicroAverageRecall() * 100.0d), TextUtilities.formatTwoDecimals(getMicroAverageF1() * 100.0d), String.valueOf(j)));
        sb.append(String.format("%-20s %-12s %-12s %-12s %-12s %-7s\n", "all (macro avg.)", TextUtilities.formatTwoDecimals(getMacroAverageAccuracy() * 100.0d), TextUtilities.formatTwoDecimals(getMacroAveragePrecision() * 100.0d), TextUtilities.formatTwoDecimals(getMacroAverageRecall() * 100.0d), TextUtilities.formatTwoDecimals(getMacroAverageF1() * 100.0d), String.valueOf(j)));
        return sb.toString();
    }
}
