package org.grobid.trainer;

import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.io.Writer;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.Random;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.text.RandomStringGenerator;
import org.grobid.core.GrobidModel;
import org.grobid.core.GrobidModels;
import org.grobid.core.engines.tagging.GenericTagger;
import org.grobid.core.engines.tagging.GrobidCRFEngine;
import org.grobid.core.engines.tagging.TaggerFactory;
import org.grobid.core.exceptions.GrobidException;
import org.grobid.core.factory.GrobidFactory;
import org.grobid.core.utilities.GrobidProperties;
import org.grobid.core.utilities.TextUtilities;
import org.grobid.trainer.evaluation.EvaluationUtilities;
import org.grobid.trainer.evaluation.LabelResult;
import org.grobid.trainer.evaluation.ModelStats;
import org.grobid.trainer.evaluation.Stats;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grobid/trainer/AbstractTrainer.class */
public abstract class AbstractTrainer implements Trainer {
    protected static final Logger LOGGER = LoggerFactory.getLogger(AbstractTrainer.class);
    public static final String OLD_MODEL_EXT = ".old";
    public static final String NEW_MODEL_EXT = ".new";
    protected double epsilon = 0.0d;
    protected int window = 0;
    protected int nbMaxIterations = 0;
    protected GrobidModel model;
    private File trainDataPath;
    private File evalDataPath;
    private GenericTagger tagger;
    private RandomStringGenerator randomStringGenerator;

    public AbstractTrainer(GrobidModel grobidModel) {
        GrobidFactory.getInstance().createEngine();
        this.model = grobidModel;
        if (grobidModel.equals(GrobidModels.DUMMY)) {
            return;
        }
        this.trainDataPath = getTempTrainingDataPath();
        this.evalDataPath = getTempEvaluationDataPath();
        this.randomStringGenerator = new RandomStringGenerator.Builder().withinRange(97, 122).build();
    }

    public void setParams(double d, int i, int i2) {
        this.epsilon = d;
        this.window = i;
        this.nbMaxIterations = i2;
    }

    @Override // org.grobid.trainer.Trainer
    public int createCRFPPData(File file, File file2) {
        return createCRFPPData(file, file2, null, 1.0d);
    }

