Update README.md
Browse files
README.md
CHANGED
@@ -56,8 +56,8 @@ This model is a binary classification head fine-tuned version of [MIT/ast-finetu
|
|
56 |
{
|
57 |
'learning_rate': 1e-5,
|
58 |
'weight_decay': 0.01,
|
59 |
-
'n_iterations':
|
60 |
-
'batch_size':
|
61 |
'gradient_accumulation_steps': 8,
|
62 |
'validate_every': 500,
|
63 |
'val_samples': 5000
|
@@ -66,16 +66,16 @@ This model is a binary classification head fine-tuned version of [MIT/ast-finetu
|
|
66 |
|
67 |
## Dataset Distribution
|
68 |
|
69 |
-
The model was trained on
|
70 |
|
71 |
```
|
72 |
-
Training Set
|
73 |
-
- Fake Audio (0):
|
74 |
-
- Real Audio (1): 24,
|
75 |
|
76 |
-
Test Set
|
77 |
-
- Fake Audio (0):
|
78 |
-
- Real Audio (1): 6,
|
79 |
```
|
80 |
|
81 |
## Model Performance
|
@@ -88,151 +88,29 @@ Final metrics on validation set:
|
|
88 |
|
89 |
# Usage Guide
|
90 |
|
91 |
-
##
|
92 |
-
First, clone the AST repository and install required dependencies:
|
93 |
```python
|
94 |
-
|
95 |
-
git clone https://github.com/YuanGongND/ast.git
|
96 |
-
import sys
|
97 |
-
sys.path.append('./ast')
|
98 |
-
cd ast
|
99 |
-
|
100 |
-
# Install dependencies
|
101 |
-
pip install timm==0.4.5 wget
|
102 |
-
|
103 |
-
# Required imports
|
104 |
-
import os
|
105 |
-
import torch
|
106 |
import torchaudio
|
107 |
-
import
|
108 |
-
import numpy as np
|
109 |
-
from torch import nn
|
110 |
-
from src.models import ASTModel
|
111 |
-
```
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
```python
|
116 |
-
class BinaryAST(nn.Module):
|
117 |
-
def __init__(self, pretrained_path='pretrained_models/audioset_10_10_0.4593.pth'):
|
118 |
-
super().__init__()
|
119 |
-
# Initialize AST base model
|
120 |
-
self.ast = ASTModel(
|
121 |
-
label_dim=527,
|
122 |
-
input_fdim=128,
|
123 |
-
input_tdim=1024,
|
124 |
-
imagenet_pretrain=True,
|
125 |
-
audioset_pretrain=False,
|
126 |
-
model_size='base384'
|
127 |
-
)
|
128 |
-
|
129 |
-
# Load pretrained weights if available
|
130 |
-
if os.path.exists(pretrained_path):
|
131 |
-
print(f"Loading pretrained weights from {pretrained_path}")
|
132 |
-
state_dict = torch.load(pretrained_path, map_location='cpu', weights_only=True)
|
133 |
-
self.ast.load_state_dict(state_dict, strict=False)
|
134 |
-
|
135 |
-
# Binary classification head
|
136 |
-
self.ast.mlp_head = nn.Sequential(
|
137 |
-
nn.LayerNorm(768),
|
138 |
-
nn.Dropout(0.3),
|
139 |
-
nn.Linear(768, 1)
|
140 |
-
)
|
141 |
-
|
142 |
-
def forward(self, x):
|
143 |
-
return self.ast(x)
|
144 |
-
```
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
"""
|
151 |
-
Process audio file for model inference.
|
152 |
-
|
153 |
-
Args:
|
154 |
-
file_path (str): Path to audio file
|
155 |
-
sr (int): Target sample rate (default: 16000)
|
156 |
-
|
157 |
-
Returns:
|
158 |
-
torch.Tensor: Processed mel spectrogram (1024 x 128)
|
159 |
-
"""
|
160 |
-
# Load audio
|
161 |
-
audio_tensor, orig_sr = torchaudio.load(file_path)
|
162 |
-
print(f"Initial tensor shape: {audio_tensor.shape}, sample_rate={orig_sr}")
|
163 |
-
|
164 |
-
# Convert to mono if needed
|
165 |
-
if audio_tensor.shape[0] > 1:
|
166 |
-
audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True)
|
167 |
-
|
168 |
-
# Resample to target sample rate
|
169 |
-
if orig_sr != sr:
|
170 |
-
resampler = torchaudio.transforms.Resample(orig_sr, sr)
|
171 |
-
audio_tensor = resampler(audio_tensor)
|
172 |
-
|
173 |
-
# Create mel spectrogram
|
174 |
-
mel_spec = torchaudio.transforms.MelSpectrogram(
|
175 |
-
sample_rate=sr,
|
176 |
-
n_mels=128,
|
177 |
-
n_fft=2048,
|
178 |
-
hop_length=160
|
179 |
-
)(audio_tensor)
|
180 |
-
spec_db = torchaudio.transforms.AmplitudeToDB()(mel_spec)
|
181 |
-
|
182 |
-
# Post-process spectrogram
|
183 |
-
spec_db = spec_db.squeeze(0).transpose(0, 1)
|
184 |
-
spec_db = (spec_db + 4.26) / (4.57 * 2) # Normalize
|
185 |
-
|
186 |
-
# Ensure correct length (pad/trim to 1024 frames)
|
187 |
-
target_len = 1024
|
188 |
-
if spec_db.shape[0] < target_len:
|
189 |
-
pad = torch.zeros(target_len - spec_db.shape[0], 128)
|
190 |
-
spec_db = torch.cat([spec_db, pad], dim=0)
|
191 |
-
else:
|
192 |
-
spec_db = spec_db[:target_len, :]
|
193 |
-
|
194 |
-
return spec_db
|
195 |
-
```
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
```python
|
200 |
-
# Initialize and load model
|
201 |
-
model = BinaryAST()
|
202 |
-
checkpoint = torch.load('/content/final_model.pth', map_location='cpu')
|
203 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
204 |
-
model.eval()
|
205 |
-
|
206 |
-
# Process audio file
|
207 |
-
spec = process_audio('path_to_audio.mp3')
|
208 |
-
|
209 |
-
# Visualize spectrogram (optional)
|
210 |
-
plt.figure(figsize=(10, 3))
|
211 |
-
plt.imshow(spec.numpy().T, aspect='auto', origin='lower')
|
212 |
-
plt.title('Mel Spectrogram')
|
213 |
-
plt.xlabel('Time Frames')
|
214 |
-
plt.ylabel('Mel Bins')
|
215 |
-
plt.colorbar()
|
216 |
-
plt.show()
|
217 |
-
|
218 |
-
# Run inference
|
219 |
-
spec_batch = spec.unsqueeze(0)
|
220 |
with torch.no_grad():
|
221 |
-
|
222 |
-
|
223 |
|
224 |
-
print(f"Probability of fake audio: {
|
225 |
-
print("Prediction:", "FAKE" if prob_fake > 0.5 else "REAL")
|
226 |
```
|
227 |
|
228 |
-
## Key Notes:
|
229 |
-
- Ensure audio files are accessible and in a supported format
|
230 |
-
- The model expects 16kHz sample rate input
|
231 |
-
- Input audio is converted to mono if stereo
|
232 |
-
- The model outputs probability scores (>0.5 indicates fake audio)
|
233 |
-
- Visualization of spectrograms is optional but useful for debugging
|
234 |
-
|
235 |
-
|
236 |
## Limitations
|
237 |
|
238 |
Important considerations when using this model:
|
@@ -246,6 +124,6 @@ Important considerations when using this model:
|
|
246 |
The training process involved:
|
247 |
1. Loading the base AST model pretrained on AudioSet
|
248 |
2. Replacing the classification head with a binary classifier
|
249 |
-
3. Fine-tuning on the fake audio detection dataset for
|
250 |
-
4. Using gradient accumulation (8 steps) with batch size
|
251 |
5. Implementing validation checks every 500 steps
|
|
|
56 |
{
|
57 |
'learning_rate': 1e-5,
|
58 |
'weight_decay': 0.01,
|
59 |
+
'n_iterations': 1500,
|
60 |
+
'batch_size': 16,
|
61 |
'gradient_accumulation_steps': 8,
|
62 |
'validate_every': 500,
|
63 |
'val_samples': 5000
|
|
|
66 |
|
67 |
## Dataset Distribution
|
68 |
|
69 |
+
The model was trained on a filtered dataset with the following class distribution:
|
70 |
|
71 |
```
|
72 |
+
Training Set:
|
73 |
+
- Fake Audio (0): 29,089 samples (53.97%)
|
74 |
+
- Real Audio (1): 24,813 samples (46.03%)
|
75 |
|
76 |
+
Test Set:
|
77 |
+
- Fake Audio (0): 7,229 samples (53.64%)
|
78 |
+
- Real Audio (1): 6,247 samples (46.36%)
|
79 |
```
|
80 |
|
81 |
## Model Performance
|
|
|
88 |
|
89 |
# Usage Guide
|
90 |
|
91 |
+
## Model Usage
|
|
|
92 |
```python
|
93 |
+
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
import torchaudio
|
95 |
+
import torch
|
|
|
|
|
|
|
|
|
96 |
|
97 |
+
# Load audio file
|
98 |
+
waveform, sample_rate = torchaudio.load("path_to_audio.ogg")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
+
# Initialize model and feature extractor
|
101 |
+
model_name = "WpythonW/ast-fakeaudio-detector"
|
102 |
+
extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
103 |
+
model = AutoModelForAudioClassification.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
+
# Process audio and get predictions
|
106 |
+
inputs = extractor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
with torch.no_grad():
|
108 |
+
logits = model(**inputs).logits
|
109 |
+
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
110 |
|
111 |
+
print(f"Probability of fake audio: {probabilities[0][0]:.2%}")
|
|
|
112 |
```
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
## Limitations
|
115 |
|
116 |
Important considerations when using this model:
|
|
|
124 |
The training process involved:
|
125 |
1. Loading the base AST model pretrained on AudioSet
|
126 |
2. Replacing the classification head with a binary classifier
|
127 |
+
3. Fine-tuning on the fake audio detection dataset for 1500 iterations
|
128 |
+
4. Using gradient accumulation (8 steps) with batch size 16
|
129 |
5. Implementing validation checks every 500 steps
|