package lemming.lemma.toutanova;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import lemming.lemma.LemmaInstance;
import lemming.lemma.toutanova.Aligner;
import lemming.lemma.toutanova.ToutanovaTrainer;
import marmot.util.DynamicWeights;
import marmot.util.Encoder;
import marmot.util.SymbolTable;

/* loaded from: input_file:lemming/lemma/toutanova/ToutanovaModel.class */
public class ToutanovaModel implements Serializable {
    private static final long serialVersionUID = 1;
    private String[] alphabet_;
    private SymbolTable<String> output_table_;
    private SymbolTable<String> pos_table_;
    private int max_input_segment_length_;
    private int num_output_bits;
    private SymbolTable<Character> char_table;
    private Set<String> form_vocab_;
    private transient Encoder encoder;
    private transient Encoder.State encoder_state;
    private int num_char_bits;
    private int num_pos_bits;
    private IndexScorer scorer_;
    private IndexUpdater updater_;
    private boolean use_zero_order_;
    private int max_input_segment_length_bits_;
    private DynamicWeights weights_;
    private static final int length_bits_ = 6;
    private static final int FEATURE_BITS;
    private static final int TRANS_FEAT = 0;
    private static final int OUTPUT_FEAT = 1;
    private static final int PAIR_FEAT = 2;
    private static final String COPY_SYMBOL = "<COPY>";
    private int max_window = 2;
    static final /* synthetic */ boolean $assertionsDisabled;

    public void init(ToutanovaTrainer.ToutanovaOptions toutanovaOptions, List<ToutanovaInstance> list, List<ToutanovaInstance> list2) {
        Logger logger = Logger.getLogger(getClass().getName());
        this.max_window = toutanovaOptions.getMaxWindowSize();
        createOutputTable(toutanovaOptions, list);
        logger.info("Output alphabet size: " + this.output_table_.size());
        logger.info("Max input segment length: " + this.max_input_segment_length_);
        if (toutanovaOptions.getFilterAlphabet() > 0) {
            filterRareOutputSymbols(toutanovaOptions, list);
            createOutputTable(toutanovaOptions, list);
            logger.info("Output alphabet size: " + this.output_table_.size());
            logger.info("Max input segment length: " + this.max_input_segment_length_);
        }
        this.char_table = new SymbolTable<>();
        if (toutanovaOptions.getUsePos()) {
            this.pos_table_ = new SymbolTable<>();
        }
        this.form_vocab_ = new HashSet();
        Iterator<ToutanovaInstance> it2 = list.iterator();
        while (it2.hasNext()) {
            this.form_vocab_.add(it2.next().getInstance().getForm());
        }
        addIndexes(list, true);
        if (list2 != null) {
            addIndexes(list2, false);
        }
        this.num_output_bits = Encoder.bitsNeeded(this.output_table_.size());
        this.alphabet_ = new String[this.output_table_.size()];
        for (Map.Entry<String, Integer> entry : this.output_table_.entrySet()) {
            this.alphabet_[entry.getValue().intValue()] = entry.getKey();
        }
        this.output_table_.setBidirectional(false);
        this.num_char_bits = Encoder.bitsNeeded(this.char_table.size());
        this.num_pos_bits = -1;
        if (this.pos_table_ != null) {
            this.num_pos_bits = Encoder.bitsNeeded(this.pos_table_.size());
        }
        this.weights_ = new DynamicWeights(toutanovaOptions.getRandom());
        SymbolTable symbolTable = new SymbolTable();
        this.scorer_ = new IndexScorer(this.weights_, symbolTable, this.num_pos_bits);
        this.updater_ = new IndexUpdater(this.weights_, symbolTable, this.num_pos_bits);
        this.use_zero_order_ = toutanovaOptions.getDecoderInstance().getOrder() < 1;
        setupTemp();
    }

    private void setupTemp() {
        this.encoder = new Encoder(10);
        this.encoder_state = new Encoder.State();
    }

    private void readObject(ObjectInputStream objectInputStream) throws ClassNotFoundException, IOException {
        objectInputStream.defaultReadObject();
        setupTemp();
    }

    private void createOutputTable(ToutanovaTrainer.ToutanovaOptions toutanovaOptions, List<ToutanovaInstance> list) {
        this.output_table_ = new SymbolTable<>(true);
        this.output_table_.insert(COPY_SYMBOL);
        this.max_input_segment_length_ = 0;
        for (ToutanovaInstance toutanovaInstance : list) {
            if (toutanovaInstance.isRare()) {
                toutanovaInstance.setResult(null);
            } else {
                String form = toutanovaInstance.getInstance().getForm();
                if (!$assertionsDisabled && toutanovaInstance.getAlignment() == null) {
                    throw new AssertionError();
                }
                List<Aligner.Pair> pairs = Aligner.Pair.toPairs(form, toutanovaInstance.getInstance().getLemma(), toutanovaInstance.getAlignment());
                ArrayList arrayList = new ArrayList(pairs.size());
                ArrayList arrayList2 = new ArrayList(pairs.size());
                int i = 0;
                for (Aligner.Pair pair : pairs) {
                    int length = pair.getInputSegment().length();
                    this.max_input_segment_length_ = Math.max(this.max_input_segment_length_, length);
                    i += length;
                    arrayList.add(Integer.valueOf(i));
                    int i2 = 0;
                    if (!pair.getInputSegment().equals(pair.getOutputSegment())) {
                        i2 = this.output_table_.toIndex((SymbolTable<String>) pair.getOutputSegment(), true);
                    }
                    arrayList2.add(Integer.valueOf(i2));
                }
                Result result = new Result(this, arrayList2, arrayList, form);
                if (!$assertionsDisabled && !result.getOutput().equals(toutanovaInstance.getInstance().getLemma())) {
                    throw new AssertionError();
                }
                toutanovaInstance.setResult(result);
            }
        }
        this.max_input_segment_length_bits_ = Encoder.bitsNeeded(this.max_input_segment_length_);
    }

