package edu.stanford.nlp.scoref;

import edu.stanford.nlp.scoref.SimpleLinearClassifier;
import edu.stanford.nlp.stats.Counter;
import java.io.File;
import java.io.PrintWriter;
import java.util.Map;

/* loaded from: input_file:edu/stanford/nlp/scoref/PairwiseModel.class */
public class PairwiseModel {
    public final String name;
    private final int trainingExamples;
    private final int epochs;
    protected final SimpleLinearClassifier classifier;
    private final double singletonRatio;
    private final String str;
    protected final MetaFeatureExtractor meta;

    /* loaded from: input_file:edu/stanford/nlp/scoref/PairwiseModel$Builder.class */
    public static class Builder {
        private final String name;
        private final MetaFeatureExtractor meta;
        private final String source = StatisticalCorefTrainer.extractedFeaturesFile;
        private int trainingExamples = 100000000;
        private int epochs = 8;
        private SimpleLinearClassifier.Loss loss = SimpleLinearClassifier.log();
        private SimpleLinearClassifier.LearningRateSchedule learningRateSchedule = SimpleLinearClassifier.adaGrad(0.05d, 30.0d);
        private double regularizationStrength = 1.0E-7d;
        private double singletonRatio = 0.3d;
        private String modelFile = null;

        public Builder(String str, MetaFeatureExtractor metaFeatureExtractor) {
            this.name = str;
            this.meta = metaFeatureExtractor;
        }

        public Builder trainingExamples(int i) {
            this.trainingExamples = i;
            return this;
        }

        public Builder epochs(int i) {
            this.epochs = i;
            return this;
        }

        public Builder singletonRatio(double d) {
            this.singletonRatio = d;
            return this;
        }

        public Builder loss(SimpleLinearClassifier.Loss loss) {
            this.loss = loss;
            return this;
        }

        public Builder regularizationStrength(double d) {
            this.regularizationStrength = d;
            return this;
        }

        public Builder learningRateSchedule(SimpleLinearClassifier.LearningRateSchedule learningRateSchedule) {
            this.learningRateSchedule = learningRateSchedule;
            return this;
        }

        public Builder modelPath(String str) {
            this.modelFile = str;
            return this;
        }

        public PairwiseModel build() {
            return new PairwiseModel(this);
        }
    }

    public static Builder newBuilder(String str, MetaFeatureExtractor metaFeatureExtractor) {
        return new Builder(str, metaFeatureExtractor);
    }

    public PairwiseModel(Builder builder) {
        this.name = builder.name;
        this.meta = builder.meta;
        this.trainingExamples = builder.trainingExamples;
        this.epochs = builder.epochs;
        this.singletonRatio = builder.singletonRatio;
        this.classifier = new SimpleLinearClassifier(builder.loss, builder.learningRateSchedule, builder.regularizationStrength, builder.modelFile == null ? null : (builder.modelFile.endsWith(".ser") || builder.modelFile.endsWith(".gz")) ? builder.modelFile : StatisticalCorefTrainer.pairwiseModelsPath + builder.modelFile + "/model.ser");
        this.str = StatisticalCorefUtils.fieldValues(builder);
    }

    public String getDefaultOutputPath() {
        return StatisticalCorefTrainer.pairwiseModelsPath + this.name + "/";
    }

    public SimpleLinearClassifier getClassifier() {
        return this.classifier;
    }

    public void writeModel() throws Exception {
        writeModel(getDefaultOutputPath());
    }

    public void writeModel(String str) throws Exception {
        File file = new File(str);
        if (!file.exists()) {
            file.mkdir();
        }
        PrintWriter printWriter = new PrintWriter(str + "config", "UTF-8");
        Throwable th = null;
        try {
            try {
                printWriter.print(this.str);
                if (printWriter != null) {
                    if (0 != 0) {
                        try {
                            printWriter.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        printWriter.close();
                    }
                }
                PrintWriter printWriter2 = new PrintWriter(str + "/weights", "UTF-8");
                Throwable th3 = null;
                try {
                    this.classifier.printWeightVector(printWriter2);
                    if (printWriter2 != null) {
                        if (0 != 0) {
                            try {
                                printWriter2.close();
                            } catch (Throwable th4) {
                                th3.addSuppressed(th4);
                            }
                        } else {
                            printWriter2.close();
                        }
                    }
                    this.classifier.writeWeights(str + "/model.ser");
                } catch (Throwable th5) {
                    if (printWriter2 != null) {
                        if (0 != 0) {
                            try {
                                printWriter2.close();
                            } catch (Throwable th6) {
                                th3.addSuppressed(th6);
                            }
                        } else {
                            printWriter2.close();
                        }
                    }
                    throw th5;
                }
            } finally {
            }
        } catch (Throwable th7) {
            if (printWriter != null) {
                if (th != null) {
                    try {
                        printWriter.close();
                    } catch (Throwable th8) {
                        th.addSuppressed(th8);
                    }
                } else {
                    printWriter.close();
                }
            }
            throw th7;
        }
    }

    public void learn(Example example, Map<Integer, CompressedFeatureVector> map, Compressor<String> compressor) {
        this.classifier.learn(this.meta.getFeatures(example, map, compressor), example.label == 1.0d ? 1.0d : -1.0d, 1.0d);
    }

    public void learn(Example example, Map<Integer, CompressedFeatureVector> map, Compressor<String> compressor, double d) {
        this.classifier.learn(this.meta.getFeatures(example, map, compressor), example.label == 1.0d ? 1.0d : -1.0d, d);
    }

    public void learn(Example example, Example example2, Map<Integer, CompressedFeatureVector> map, Compressor<String> compressor, double d) {
        Counter<String> counter = null;
        Counter<String> counter2 = null;
        if (example != null) {
            counter = this.meta.getFeatures(example, map, compressor);
        }
        if (example2 != null) {
            counter2 = this.meta.getFeatures(example2, map, compressor);
        }
        if (example != null && example2 != null) {
            this.classifier.learn(counter, 1.0d, d);
            this.classifier.learn(counter2, -1.0d, d);
        } else if (this.singletonRatio != 0.0d) {
            if (example != null) {
                this.classifier.learn(counter, 1.0d, d * this.singletonRatio);
            }
            if (example2 != null) {
                this.classifier.learn(counter2, -1.0d, d * this.singletonRatio);
            }
        }
    }

    public double predict(Example example, Map<Integer, CompressedFeatureVector> map, Compressor<String> compressor) {
        return this.classifier.label(this.meta.getFeatures(example, map, compressor));
    }

    public int getNumTrainingExamples() {
        return this.trainingExamples;
    }

    public int getNumEpochs() {
        return this.epochs;
    }
}
