agutig commited on
Commit
eceb983
·
verified ·
1 Parent(s): c69e940

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +78 -1
README.md CHANGED
@@ -1,4 +1,81 @@
1
  ---
2
  datasets:
3
  - edinburghcstr/ami
4
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  datasets:
3
  - edinburghcstr/ami
4
+ base_model:
5
+ - MIT/ast-finetuned-audioset-10-10-0.4593
6
+ ---
7
+ # AST-based Speaker Identification on AMI
8
+
9
+ ## Model description
10
+
11
+ This model is a **fine-tuned** version of [MIT/ast-finetuned-audioset-10-10-0.4593](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593)
12
+ for speaker classification on the AMI Meeting Corpus. It was trained on **50** speakers (adjust `num_labels` if different), using 128-bin mel-spectrograms of 1024 frames.
13
+
14
+ - **Base architecture**: Audio Spectrogram Transformer (AST)
15
+ - **Training**: ~10 epochs, batch size=4, learning rate=1e-5, AdamW optimizer, mixed precision
16
+ - **Data**: Stratified samples from AMI train/validation/test splits
17
+ - **Performance**: Not good, this was just a small experiment for diarization
18
+
19
+
20
+
21
+ ## How to use
22
+
23
+ ```python
24
+ from transformers import AutoProcessor, ASTForAudioClassification
25
+ import torch
26
+ import numpy as np
27
+
28
+ # 1) Load the model and processor
29
+ MODEL_ID = "agutig/AST_diarizer"
30
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
31
+ model = ASTForAudioClassification.from_pretrained(MODEL_ID)
32
+ model.eval()
33
+
34
+ # 2) Prepare a 1-second audio sample (or load your own)
35
+ sr = 16000
36
+ audio = np.random.randn(sr).astype(np.float32)
37
+ # Alternatively:
38
+ # import librosa
39
+ # audio, _ = librosa.load("your_audio.wav", sr=sr)
40
+
41
+ # 3) Preprocess and run inference
42
+ inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
43
+ with torch.no_grad():
44
+ logits = model(**inputs).logits # shape [1, num_labels]
45
+ probs = torch.softmax(logits, dim=-1)[0]
46
+ pred_i = int(probs.argmax())
47
+
48
+ print(f"Predicted speaker index: {pred_i}")
49
+ ```
50
+
51
+ ## Usage with `pipeline`
52
+
53
+ ```python
54
+ from transformers import pipeline
55
+
56
+ speaker_id = pipeline(
57
+ task="audio-classification",
58
+ model="agutig/AST_diarizer",
59
+ return_all_scores=True
60
+ )
61
+
62
+ results = speaker_id("path/to/audio.wav")
63
+ print(results)
64
+ ```
65
+
66
+ ## Evaluation & Benchmarks
67
+
68
+ | Metric | Value |
69
+ |--------------------------|---------|
70
+ | Accuracy (test) | 0.XX |
71
+ | Adjusted Rand Index (ARI)| 0.YY |
72
+ | Normalized Mutual Info | 0.ZZ |
73
+
74
+ _Fill in actual values._
75
+
76
+
77
+
78
+ ## License
79
+
80
+ - **Model**: Apache 2.0
81
+ - **Base code (AST AudioSet)**: MIT License