    private void filterRareOutputSymbols(ToutanovaTrainer.ToutanovaOptions toutanovaOptions, List<ToutanovaInstance> list) {
        Logger logger = Logger.getLogger(getClass().getName());
        int[] iArr = new int[this.output_table_.size()];
        Iterator<ToutanovaInstance> it2 = list.iterator();
        while (it2.hasNext()) {
            Iterator<Integer> it3 = it2.next().getResult().getOutputs().iterator();
            while (it3.hasNext()) {
                int intValue = it3.next().intValue();
                iArr[intValue] = iArr[intValue] + 1;
            }
        }
        int i = 0;
        for (int i2 : iArr) {
            if (i2 == 1) {
                i++;
            }
        }
        logger.info(String.format("Num rare output symbols (< %d): %d", Integer.valueOf(toutanovaOptions.getFilterAlphabet()), Integer.valueOf(i)));
        for (ToutanovaInstance toutanovaInstance : list) {
            boolean z = false;
            Iterator<Integer> it4 = toutanovaInstance.getResult().getOutputs().iterator();
            while (true) {
                if (it4.hasNext()) {
                    if (iArr[it4.next().intValue()] <= toutanovaOptions.getFilterAlphabet()) {
                        z = true;
                        break;
                    }
                } else {
                    break;
                }
            }
            toutanovaInstance.setRare(z);
        }
    }

    public SymbolTable<String> getOutputTable() {
        return this.output_table_;
    }

    public int getMaxInputSegmentLength() {
        return this.max_input_segment_length_;
    }

    public String getOutput(int i) {
        return this.alphabet_ == null ? this.output_table_.toSymbol(Integer.valueOf(i)) : this.alphabet_[i];
    }

    public void consumeTransitionFeature(IndexConsumer indexConsumer, ToutanovaInstance toutanovaInstance, int i, int i2, int i3, int i4) {
        if (i3 < 0) {
            return;
        }
        this.encoder.reset();
        this.encoder.append(0, FEATURE_BITS);
        this.encoder.append(i3, this.num_output_bits);
        this.encoder.append(i4, this.num_output_bits);
        indexConsumer.consume(toutanovaInstance, this.encoder);
        addAffixes(toutanovaInstance, indexConsumer, i, i2);
    }

    private void addAffixes(ToutanovaInstance toutanovaInstance, IndexConsumer indexConsumer, int i, int i2) {
        for (int i3 = 1; i3 <= this.max_window; i3++) {
            this.encoder.storeState(this.encoder_state);
            addSegment(toutanovaInstance.getFormCharIndexes(), i - i3, i);
            addSegment(toutanovaInstance.getFormCharIndexes(), i2 + 1, i2 + i3 + 1);
            indexConsumer.consume(toutanovaInstance, this.encoder);
            this.encoder.restoreState(this.encoder_state);
        }
        for (int i4 = 1; i4 <= this.max_window; i4++) {
            this.encoder.storeState(this.encoder_state);
            addSegment(toutanovaInstance.getFormCharIndexes(), i - i4, i);
            indexConsumer.consume(toutanovaInstance, this.encoder);
            this.encoder.restoreState(this.encoder_state);
        }
        for (int i5 = 1; i5 <= this.max_window; i5++) {
            this.encoder.storeState(this.encoder_state);
            addSegment(toutanovaInstance.getFormCharIndexes(), i2 + 1, i2 + i5 + 1);
            indexConsumer.consume(toutanovaInstance, this.encoder);
            this.encoder.restoreState(this.encoder_state);
        }
    }

    private void addSegment(int[] iArr, int i, int i2) {
        this.encoder.append(i2 - i, 6);
        int i3 = i;
        while (i3 < i2) {
            int size = (i3 < 0 || i3 >= iArr.length) ? this.char_table.size() : iArr[i3];
            if (size < 0) {
                return;
            }
            this.encoder.append(size, this.num_char_bits);
            i3++;
        }
    }

    public void consumeOutputFeature(IndexConsumer indexConsumer, ToutanovaInstance toutanovaInstance, int i, int i2, int i3) {
        this.encoder.reset();
        this.encoder.append(1, FEATURE_BITS);
        this.encoder.append(i3, this.num_output_bits);
        indexConsumer.consume(toutanovaInstance, this.encoder);
        addAffixes(toutanovaInstance, indexConsumer, i, i2);
    }

