package experimental.analyzer.simple;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.OptimizationException;
import experimental.analyzer.Analyzer;
import experimental.analyzer.AnalyzerInstance;
import experimental.analyzer.AnalyzerReading;
import experimental.analyzer.AnalyzerTag;
import experimental.analyzer.AnalyzerTrainer;
import experimental.analyzer.simple.SimpleAnalyzer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import marmot.util.Counter;
import marmot.util.Mutable;
import net.arnx.jsonic.JSONException;
import org.javatuples.Pair;
import vn.corenlp.tokenizer.StringConst;

/* loaded from: input_file:experimental/analyzer/simple/SimpleAnalyzerTrainer.class */
public class SimpleAnalyzerTrainer extends AnalyzerTrainer {
    private SimpleAnalyzer.Mode train_mode_;
    private SimpleAnalyzer.Mode tag_mode_;
    private double penalty_;
    public static final String MODE = "mode";
    public static final String PENALTY = "penalty";
    public static final String PAIR_CONSTRAINT = "pair-constraint";
    public static final String PAIR_CONSTRAINT_THRESHOLD = "pair-constraint-threshold";
    private boolean optimize_threshold_ = false;
    private boolean mallet_ = false;
    private PairConstraint pair_constraint_ = PairConstraint.weighted;
    private double pair_constraint_threshold_ = 0.1d;
    private Map<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>> relative_counts_ = null;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:experimental/analyzer/simple/SimpleAnalyzerTrainer$PairConstraint.class */
    public enum PairConstraint {
        simple,
        weighted,
        none
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:experimental/analyzer/simple/SimpleAnalyzerTrainer$TagStats.class */
    public static class TagStats {
        Counter<AnalyzerTag> tag_counts;
        Counter<Pair<AnalyzerTag, AnalyzerTag>> tag_tag_counts;

        public TagStats() {
            this.tag_counts = new Counter<>();
            this.tag_tag_counts = new Counter<>();
            this.tag_counts = new Counter<>();
            this.tag_tag_counts = new Counter<>();
        }
    }

