AST_diarizer / README.md
agutig's picture
Update README.md
03098fe verified
---
datasets:
- edinburghcstr/ami
base_model:
- MIT/ast-finetuned-audioset-10-10-0.4593
---
# AST-based Speaker Identification on AMI
## Model description
This model is a **fine-tuned** version of [MIT/ast-finetuned-audioset-10-10-0.4593](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593)
for speaker classification on the AMI Meeting Corpus. It was trained on **50** speakers (adjust `num_labels` if different), using 128-bin mel-spectrograms of 1024 frames.
- **Base architecture**: Audio Spectrogram Transformer (AST)
- **Training**: ~10 epochs, batch size=4, learning rate=1e-5, AdamW optimizer, mixed precision
- **Data**: Stratified samples from AMI train/validation/test splits
- **Performance**: Not good, this was just a small experiment for diarization
## How to use
```python
from transformers import AutoProcessor, ASTForAudioClassification
import torch
import numpy as np
# 1) Load the model and processor
MODEL_ID = "agutig/AST_diarizer"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = ASTForAudioClassification.from_pretrained(MODEL_ID)
model.eval()
# 2) Prepare a 1-second audio sample (or load your own)
sr = 16000
audio = np.random.randn(sr).astype(np.float32)
# Alternatively:
# import librosa
# audio, _ = librosa.load("your_audio.wav", sr=sr)
# 3) Preprocess and run inference
inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits # shape [1, num_labels]
probs = torch.softmax(logits, dim=-1)[0]
pred_i = int(probs.argmax())
print(f"Predicted speaker index: {pred_i}")
```
## Usage with `pipeline`
```python
from transformers import pipeline
speaker_id = pipeline(
task="audio-classification",
model="agutig/AST_diarizer",
return_all_scores=True
)
results = speaker_id("path/to/audio.wav")
print(results)
```
## Evaluation & Benchmarks
Clasification:
![image/png](https://cdn-uploads.huggingface.co/production/uploads/6759a53b608daa4d287fd97c/ZfvDY9M32wTtsePzwJV3v.png)
![image/png](https://cdn-uploads.huggingface.co/production/uploads/6759a53b608daa4d287fd97c/BUm30OuKmUWehOIqFjUdO.png)
Embeddings
![image/png](https://cdn-uploads.huggingface.co/production/uploads/6759a53b608daa4d287fd97c/-WBY39T4M4f9pRrZGjqMk.png)
![image/png](https://cdn-uploads.huggingface.co/production/uploads/6759a53b608daa4d287fd97c/ZEWyyjfZtxUZjzFgywqJI.png)
## License
- **Model**: Apache 2.0
- **Base code (AST AudioSet)**: MIT License