    @Override // org.grobid.trainer.Trainer
    public void train() {
        File file = this.trainDataPath;
        createCRFPPData(getCorpusPath(), file);
        GenericTrainer trainer = TrainerFactory.getTrainer();
        if (this.epsilon != 0.0d) {
            trainer.setEpsilon(this.epsilon);
        }
        if (this.window != 0) {
            trainer.setWindow(this.window);
        }
        if (this.nbMaxIterations != 0) {
            trainer.setNbMaxIterations(this.nbMaxIterations);
        }
        File parentFile = new File(GrobidProperties.getModelPath(this.model).getAbsolutePath()).getParentFile();
        if (!parentFile.exists()) {
            LOGGER.warn("Cannot find the destination directory " + parentFile.getAbsolutePath() + " for the model " + this.model.getModelName() + ". Creating it.");
            parentFile.mkdir();
        }
        File file2 = new File(GrobidProperties.getModelPath(this.model).getAbsolutePath() + NEW_MODEL_EXT);
        File modelPath = GrobidProperties.getModelPath(this.model);
        trainer.train(getTemplatePath(), file, file2, GrobidProperties.getNBThreads().intValue(), this.model);
        if (GrobidProperties.getGrobidCRFEngine() != GrobidCRFEngine.DELFT) {
            renameModels(modelPath, file2);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void renameModels(File file, File file2) {
        if (file.exists() && !file.renameTo(new File(file.getAbsolutePath() + OLD_MODEL_EXT))) {
            LOGGER.warn("Unable to rename old model file: " + file.getAbsolutePath());
        } else {
            if (file2.renameTo(file)) {
                return;
            }
            LOGGER.warn("Unable to rename new model file: " + file2);
        }
    }

    @Override // org.grobid.trainer.Trainer
    public String evaluate() {
        return evaluate(false);
    }

    @Override // org.grobid.trainer.Trainer
    public String evaluate(boolean z) {
        createCRFPPData(getEvalCorpusPath(), this.evalDataPath);
        return EvaluationUtilities.evaluateStandard(this.evalDataPath.getAbsolutePath(), getTagger()).toString(z);
    }

    @Override // org.grobid.trainer.Trainer
    public String evaluate(GenericTagger genericTagger, boolean z) {
        createCRFPPData(getEvalCorpusPath(), this.evalDataPath);
        return EvaluationUtilities.evaluateStandard(this.evalDataPath.getAbsolutePath(), genericTagger).toString(z);
    }

    @Override // org.grobid.trainer.Trainer
    public String splitTrainEvaluate(Double d) {
        File file = this.trainDataPath;
        createCRFPPData(getCorpusPath(), file, this.evalDataPath, d.doubleValue());
        GenericTrainer trainer = TrainerFactory.getTrainer();
        if (this.epsilon != 0.0d) {
            trainer.setEpsilon(this.epsilon);
        }
        if (this.window != 0) {
            trainer.setWindow(this.window);
        }
        if (this.nbMaxIterations != 0) {
            trainer.setNbMaxIterations(this.nbMaxIterations);
        }
        File parentFile = new File(GrobidProperties.getModelPath(this.model).getAbsolutePath()).getParentFile();
        if (!parentFile.exists()) {
            LOGGER.warn("Cannot find the destination directory " + parentFile.getAbsolutePath() + " for the model " + this.model.getModelName() + ". Creating it.");
            parentFile.mkdir();
        }
        File file2 = new File(GrobidProperties.getModelPath(this.model).getAbsolutePath() + NEW_MODEL_EXT);
        File modelPath = GrobidProperties.getModelPath(this.model);
        trainer.train(getTemplatePath(), file, file2, GrobidProperties.getNBThreads().intValue(), this.model);
        renameModels(modelPath, file2);
        return EvaluationUtilities.evaluateStandard(this.evalDataPath.getAbsolutePath(), getTagger()).toString();
    }

    @Override // org.grobid.trainer.Trainer
    public String nFoldEvaluate(int i) {
        return nFoldEvaluate(i, false);
    }

    @Override // org.grobid.trainer.Trainer
    public String nFoldEvaluate(int i, boolean z) {
        File file = this.trainDataPath;
        createCRFPPData(getCorpusPath(), file);
        GenericTrainer trainer = TrainerFactory.getTrainer();
        String generate = this.randomStringGenerator.generate(10);
        List<ImmutablePair<String, String>> splitNFold = splitNFold(loadAndShuffle(Paths.get(file.getAbsolutePath(), new String[0])), i);
        if (this.epsilon != 0.0d) {
            trainer.setEpsilon(this.epsilon);
        }
        if (this.window != 0) {
            trainer.setWindow(this.window);
        }
        if (this.nbMaxIterations != 0) {
            trainer.setNbMaxIterations(this.nbMaxIterations);
        }
        File file2 = new File(GrobidProperties.getTempPath().getAbsolutePath());
        if (!file2.exists()) {
            LOGGER.warn("Cannot find the destination directory " + file2);
        }
        ArrayList arrayList = new ArrayList();
        StringBuilder sb = new StringBuilder();
        sb.append("Recap results for each fold:").append("\n\n");
        AtomicInteger atomicInteger = new AtomicInteger(0);
        List<ModelStats> list = (List) splitNFold.stream().map(immutablePair -> {
            sb.append("\n");
            sb.append("====================== Fold " + atomicInteger.get() + " ====================== ").append("\n");
            System.out.println("====================== Fold " + atomicInteger.get() + " ====================== ");
            final File file3 = new File(file2 + File.separator + getModel().getModelName() + "_nfold_" + atomicInteger.getAndIncrement() + "_" + generate + ".wapiti");
            sb.append("Saving model in " + file3).append("\n");
            arrayList.add(file3.getAbsolutePath());
            arrayList.add(immutablePair.getLeft());
            arrayList.add(immutablePair.getRight());
            sb.append("Training input data: " + ((String) immutablePair.getLeft())).append("\n");
            trainer.train(getTemplatePath(), new File((String) immutablePair.getLeft()), file3, GrobidProperties.getNBThreads().intValue(), this.model);
            sb.append("Evaluation input data: " + ((String) immutablePair.getRight())).append("\n");
            ModelStats evaluateStandard = EvaluationUtilities.evaluateStandard((String) immutablePair.getRight(), TaggerFactory.getTagger(new GrobidModel() { // from class: org.grobid.trainer.AbstractTrainer.1
                public String getFolderName() {
                    return file2.getAbsolutePath();
                }

                public String getModelPath() {
                    return file3.getAbsolutePath();
                }

                public String getModelName() {
                    return AbstractTrainer.this.model.getModelName();
                }

                public String getTemplateName() {
                    return AbstractTrainer.this.model.getTemplateName();
                }
            }));
            sb.append(evaluateStandard.toString(z));
            sb.append("\n");
            sb.append("\n");
            return evaluateStandard;
        }).collect(Collectors.toList());
        sb.append("\n").append("Summary results: ").append("\n");
        Comparator comparator = (modelStats, modelStats2) -> {
            Stats fieldStats = modelStats.getFieldStats();
            Stats fieldStats2 = modelStats2.getFieldStats();
            if (fieldStats.getMicroAverageF1() > fieldStats2.getMicroAverageF1()) {
                return 1;
            }
            return fieldStats.getMicroAverageF1() < fieldStats2.getMicroAverageF1() ? -1 : 0;
        };
        Optional min = list.stream().min(comparator);
        sb.append("Worst fold").append("\n");
        sb.append(((ModelStats) min.orElseGet(() -> {
            throw new GrobidException("Something wrong when computing evaluations - worst model metrics not found. ");
        })).toString()).append("\n");
        sb.append("Best fold:").append("\n");
        sb.append(((ModelStats) list.stream().max(comparator).orElseGet(() -> {
            throw new GrobidException("Something wrong when computing evaluations - best model metrics not found. ");
        })).toString()).append("\n").append("\n");
        sb.append("Average over " + i + " folds: ").append("\n");
        TreeMap treeMap = new TreeMap();
        int i2 = 0;
        int i3 = 0;
        for (ModelStats modelStats3 : list) {
            i2 += modelStats3.getTotalInstances();
            i3 += modelStats3.getCorrectInstance();
            for (Map.Entry<String, LabelResult> entry : modelStats3.getFieldStats().getLabelsResults().entrySet()) {
                String key = entry.getKey();
                if (treeMap.containsKey(key)) {
                    ((LabelResult) treeMap.get(key)).setAccuracy(((LabelResult) treeMap.get(key)).getAccuracy() + entry.getValue().getAccuracy());
                    ((LabelResult) treeMap.get(key)).setF1Score(((LabelResult) treeMap.get(key)).getF1Score() + entry.getValue().getF1Score());
                    ((LabelResult) treeMap.get(key)).setRecall(((LabelResult) treeMap.get(key)).getRecall() + entry.getValue().getRecall());
                    ((LabelResult) treeMap.get(key)).setPrecision(((LabelResult) treeMap.get(key)).getPrecision() + entry.getValue().getPrecision());
                    ((LabelResult) treeMap.get(key)).setSupport(((LabelResult) treeMap.get(key)).getSupport() + entry.getValue().getSupport());
                } else {
                    treeMap.put(key, new LabelResult(key));
                    ((LabelResult) treeMap.get(key)).setAccuracy(entry.getValue().getAccuracy());
                    ((LabelResult) treeMap.get(key)).setF1Score(entry.getValue().getF1Score());
                    ((LabelResult) treeMap.get(key)).setRecall(entry.getValue().getRecall());
                    ((LabelResult) treeMap.get(key)).setPrecision(entry.getValue().getPrecision());
                    ((LabelResult) treeMap.get(key)).setSupport(entry.getValue().getSupport());
                }
            }
        }
        sb.append(String.format("\n%-20s %-12s %-12s %-12s %-12s %-7s\n\n", "label", "accuracy", "precision", "recall", "f1", "support"));
        for (String str : treeMap.keySet()) {
            LabelResult labelResult = (LabelResult) treeMap.get(str);
            ((LabelResult) treeMap.get(str)).setAccuracy(labelResult.getAccuracy() / list.size());
            ((LabelResult) treeMap.get(str)).setF1Score(labelResult.getF1Score() / list.size());
            ((LabelResult) treeMap.get(str)).setPrecision(labelResult.getPrecision() / list.size());
            ((LabelResult) treeMap.get(str)).setRecall(labelResult.getRecall() / list.size());
            sb.append(labelResult.toString());
        }
        OptionalDouble average = list.stream().mapToDouble(modelStats4 -> {
            return modelStats4.getFieldStats().getMicroAverageF1();
        }).average();
        OptionalDouble average2 = list.stream().mapToDouble(modelStats5 -> {
            return modelStats5.getFieldStats().getMicroAveragePrecision();
        }).average();
        OptionalDouble average3 = list.stream().mapToDouble(modelStats6 -> {
            return modelStats6.getFieldStats().getMicroAverageRecall();
        }).average();
        double orElseGet = list.stream().mapToDouble(modelStats7 -> {
            return modelStats7.getFieldStats().getMicroAverageAccuracy();
        }).average().orElseGet(() -> {
            throw new GrobidException("Missing average accuracy. Something went wrong. Please check. ");
        });
        double orElseGet2 = average.orElseGet(() -> {
            throw new GrobidException("Missing average F1. Something went wrong. Please check. ");
        });
        double orElseGet3 = average2.orElseGet(() -> {
            throw new GrobidException("Missing average precision. Something went wrong. Please check. ");
        });
        double orElseGet4 = average3.orElseGet(() -> {
            throw new GrobidException("Missing average recall. Something went wrong. Please check. ");
        });
        sb.append("\n");
        sb.append(String.format("%-20s %-12s %-12s %-12s %-7s\n", "all ", TextUtilities.formatTwoDecimals(orElseGet * 100.0d), TextUtilities.formatTwoDecimals(orElseGet3 * 100.0d), TextUtilities.formatTwoDecimals(orElseGet4 * 100.0d), TextUtilities.formatTwoDecimals(orElseGet2 * 100.0d)));
        sb.append("\n===== Instance-level results =====\n\n");
        double d = i2 / i;
        double d2 = i3 / i;
        sb.append(String.format("%-27s %s\n", "Total expected instances:", TextUtilities.formatTwoDecimals(d)));
        sb.append(String.format("%-27s %s\n", "Correct instances:", TextUtilities.formatTwoDecimals(d2)));
        sb.append(String.format("%-27s %s\n", "Instance-level recall:", TextUtilities.formatTwoDecimals((d2 / d) * 100.0d)));
        arrayList.stream().forEach(str2 -> {
            try {
                Files.delete(Paths.get(str2, new String[0]));
            } catch (IOException e) {
                LOGGER.warn("Error while performing the cleanup after n-fold cross-validation. Cannot delete the file: " + str2, e);
            }
        });
        return sb.toString();
    }

    protected List<ImmutablePair<String, String>> splitNFold(List<String> list, int i) {
        int size = CollectionUtils.size(list);
        int floorDiv = Math.floorDiv(size, i);
        if (floorDiv == 0) {
            throw new IllegalArgumentException("There aren't enough training data for n-fold evaluation with fold of size " + i);
        }
        return (List) IntStream.range(0, i).mapToObj(i2 -> {
            int i2 = floorDiv * i2;
            int i3 = i2 + floorDiv;
            if (i2 == i - 1) {
                i3 = size;
            }
            List subList = list.subList(i2, i3);
            List subList2 = list.subList(0, i2);
            List subList3 = list.subList(i3, size);
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(subList2);
            arrayList.addAll(subList3);
            String absolutePath = getTempEvaluationDataPath().getAbsolutePath();
            try {
                BufferedWriter newBufferedWriter = Files.newBufferedWriter(Paths.get(absolutePath, new String[0]), new OpenOption[0]);
                Throwable th = null;
                try {
                    try {
                        newBufferedWriter.write(String.join("\n\n", subList));
                        newBufferedWriter.write("\n");
                        if (newBufferedWriter != null) {
                            if (0 != 0) {
                                try {
                                    newBufferedWriter.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                newBufferedWriter.close();
                            }
                        }
                        String absolutePath2 = getTempTrainingDataPath().getAbsolutePath();
                        try {
                            BufferedWriter newBufferedWriter2 = Files.newBufferedWriter(Paths.get(absolutePath2, new String[0]), new OpenOption[0]);
                            Throwable th3 = null;
                            try {
                                newBufferedWriter2.write(String.join("\n\n", arrayList));
                                newBufferedWriter2.write("\n");
                                if (newBufferedWriter2 != null) {
                                    if (0 != 0) {
                                        try {
                                            newBufferedWriter2.close();
                                        } catch (Throwable th4) {
                                            th3.addSuppressed(th4);
                                        }
                                    } else {
                                        newBufferedWriter2.close();
                                    }
                                }
                                return new ImmutablePair(absolutePath2, absolutePath);
                            } finally {
                            }
                        } catch (IOException e) {
                            throw new GrobidException("Error when dumping n-fold training data into files. ", e);
                        }
                    } finally {
                    }
                } finally {
                }
            } catch (IOException e2) {
                throw new GrobidException("Error when dumping n-fold evaluation data into files. ", e2);
            }
        }).collect(Collectors.toList());
    }

    protected List<String> loadAndShuffle(Path path) {
        List<String> load = load(path);
        Collections.shuffle(load, new Random(839374947498L));
        return load;
    }

    public List<String> load(Path path) {
        ArrayList arrayList = new ArrayList();
        try {
            Stream<String> lines = Files.lines(path);
            Throwable th = null;
            try {
                try {
                    ArrayList arrayList2 = new ArrayList();
                    ListIterator listIterator = ((List) lines.collect(Collectors.toList())).listIterator();
                    while (listIterator.hasNext()) {
                        String str = (String) listIterator.next();
                        if (StringUtils.isBlank(str)) {
                            if (CollectionUtils.isNotEmpty(arrayList2)) {
                                arrayList.add(String.join("\n", arrayList2));
                            }
                            arrayList2 = new ArrayList();
                        } else {
                            arrayList2.add(str);
                        }
                    }
                    if (CollectionUtils.isNotEmpty(arrayList2)) {
                        arrayList.add(String.join("\n", arrayList2));
                    }
                    if (lines != null) {
                        if (0 != 0) {
                            try {
                                lines.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            lines.close();
                        }
                    }
                    return arrayList;
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            throw new GrobidException("Error in n-fold, when loading training data. Failing. ", e);
        }
    }

    protected final File getTempTrainingDataPath() {
        try {
            return File.createTempFile(this.model.getModelName(), ".train", GrobidProperties.getTempPath());
        } catch (IOException e) {
            throw new RuntimeException("Unable to create a temporary training file for model: " + this.model);
        }
    }

    protected final File getTempEvaluationDataPath() {
        try {
            return File.createTempFile(this.model.getModelName(), ".test", GrobidProperties.getTempPath());
        } catch (IOException e) {
            throw new RuntimeException("Unable to create a temporary evaluation file for model: " + this.model);
        }
    }

    protected GenericTagger getTagger() {
        if (this.tagger == null) {
            this.tagger = TaggerFactory.getTagger(this.model);
        }
        return this.tagger;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static File getFilePath2Resources() {
        File file = new File(GrobidProperties.get_GROBID_HOME_PATH().getAbsoluteFile() + File.separator + ".." + File.separator + "grobid-trainer" + File.separator + "resources");
        if (!file.exists()) {
            file = new File("resources");
        }
        return file;
    }

    protected File getCorpusPath() {
        return GrobidProperties.getCorpusPath(getFilePath2Resources(), this.model);
    }

    protected File getTemplatePath() {
        return getTemplatePath(this.model);
    }

    protected File getTemplatePath(GrobidModel grobidModel) {
        return GrobidProperties.getTemplatePath(getFilePath2Resources(), grobidModel);
    }

    protected File getEvalCorpusPath() {
        return GrobidProperties.getEvalCorpusPath(getFilePath2Resources(), this.model);
    }

    public static File getEvalCorpusBasePath() {
        return new File(getFilePath2Resources().getAbsolutePath() + File.separator + "dataset" + File.separator + "patent" + File.separator + "evaluation");
    }

    @Override // org.grobid.trainer.Trainer
    public GrobidModel getModel() {
        return this.model;
    }

    public static void runTraining(Trainer trainer) {
        long currentTimeMillis = System.currentTimeMillis();
        trainer.train();
        System.out.println("Model for " + trainer.getModel() + " created in " + (System.currentTimeMillis() - currentTimeMillis) + " ms");
    }

    public File getEvalDataPath() {
        return this.evalDataPath;
    }

    public static String runEvaluation(Trainer trainer, boolean z) {
        try {
            return trainer.evaluate(z) + "\n\nEvaluation for " + trainer.getModel() + " model is realized in " + (System.currentTimeMillis() - System.currentTimeMillis()) + " ms";
        } catch (Exception e) {
            throw new GrobidException("An exception occurred while evaluating Grobid.", e);
        }
    }

    public static String runEvaluation(Trainer trainer) {
        return trainer.evaluate(false);
    }

    public static String runSplitTrainingEvaluation(Trainer trainer, Double d) {
        try {
            return trainer.splitTrainEvaluate(d) + "\n\nSplit, training and evaluation for " + trainer.getModel() + " model is realized in " + (System.currentTimeMillis() - System.currentTimeMillis()) + " ms";
        } catch (Exception e) {
            throw new GrobidException("An exception occurred while evaluating Grobid.", e);
        }
    }

    public static void runNFoldEvaluation(Trainer trainer, int i, Path path) {
        runNFoldEvaluation(trainer, i, path, false);
    }

    public static void runNFoldEvaluation(Trainer trainer, int i, Path path, boolean z) {
        String runNFoldEvaluation = runNFoldEvaluation(trainer, i, z);
        try {
            BufferedWriter newBufferedWriter = Files.newBufferedWriter(path, new OpenOption[0]);
            Throwable th = null;
            try {
                newBufferedWriter.write(runNFoldEvaluation);
                newBufferedWriter.write("\n");
                if (newBufferedWriter != null) {
                    if (0 != 0) {
                        try {
                            newBufferedWriter.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        newBufferedWriter.close();
                    }
                }
            } finally {
            }
        } catch (IOException e) {
            throw new GrobidException("Error when dumping n-fold training data into files. ", e);
        }
    }

    public static String runNFoldEvaluation(Trainer trainer, int i) {
        return runNFoldEvaluation(trainer, i, false);
    }

    public static String runNFoldEvaluation(Trainer trainer, int i, boolean z) {
        try {
            return trainer.nFoldEvaluate(i, z) + "\n\nN-Fold evaluation for " + trainer.getModel() + " model is realized in " + (System.currentTimeMillis() - System.currentTimeMillis()) + " ms";
        } catch (Exception e) {
            throw new GrobidException("An exception occurred while evaluating Grobid.", e);
        }
    }

    public Writer dispatchExample(Writer writer, Writer writer2, double d) {
        return (writer != null || writer2 == null) ? (writer == null || writer2 != null) ? Math.random() <= d ? writer : writer2 : writer : writer2;
    }
}
