GigaAM Emotion Classifier

Fine-tuned version of the GigaAM (conformer-based) model, waveletdeboshir/gigaam-rnnt, designed for 5-class emotion recognition on speech, using the nixiieee/dusha_balanced dataset (cropped version of original dusha dataset).


Model Description

  • Architecture: Encoder from pretrained RNNT + custom 3-layer classifier head with LayerNorm & dropout.
  • Labels: 5 emotion classes.
  • Pooling: Length-aware average pooling using attention mask.
  • Output: Logits over emotion classes.

Installation

The model was trained with the following libs versions:

torch==2.5.1 torchaudio==2.5.1 transformers==4.49.0 accelerate==1.5.2

It is recommended to use exactly those versions.


Usage

from transformers import (
    PreTrainedModel,
    AutoConfig,
    AutoModel,
    AutoProcessor,
)
from transformers.modeling_outputs import SequenceClassifierOutput
import torch
import torch.nn as nn
import torchaudio

class EmotionClassifier(nn.Module):
    def __init__(self, hidden_size, num_labels=5, dropout=0.2):
        super().__init__()
        self.pool_norm = nn.LayerNorm(hidden_size)
        self.pre_dropout = nn.Dropout(dropout)

        mid1 = max(hidden_size // 2, num_labels * 4)
        mid2 = max(hidden_size // 4, num_labels * 2)

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, mid1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(mid1),
            nn.Linear(mid1, mid2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(mid2),
            nn.Linear(mid2, num_labels),
        )

    def forward(self, hidden_states, attention_mask=None):
        if attention_mask is not None:
            lengths = attention_mask.sum(dim=1, keepdim=True)
            masked = hidden_states * attention_mask.unsqueeze(-1)
            pooled = masked.sum(dim=1) / lengths
        else:
            pooled = hidden_states.mean(dim=1)
        x = self.pool_norm(pooled)
        x = self.pre_dropout(x)
        logits = self.classifier(x)
        return logits
    
class ModelForEmotionClassification(PreTrainedModel):
    config_class = AutoConfig

    def __init__(
        self, config, model_name, num_labels=5, dropout=0.2
    ):
        super().__init__(config)
        self.encoder = AutoModel.from_pretrained(model_name, trust_remote_code=True).model.encoder
        hidden_size = config.encoder['d_model']
        self.classifier = EmotionClassifier(
            hidden_size, num_labels=num_labels, dropout=dropout
        )
        self.post_init()

    def forward(
        self,
        input_features: torch.Tensor,
        input_lengths: torch.Tensor,
        attention_mask: torch.Tensor = None,
        labels: torch.Tensor = None
    ) -> SequenceClassifierOutput:
        encoded, out_lens = self.encoder(input_features, input_lengths)
        hidden_states = encoded.transpose(1, 2)

        if attention_mask is None:
            max_t = hidden_states.size(1)
            attention_mask = (
                torch.arange(max_t, device=out_lens.device)
                .unsqueeze(0)
                .lt(out_lens.unsqueeze(1))
                .long()
            )

        logits = self.classifier(hidden_states, attention_mask=attention_mask)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)

        return SequenceClassifierOutput(loss=loss, logits=logits)

model_name = "nixiieee/gigaam-rnnt-emotion-classifier-dusha"
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
model = ModelForEmotionClassification.from_pretrained(model_name, config=config, model_name=model_name)
model.eval()

# load audio
wav, sr = torchaudio.load("audio.wav")
# resample if necessary
wav = torchaudio.functional.resample(wav, sr, 16000)
input_features = processor(wav[0], sampling_rate=16000, return_tensors="pt")

with torch.no_grad():
    pred_ids = model.generate(**input_features)

pred = pred_ids.logits.argmax(dim=-1).item()
print("Predicted emotion:", config.id2label[pred])
Downloads last month
35
Safetensors
Model size
233M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for nixiieee/gigaam-rnnt-emotion-classifier-dusha

Finetuned
(1)
this model

Dataset used to train nixiieee/gigaam-rnnt-emotion-classifier-dusha