GigaAM Emotion Classifier
Fine-tuned version of the GigaAM (conformer-based) model, waveletdeboshir/gigaam-rnnt, designed for 5-class emotion recognition on speech, using the nixiieee/dusha_balanced
dataset (cropped version of original dusha dataset).
Model Description
- Architecture: Encoder from pretrained RNNT + custom 3-layer classifier head with LayerNorm & dropout.
- Labels: 5 emotion classes.
- Pooling: Length-aware average pooling using attention mask.
- Output: Logits over emotion classes.
Installation
The model was trained with the following libs versions:
torch==2.5.1 torchaudio==2.5.1 transformers==4.49.0 accelerate==1.5.2
It is recommended to use exactly those versions.
Usage
from transformers import (
PreTrainedModel,
AutoConfig,
AutoModel,
AutoProcessor,
)
from transformers.modeling_outputs import SequenceClassifierOutput
import torch
import torch.nn as nn
import torchaudio
class EmotionClassifier(nn.Module):
def __init__(self, hidden_size, num_labels=5, dropout=0.2):
super().__init__()
self.pool_norm = nn.LayerNorm(hidden_size)
self.pre_dropout = nn.Dropout(dropout)
mid1 = max(hidden_size // 2, num_labels * 4)
mid2 = max(hidden_size // 4, num_labels * 2)
self.classifier = nn.Sequential(
nn.Linear(hidden_size, mid1),
nn.GELU(),
nn.Dropout(dropout),
nn.LayerNorm(mid1),
nn.Linear(mid1, mid2),
nn.GELU(),
nn.Dropout(dropout),
nn.LayerNorm(mid2),
nn.Linear(mid2, num_labels),
)
def forward(self, hidden_states, attention_mask=None):
if attention_mask is not None:
lengths = attention_mask.sum(dim=1, keepdim=True)
masked = hidden_states * attention_mask.unsqueeze(-1)
pooled = masked.sum(dim=1) / lengths
else:
pooled = hidden_states.mean(dim=1)
x = self.pool_norm(pooled)
x = self.pre_dropout(x)
logits = self.classifier(x)
return logits
class ModelForEmotionClassification(PreTrainedModel):
config_class = AutoConfig
def __init__(
self, config, model_name, num_labels=5, dropout=0.2
):
super().__init__(config)
self.encoder = AutoModel.from_pretrained(model_name, trust_remote_code=True).model.encoder
hidden_size = config.encoder['d_model']
self.classifier = EmotionClassifier(
hidden_size, num_labels=num_labels, dropout=dropout
)
self.post_init()
def forward(
self,
input_features: torch.Tensor,
input_lengths: torch.Tensor,
attention_mask: torch.Tensor = None,
labels: torch.Tensor = None
) -> SequenceClassifierOutput:
encoded, out_lens = self.encoder(input_features, input_lengths)
hidden_states = encoded.transpose(1, 2)
if attention_mask is None:
max_t = hidden_states.size(1)
attention_mask = (
torch.arange(max_t, device=out_lens.device)
.unsqueeze(0)
.lt(out_lens.unsqueeze(1))
.long()
)
logits = self.classifier(hidden_states, attention_mask=attention_mask)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
return SequenceClassifierOutput(loss=loss, logits=logits)
model_name = "nixiieee/gigaam-rnnt-emotion-classifier-dusha"
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
model = ModelForEmotionClassification.from_pretrained(model_name, config=config, model_name=model_name)
model.eval()
# load audio
wav, sr = torchaudio.load("audio.wav")
# resample if necessary
wav = torchaudio.functional.resample(wav, sr, 16000)
input_features = processor(wav[0], sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
pred_ids = model.generate(**input_features)
pred = pred_ids.logits.argmax(dim=-1).item()
print("Predicted emotion:", config.id2label[pred])
- Downloads last month
- 35
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
Model tree for nixiieee/gigaam-rnnt-emotion-classifier-dusha
Base model
waveletdeboshir/gigaam-rnnt