    public void consumePairFeature(IndexConsumer indexConsumer, ToutanovaInstance toutanovaInstance, int i, int i2, int i3) {
        int[] formCharIndexes = toutanovaInstance.getFormCharIndexes();
        this.encoder.reset();
        this.encoder.append(2, FEATURE_BITS);
        this.encoder.append(i3, this.num_output_bits);
        this.encoder.append(i2 - i, this.max_input_segment_length_bits_);
        this.encoder.append(i2 - i, 4);
        for (int i4 = i; i4 < i2; i4++) {
            int i5 = formCharIndexes[i4];
            if (i5 < 0) {
                return;
            }
            this.encoder.append(i5, this.num_char_bits);
        }
        indexConsumer.consume(toutanovaInstance, this.encoder);
        addAffixes(toutanovaInstance, indexConsumer, i, i2);
    }

    private void consumeOutputPair(IndexConsumer indexConsumer, ToutanovaInstance toutanovaInstance, int i, int i2, int i3) {
        consumePairFeature(indexConsumer, toutanovaInstance, i, i2, i3);
        consumeOutputFeature(indexConsumer, toutanovaInstance, i, i2, i3);
    }

    private void consumeTransition(IndexConsumer indexConsumer, ToutanovaInstance toutanovaInstance, int i, int i2, int i3, int i4) {
        if (this.use_zero_order_) {
            return;
        }
        consumeTransitionFeature(indexConsumer, toutanovaInstance, i, i2, i3, i4);
    }

    public double getPairScore(ToutanovaInstance toutanovaInstance, int i, int i2, int i3) {
        this.scorer_.reset();
        consumeOutputPair(this.scorer_, toutanovaInstance, i, i2, i3);
        return this.scorer_.getScore();
    }

    public double getTransitionScore(ToutanovaInstance toutanovaInstance, int i, int i2, int i3, int i4) {
        this.scorer_.reset();
        consumeTransition(this.scorer_, toutanovaInstance, i3, i4, i, i2);
        return this.scorer_.getScore();
    }

    public double getScore(ToutanovaInstance toutanovaInstance, Result result) {
        this.scorer_.reset();
        Iterator<Integer> it2 = result.getOutputs().iterator();
        Iterator<Integer> it3 = result.getInputs().iterator();
        int i = -1;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (!it2.hasNext()) {
                return this.scorer_.getScore();
            }
            int intValue = it2.next().intValue();
            int intValue2 = it3.next().intValue();
            if (i >= 0) {
                consumeTransition(this.scorer_, toutanovaInstance, i3, intValue2, i, intValue);
            }
            consumeOutputPair(this.scorer_, toutanovaInstance, i3, intValue2, intValue);
            i = intValue;
            i2 = intValue2;
        }
    }

    public void update(ToutanovaInstance toutanovaInstance, Result result, double d) {
        this.updater_.setUpdate(d);
        Iterator<Integer> it2 = result.getOutputs().iterator();
        Iterator<Integer> it3 = result.getInputs().iterator();
        int i = -1;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (!it2.hasNext()) {
                return;
            }
            int intValue = it2.next().intValue();
            int intValue2 = it3.next().intValue();
            if (i >= 0) {
                consumeTransition(this.updater_, toutanovaInstance, i3, intValue2, i, intValue);
            }
            consumeOutputPair(this.updater_, toutanovaInstance, i3, intValue2, intValue);
            i = intValue;
            i2 = intValue2;
        }
    }

    public void addIndexes(List<ToutanovaInstance> list, boolean z) {
        Iterator<ToutanovaInstance> it2 = list.iterator();
        while (it2.hasNext()) {
            addIndexes(it2.next(), z);
        }
    }

    public void addIndexes(ToutanovaInstance toutanovaInstance, boolean z) {
        String posTag;
        if (toutanovaInstance.isRare()) {
            return;
        }
        String form = toutanovaInstance.getInstance().getForm();
        int[] iArr = new int[form.length()];
        for (int i = 0; i < form.length(); i++) {
            iArr[i] = this.char_table.toIndex(Character.valueOf(form.charAt(i)), -1, z);
        }
        toutanovaInstance.setFormCharIndexes(iArr);
        if (this.pos_table_ == null || (posTag = toutanovaInstance.getInstance().getPosTag()) == null) {
            return;
        }
        toutanovaInstance.setPosTagIndex(this.pos_table_.toIndex(posTag, -1, z));
    }

    public DynamicWeights getWeights() {
        return this.weights_;
    }

    public void setWeights(DynamicWeights dynamicWeights) {
        this.weights_ = dynamicWeights;
        this.scorer_.setWeights(dynamicWeights);
        this.updater_.setWeights(dynamicWeights);
    }

    public boolean isOOV(LemmaInstance lemmaInstance) {
        return !this.form_vocab_.contains(lemmaInstance.getForm());
    }

    static {
        $assertionsDisabled = !ToutanovaModel.class.desiredAssertionStatus();
        FEATURE_BITS = Encoder.bitsNeeded(2);
    }
}