    @Override // experimental.analyzer.AnalyzerTrainer
    public Analyzer train(Collection<AnalyzerInstance> collection) {
        System.err.format("Num instances: %d\n", Integer.valueOf(collection.size()));
        this.tag_mode_ = SimpleAnalyzer.Mode.binary;
        this.train_mode_ = SimpleAnalyzer.Mode.binary;
        if (this.options_.containsKey(MODE)) {
            SimpleAnalyzer.Mode valueOf = SimpleAnalyzer.Mode.valueOf(this.options_.get(MODE));
            this.tag_mode_ = valueOf;
            this.train_mode_ = valueOf;
        }
        if (this.options_.containsKey(PAIR_CONSTRAINT)) {
            this.pair_constraint_ = PairConstraint.valueOf(this.options_.get(PAIR_CONSTRAINT));
        }
        if (this.options_.containsKey(PAIR_CONSTRAINT_THRESHOLD)) {
            this.pair_constraint_threshold_ = Double.valueOf(this.options_.get(PAIR_CONSTRAINT_THRESHOLD)).doubleValue();
        }
        System.err.format("Modes: %s / %s\n", this.tag_mode_, this.train_mode_);
        this.penalty_ = 1.0d;
        if (this.options_.containsKey("penalty")) {
            this.penalty_ = Double.valueOf(this.options_.get("penalty")).doubleValue();
        }
        System.err.format("Penalty: %g\n", Double.valueOf(this.penalty_));
        Collection<Pair<AnalyzerTag, AnalyzerTag>> coupledTags = 0 != 0 ? getCoupledTags(collection) : null;
        if (this.pair_constraint_ != PairConstraint.none) {
            preparePairConstraints(collection);
        }
        LinkedList linkedList = new LinkedList();
        for (AnalyzerInstance analyzerInstance : collection) {
            linkedList.add(new SimpleAnalyzerInstance(analyzerInstance, AnalyzerReading.toTags(analyzerInstance.getReadings())));
        }
        SimpleAnalyzerModel simpleAnalyzerModel = new SimpleAnalyzerModel();
        simpleAnalyzerModel.init(linkedList, this.options_.containsKey(AnalyzerTrainer.FLOAT_DICT_) ? this.options_.get(AnalyzerTrainer.FLOAT_DICT_) : null);
        if (this.mallet_) {
            run_mallet(simpleAnalyzerModel, linkedList);
        } else {
            run_sgd(simpleAnalyzerModel, linkedList, 10, true, 0.1d);
        }
        double d = 0.01d;
        if (this.optimize_threshold_) {
            d = new SimpleThresholdOptimizer(false).findTreshold(simpleAnalyzerModel, collection, this.tag_mode_);
            System.err.println("Best threshold on train: " + d);
        }
        return new SimpleAnalyzer(simpleAnalyzerModel, d, this.tag_mode_, coupledTags);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void preparePairConstraints(Collection<AnalyzerInstance> collection) {
        TagStats tagStates = getTagStates(collection);
        this.relative_counts_ = new HashMap();
        for (Map.Entry<Pair<AnalyzerTag, AnalyzerTag>, Double> entry : tagStates.tag_tag_counts.entrySet()) {
            Pair<AnalyzerTag, AnalyzerTag> key = entry.getKey();
            Double value = entry.getValue();
            addRelativeProb((AnalyzerTag) key.getValue0(), (AnalyzerTag) key.getValue1(), value, tagStates.tag_counts.count(key.getValue0()), this.relative_counts_);
            addRelativeProb((AnalyzerTag) key.getValue1(), (AnalyzerTag) key.getValue0(), value, tagStates.tag_counts.count(key.getValue1()), this.relative_counts_);
        }
        for (Map.Entry<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>> entry2 : this.relative_counts_.entrySet()) {
            AnalyzerTag key2 = entry2.getKey();
            Map<AnalyzerTag, Mutable<Double>> value2 = entry2.getValue();
            value2.put(key2, new Mutable<>(Double.valueOf(1.0d)));
            double d = 0.0d;
            Iterator<Mutable<Double>> it2 = value2.values().iterator();
            while (it2.hasNext()) {
                d += it2.next().get().doubleValue();
            }
            for (Mutable<Double> mutable : value2.values()) {
                mutable.set(Double.valueOf(mutable.get().doubleValue() / d));
            }
            System.err.println(key2 + StringConst.SPACE + value2);
        }
    }

    private void addRelativeProb(AnalyzerTag analyzerTag, AnalyzerTag analyzerTag2, Double d, Double d2, Map<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>> map) {
        double doubleValue = d.doubleValue() / d2.doubleValue();
        if (doubleValue > this.pair_constraint_threshold_) {
            Map<AnalyzerTag, Mutable<Double>> map2 = map.get(analyzerTag);
            if (map2 == null) {
                map2 = new HashMap();
                map.put(analyzerTag, map2);
            }
            if (!$assertionsDisabled && map2.containsKey(analyzerTag2)) {
                throw new AssertionError();
            }
            map2.put(analyzerTag2, new Mutable<>(Double.valueOf(doubleValue)));
        }
    }

    private void run_sgd(SimpleAnalyzerModel simpleAnalyzerModel, Collection<SimpleAnalyzerInstance> collection, int i, boolean z, double d) {
        LinkedList linkedList = new LinkedList(collection);
        SimpleAnalyzerObjective simpleAnalyzerObjective = new SimpleAnalyzerObjective(this.penalty_, simpleAnalyzerModel, collection, this.train_mode_, this.relative_counts_, this.pair_constraint_);
        int i2 = 0;
        Random random = new Random(42L);
        for (int i3 = 0; i3 < i; i3++) {
            if (z) {
                System.err.println("step: " + i3);
            }
            Collections.shuffle(linkedList, random);
            Iterator it2 = linkedList.iterator();
            while (it2.hasNext()) {
                simpleAnalyzerObjective.update((SimpleAnalyzerInstance) it2.next(), d / (1.0d + (i2 / linkedList.size())), true);
                i2++;
            }
        }
    }

    private void run_mallet(SimpleAnalyzerModel simpleAnalyzerModel, Collection<SimpleAnalyzerInstance> collection) {
        Logger logger = Logger.getLogger(getClass().getName());
        logger.info("Start optimization");
        SimpleAnalyzerObjective simpleAnalyzerObjective = new SimpleAnalyzerObjective(this.penalty_, simpleAnalyzerModel, collection, this.train_mode_, this.relative_counts_, this.pair_constraint_);
        LimitedMemoryBFGS limitedMemoryBFGS = new LimitedMemoryBFGS(simpleAnalyzerObjective);
        Logger.getLogger(limitedMemoryBFGS.getClass().getName()).setLevel(Level.OFF);
        simpleAnalyzerObjective.setParameters(simpleAnalyzerModel.getWeights());
        try {
            limitedMemoryBFGS.optimize(1);
            for (int i = 0; i < 200; i++) {
                if (limitedMemoryBFGS.isConverged()) {
                    break;
                }
                limitedMemoryBFGS.optimize(1);
                logger.info(String.format("Iteration: %3d / %3d: %g", Integer.valueOf(i + 1), Integer.valueOf(JSONException.PARSE_ERROR), Double.valueOf(simpleAnalyzerObjective.getValue())));
            }
        } catch (OptimizationException e) {
        } catch (IllegalArgumentException e2) {
        }
    }

    private TagStats getTagStates(Collection<AnalyzerInstance> collection) {
        TagStats tagStats = new TagStats();
        Iterator<AnalyzerInstance> it2 = collection.iterator();
        while (it2.hasNext()) {
            Collection<AnalyzerTag> tags = AnalyzerReading.toTags(it2.next().getReadings());
            Iterator<AnalyzerTag> it3 = tags.iterator();
            while (it3.hasNext()) {
                tagStats.tag_counts.increment(it3.next(), Double.valueOf(1.0d));
            }
            ArrayList arrayList = new ArrayList(tags);
            for (int i = 0; i < arrayList.size(); i++) {
                AnalyzerTag analyzerTag = (AnalyzerTag) arrayList.get(i);
                for (int i2 = i + 1; i2 < arrayList.size(); i2++) {
                    AnalyzerTag analyzerTag2 = (AnalyzerTag) arrayList.get(i2);
                    if (analyzerTag.hashCode() < analyzerTag2.hashCode()) {
                        tagStats.tag_tag_counts.increment(new Pair<>(analyzerTag2, analyzerTag), Double.valueOf(1.0d));
                    } else {
                        tagStats.tag_tag_counts.increment(new Pair<>(analyzerTag, analyzerTag2), Double.valueOf(1.0d));
                    }
                }
            }
        }
        return tagStats;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Collection<Pair<AnalyzerTag, AnalyzerTag>> getCoupledTags(Collection<AnalyzerInstance> collection) {
        TagStats tagStates = getTagStates(collection);
        LinkedList linkedList = new LinkedList();
        for (Map.Entry<Pair<AnalyzerTag, AnalyzerTag>, Double> entry : tagStates.tag_tag_counts.entrySet()) {
            Pair<AnalyzerTag, AnalyzerTag> key = entry.getKey();
            double doubleValue = tagStates.tag_counts.count(key.getValue0()).doubleValue();
            if (!$assertionsDisabled && doubleValue >= collection.size()) {
                throw new AssertionError();
            }
            double doubleValue2 = tagStates.tag_counts.count(key.getValue1()).doubleValue();
            if (!$assertionsDisabled && doubleValue2 >= collection.size()) {
                throw new AssertionError();
            }
            double doubleValue3 = entry.getValue().doubleValue();
            if (!$assertionsDisabled && doubleValue3 >= collection.size()) {
                throw new AssertionError();
            }
            if (entry.getValue().doubleValue() >= 10.0d && doubleValue3 / Math.sqrt(doubleValue * doubleValue2) > 0.99d) {
                linkedList.add(key);
            }
        }
        System.err.println("|Coupled|: " + linkedList.size());
        System.err.println("Coupled: " + linkedList);
        return linkedList;
    }

    static {
        $assertionsDisabled = !SimpleAnalyzerTrainer.class.desiredAssertionStatus();
    }
}
