package edu.emory.mathcs.nlp.learning.util;

import edu.emory.mathcs.nlp.learning.activation.ActivationFunction;
import edu.emory.mathcs.nlp.learning.initialization.WeightGenerator;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/emory/mathcs/nlp/learning/util/WeightVector.class */
public class WeightVector implements Serializable {
    private static final long serialVersionUID = -3283251983046316463L;
    private ActivationFunction activation_function;
    private MajorVector sparse_weight_vector;
    private MajorVector dense_weight_vector;

    public WeightVector() {
        this(null);
    }

    public WeightVector(ActivationFunction activationFunction) {
        setSparseWeightVector(new ColumnMajorVector());
        setDenseWeightVector(new RowMajorVector());
        setActivationFunction(activationFunction);
    }

    public MajorVector getMajorVector(boolean z) {
        return z ? this.sparse_weight_vector : this.dense_weight_vector;
    }

    public MajorVector getSparseWeightVector() {
        return this.sparse_weight_vector;
    }

    public void setSparseWeightVector(MajorVector majorVector) {
        this.sparse_weight_vector = majorVector;
    }

    public MajorVector getDenseWeightVector() {
        return this.dense_weight_vector;
    }

    public void setDenseWeightVector(MajorVector majorVector) {
        this.dense_weight_vector = majorVector;
    }

    public ActivationFunction getActivationFunction() {
        return this.activation_function;
    }

    public void setActivationFunction(ActivationFunction activationFunction) {
        this.activation_function = activationFunction;
    }

    public boolean hasActivationFunction() {
        return this.activation_function != null;
    }

    public int getLabelSize() {
        return this.sparse_weight_vector.getLabelSize();
    }

    public boolean expand(int i, int i2, int i3) {
        return expand(i, i2, i3, null);
    }

    public boolean expand(int i, int i2, int i3, WeightGenerator weightGenerator) {
        return false | this.sparse_weight_vector.expand(i3, i, weightGenerator) | this.dense_weight_vector.expand(i3, i2, weightGenerator);
    }

    public WeightVector createZeroVector() {
        WeightVector weightVector = new WeightVector(this.activation_function);
        weightVector.setSparseWeightVector(this.sparse_weight_vector.createZeroVector());
        weightVector.setDenseWeightVector(this.dense_weight_vector.createZeroVector());
        return weightVector;
    }

    public int countNonZeroWeights() {
        return this.sparse_weight_vector.countNonZeroWeights() + this.dense_weight_vector.countNonZeroWeights();
    }

    public List<int[]> getTopFeatureCombinations(FeatureVector featureVector, int i, int i2) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator<SparseItem> it2 = featureVector.getSparseVector().iterator();
        while (it2.hasNext()) {
            SparseItem next = it2.next();
            float f = this.sparse_weight_vector.get(i, next.getIndex()) - this.sparse_weight_vector.get(i2, next.getIndex());
            SparsePrediction sparsePrediction = new SparsePrediction(next.getIndex(), f);
            if (f > 0.0f) {
                arrayList.add(sparsePrediction);
            }
        }
        Collections.sort(arrayList, Collections.reverseOrder());
        if (!arrayList.isEmpty()) {
            SparsePrediction sparsePrediction2 = (SparsePrediction) arrayList.get(0);
            for (int i3 = 1; i3 < 3 && i3 < arrayList.size(); i3++) {
                arrayList2.add(new int[]{sparsePrediction2.getLabel(), ((SparsePrediction) arrayList.get(i3)).getLabel()});
            }
        }
        return arrayList2;
    }

    public float[] scores(FeatureVector featureVector) {
        float[] fArr = new float[getLabelSize()];
        addScores(featureVector, fArr);
        return fArr;
    }

    public void addScores(FeatureVector featureVector, float[] fArr) {
        if (featureVector.hasSparseVector()) {
            this.sparse_weight_vector.addScores(featureVector.getSparseVector(), fArr);
        }
        if (featureVector.hasDenseVector()) {
            this.dense_weight_vector.addScores(featureVector.getDenseVector(), fArr);
        }
        if (hasActivationFunction()) {
            this.activation_function.apply(fArr);
        }
    }
}
