package marmot.core;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import marmot.core.lattice.SequenceViterbiLattice;
import marmot.core.lattice.SumLattice;
import marmot.core.lattice.ZeroOrderSumLattice;
import marmot.core.lattice.ZeroOrderViterbiLattice;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:marmot/core/PerceptronTrainer.class */
public class PerceptronTrainer implements Trainer {
    private int steps_;
    private boolean shuffle_;
    private boolean verbose_;
    private boolean averaging_;
    private long seed_;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // marmot.core.Trainer
    public void train(Tagger tagger, Collection<Sequence> collection, Evaluator evaluator) {
        Random random = this.shuffle_ ? this.seed_ == 0 ? new Random() : new Random(this.seed_) : null;
        ArrayList arrayList = new ArrayList(collection);
        int max = Math.max(arrayList.size() / 4, 1);
        WeightVector weightVector = tagger.getWeightVector();
        if (!$assertionsDisabled && weightVector == null) {
            throw new AssertionError();
        }
        double[] dArr = this.averaging_ ? new double[weightVector.getWeights().length] : null;
        Model model = tagger.getModel();
        for (int i = 0; i < this.steps_; i++) {
            if (this.verbose_) {
                System.err.println("step: " + i);
            }
            if (this.shuffle_) {
                Collections.shuffle(arrayList, random);
            }
            int i2 = 0;
            long currentTimeMillis = System.currentTimeMillis();
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                SumLattice sumLattice = tagger.getSumLattice(true, (Sequence) it2.next());
                List<List<State>> candidates = sumLattice.getCandidates();
                List<Integer> states = (sumLattice instanceof ZeroOrderSumLattice ? new ZeroOrderViterbiLattice(candidates, 1, false) : new SequenceViterbiLattice(candidates, model.getBoundaryState(tagger.getNumLevels() - 1), 1, false)).getViterbiSequence().getStates();
                List<Integer> goldCandidates = sumLattice.getGoldCandidates();
                if (!goldCandidates.equals(states)) {
                    update(weightVector, candidates, goldCandidates, 1.0d);
                    update(weightVector, candidates, states, -1.0d);
                    if (this.averaging_) {
                        double[] weights = weightVector.getWeights();
                        int size = arrayList.size() - i2;
                        if (!$assertionsDisabled && size <= 0) {
                            throw new AssertionError();
                        }
                        weightVector.setWeights(dArr);
                        update(weightVector, candidates, goldCandidates, size);
                        update(weightVector, candidates, states, -size);
                        weightVector.setWeights(weights);
                    }
                }
                i2++;
                if (i2 % max == 0 && this.verbose_) {
                    System.err.format("Processed %d sentences at %g sentence/s \n", Integer.valueOf(i2), Double.valueOf(i2 / ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d)));
                }
            }
            if (this.averaging_) {
                double[] weights2 = weightVector.getWeights();
                for (int i3 = 0; i3 < weights2.length; i3++) {
                    double size2 = (i + 1) * arrayList.size();
                    if (!$assertionsDisabled && size2 <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                        throw new AssertionError();
                    }
                    weights2[i3] = dArr[i3] / size2;
                    double d = (i + 2) / (i + 1);
                    if (!$assertionsDisabled && d <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                        throw new AssertionError();
                    }
                    if (!$assertionsDisabled && d >= 2.00001d) {
                        throw new AssertionError();
                    }
                    double[] dArr2 = dArr;
                    int i4 = i3;
                    dArr2[i4] = dArr2[i4] * d;
                }
            }
            if (evaluator != null && this.verbose_) {
                weightVector.setExtendFeatureSet(false);
                evaluator.eval(tagger);
                weightVector.setExtendFeatureSet(true);
            }
        }
        weightVector.setExtendFeatureSet(false);
    }

    private void update(WeightVector weightVector, List<List<State>> list, List<Integer> list2, double d) {
        int i = 0;
        for (int i2 = 0; i2 < list2.size(); i2++) {
            int intValue = list2.get(i2).intValue();
            State state = list.get(i2).get(intValue);
            weightVector.updateWeights(state, d, false);
            weightVector.updateWeights(state.getTransition(i), d, true);
            i = intValue;
        }
    }

    @Override // marmot.core.Trainer
    public void setOptions(Options options) {
        this.steps_ = options.getNumIterations();
        this.shuffle_ = options.getShuffle();
        this.verbose_ = options.getVerbose();
        this.averaging_ = options.getAveraging();
        this.seed_ = options.getSeed();
    }

    static {
        $assertionsDisabled = !PerceptronTrainer.class.desiredAssertionStatus();
    }
}
