|
--- |
|
datasets: |
|
- edinburghcstr/ami |
|
base_model: |
|
- MIT/ast-finetuned-audioset-10-10-0.4593 |
|
--- |
|
# AST-based Speaker Identification on AMI |
|
|
|
## Model description |
|
|
|
This model is a **fine-tuned** version of [MIT/ast-finetuned-audioset-10-10-0.4593](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) |
|
for speaker classification on the AMI Meeting Corpus. It was trained on **50** speakers (adjust `num_labels` if different), using 128-bin mel-spectrograms of 1024 frames. |
|
|
|
- **Base architecture**: Audio Spectrogram Transformer (AST) |
|
- **Training**: ~10 epochs, batch size=4, learning rate=1e-5, AdamW optimizer, mixed precision |
|
- **Data**: Stratified samples from AMI train/validation/test splits |
|
- **Performance**: Not good, this was just a small experiment for diarization |
|
|
|
|
|
|
|
## How to use |
|
|
|
```python |
|
from transformers import AutoProcessor, ASTForAudioClassification |
|
import torch |
|
import numpy as np |
|
|
|
# 1) Load the model and processor |
|
MODEL_ID = "agutig/AST_diarizer" |
|
processor = AutoProcessor.from_pretrained(MODEL_ID) |
|
model = ASTForAudioClassification.from_pretrained(MODEL_ID) |
|
model.eval() |
|
|
|
# 2) Prepare a 1-second audio sample (or load your own) |
|
sr = 16000 |
|
audio = np.random.randn(sr).astype(np.float32) |
|
# Alternatively: |
|
# import librosa |
|
# audio, _ = librosa.load("your_audio.wav", sr=sr) |
|
|
|
# 3) Preprocess and run inference |
|
inputs = processor(audio, sampling_rate=sr, return_tensors="pt") |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits # shape [1, num_labels] |
|
probs = torch.softmax(logits, dim=-1)[0] |
|
pred_i = int(probs.argmax()) |
|
|
|
print(f"Predicted speaker index: {pred_i}") |
|
``` |
|
|
|
## Usage with `pipeline` |
|
|
|
```python |
|
from transformers import pipeline |
|
|
|
speaker_id = pipeline( |
|
task="audio-classification", |
|
model="agutig/AST_diarizer", |
|
return_all_scores=True |
|
) |
|
|
|
results = speaker_id("path/to/audio.wav") |
|
print(results) |
|
``` |
|
|
|
## Evaluation & Benchmarks |
|
|
|
Clasification: |
|
 |
|
|
|
 |
|
|
|
Embeddings |
|
|
|
 |
|
|
|
|
|
 |
|
|
|
|
|
|
|
## License |
|
|
|
- **Model**: Apache 2.0 |
|
- **Base code (AST AudioSet)**: MIT License |