package edu.emory.mathcs.nlp.component.template;

import edu.emory.mathcs.nlp.component.template.config.NLPConfig;
import edu.emory.mathcs.nlp.component.template.eval.Eval;
import edu.emory.mathcs.nlp.component.template.feature.FeatureTemplate;
import edu.emory.mathcs.nlp.component.template.node.AbstractNLPNode;
import edu.emory.mathcs.nlp.component.template.state.NLPState;
import edu.emory.mathcs.nlp.component.template.train.HyperParameter;
import edu.emory.mathcs.nlp.component.template.util.NLPFlag;
import edu.emory.mathcs.nlp.learning.optimization.OnlineOptimizer;
import edu.emory.mathcs.nlp.learning.util.FeatureVector;
import edu.emory.mathcs.nlp.learning.util.Instance;
import edu.emory.mathcs.nlp.learning.util.MLUtils;
import java.io.InputStream;
import java.io.Serializable;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/emory/mathcs/nlp/component/template/OnlineComponent.class */
public abstract class OnlineComponent<N extends AbstractNLPNode<N>, S extends NLPState<N>> implements NLPComponent<N>, Serializable {
    private static final long serialVersionUID = 59819173578703335L;
    protected FeatureTemplate<N, S> feature_template;
    protected boolean document_based;
    protected OnlineOptimizer optimizer;
    protected transient HyperParameter hyper_parameter;
    protected transient NLPConfig<N> config;
    protected transient NLPFlag flag;
    protected transient Eval eval;

    public OnlineComponent(boolean z) {
        setDocumentBased(z);
    }

    public OnlineComponent(boolean z, InputStream inputStream) {
        this(z);
        setConfiguration(inputStream);
    }

    public OnlineOptimizer getOptimizer() {
        return this.optimizer;
    }

    public void setOptimizer(OnlineOptimizer onlineOptimizer) {
        this.optimizer = onlineOptimizer;
    }

    public HyperParameter getHyperParameter() {
        return this.hyper_parameter;
    }

    public void setHyperParameter(HyperParameter hyperParameter) {
        this.hyper_parameter = hyperParameter;
    }

    public FeatureTemplate<N, S> getFeatureTemplate() {
        return this.feature_template;
    }

    public void setFeatureTemplate(FeatureTemplate<N, S> featureTemplate) {
        this.feature_template = featureTemplate;
    }

    public void initFeatureTemplate() {
        this.feature_template = new FeatureTemplate<>(this.config.getFeatureTemplateElement(), getHyperParameter());
    }

    public Eval getEval() {
        return this.eval;
    }

    public void setEval(Eval eval) {
        this.eval = eval;
    }

    public NLPFlag getFlag() {
        return this.flag;
    }

    public void setFlag(NLPFlag nLPFlag) {
        this.flag = nLPFlag;
        if (nLPFlag == NLPFlag.EVALUATE && this.eval == null) {
            setEval(createEvaluator());
        }
    }

    public NLPConfig<N> getConfiguration() {
        return this.config;
    }

    public void setConfiguration(NLPConfig<N> nLPConfig) {
        this.config = nLPConfig;
    }

    public NLPConfig<N> setConfiguration(InputStream inputStream) {
        NLPConfig<N> nLPConfig = new NLPConfig<>(inputStream);
        setConfiguration(nLPConfig);
        return nLPConfig;
    }

    public boolean isDocumentBased() {
        return this.document_based;
    }

    public void setDocumentBased(boolean z) {
        this.document_based = z;
    }

    public boolean isTrain() {
        return this.flag == NLPFlag.TRAIN;
    }

    public boolean isDecode() {
        return this.flag == NLPFlag.DECODE;
    }

    public boolean isEvaluate() {
        return this.flag == NLPFlag.EVALUATE;
    }

    @Override // edu.emory.mathcs.nlp.component.template.NLPComponent
    public void process(N[] nArr) {
        process((OnlineComponent<N, S>) initState(nArr));
    }

    @Override // edu.emory.mathcs.nlp.component.template.NLPComponent
    public void process(List<N[]> list) {
        if (this.document_based) {
            process((OnlineComponent<N, S>) initState(list));
            return;
        }
        Iterator<N[]> it2 = list.iterator();
        while (it2.hasNext()) {
            process(it2.next());
        }
    }

    public S process(S s) {
        float[] scores;
        if (!isDecode() && !s.saveOracle()) {
            return s;
        }
        int[] iArr = {0, -1};
        while (!s.isTerminate()) {
            FeatureVector createFeatureVector = this.feature_template.createFeatureVector(s, isTrain());
            if (isTrain()) {
                Instance instance = new Instance(s.getOracle(), createFeatureVector);
                this.optimizer.train(instance);
                scores = instance.getScores();
                putLabel(instance.getStringLabel(), instance.getGoldLabel());
                iArr[0] = this.hyper_parameter.getLOLS().chooseGold() ? instance.getGoldLabel() : getPrediction(s, scores)[0];
            } else {
                scores = this.optimizer.scores(createFeatureVector);
                iArr = getPrediction(s, scores);
            }
            s.next(this.optimizer.getLabelMap(), iArr, scores);
        }
        if (isDecode() || isEvaluate()) {
            postProcess(s);
            if (isEvaluate()) {
                s.evaluate(this.eval);
            }
        }
        return s;
    }

    protected int[] getPrediction(S s, float[] fArr) {
        return MLUtils.argmax2(fArr);
    }

    protected void putLabel(String str, int i) {
    }

    protected abstract S initState(N[] nArr);

    protected abstract S initState(List<N[]> list);

    public abstract Eval createEvaluator();

    protected abstract void postProcess(S s);
}
