WpythonW commited on
Commit
6f2fece
·
verified ·
1 Parent(s): e19274f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +25 -147
README.md CHANGED
@@ -56,8 +56,8 @@ This model is a binary classification head fine-tuned version of [MIT/ast-finetu
56
  {
57
  'learning_rate': 1e-5,
58
  'weight_decay': 0.01,
59
- 'n_iterations': 10000,
60
- 'batch_size': 8,
61
  'gradient_accumulation_steps': 8,
62
  'validate_every': 500,
63
  'val_samples': 5000
@@ -66,16 +66,16 @@ This model is a binary classification head fine-tuned version of [MIT/ast-finetu
66
 
67
  ## Dataset Distribution
68
 
69
- The model was trained on [012shin/fake-audio-detection-augmented](https://huggingface.co/datasets/012shin/fake-audio-detection-augmented) dataset with the following class distribution:
70
 
71
  ```
72
- Training Set (80%):
73
- - Fake Audio (0): 43,460 samples (63.69%)
74
- - Real Audio (1): 24,776 samples (36.31%)
75
 
76
- Test Set (20%):
77
- - Fake Audio (0): 10,776 samples (63.17%)
78
- - Real Audio (1): 6,284 samples (36.83%)
79
  ```
80
 
81
  ## Model Performance
@@ -88,151 +88,29 @@ Final metrics on validation set:
88
 
89
  # Usage Guide
90
 
91
- ## 1. Environment Setup
92
- First, clone the AST repository and install required dependencies:
93
  ```python
94
- # Clone AST repository and set up path
95
- git clone https://github.com/YuanGongND/ast.git
96
- import sys
97
- sys.path.append('./ast')
98
- cd ast
99
-
100
- # Install dependencies
101
- pip install timm==0.4.5 wget
102
-
103
- # Required imports
104
- import os
105
- import torch
106
  import torchaudio
107
- import matplotlib.pyplot as plt
108
- import numpy as np
109
- from torch import nn
110
- from src.models import ASTModel
111
- ```
112
 
113
- ## 2. Model Implementation
114
- Implement the BinaryAST model class:
115
- ```python
116
- class BinaryAST(nn.Module):
117
- def __init__(self, pretrained_path='pretrained_models/audioset_10_10_0.4593.pth'):
118
- super().__init__()
119
- # Initialize AST base model
120
- self.ast = ASTModel(
121
- label_dim=527,
122
- input_fdim=128,
123
- input_tdim=1024,
124
- imagenet_pretrain=True,
125
- audioset_pretrain=False,
126
- model_size='base384'
127
- )
128
-
129
- # Load pretrained weights if available
130
- if os.path.exists(pretrained_path):
131
- print(f"Loading pretrained weights from {pretrained_path}")
132
- state_dict = torch.load(pretrained_path, map_location='cpu', weights_only=True)
133
- self.ast.load_state_dict(state_dict, strict=False)
134
-
135
- # Binary classification head
136
- self.ast.mlp_head = nn.Sequential(
137
- nn.LayerNorm(768),
138
- nn.Dropout(0.3),
139
- nn.Linear(768, 1)
140
- )
141
-
142
- def forward(self, x):
143
- return self.ast(x)
144
- ```
145
 
146
- ## 3. Audio Processing Function
147
- Function to preprocess audio files for model input:
148
- ```python
149
- def process_audio(file_path, sr=16000):
150
- """
151
- Process audio file for model inference.
152
-
153
- Args:
154
- file_path (str): Path to audio file
155
- sr (int): Target sample rate (default: 16000)
156
-
157
- Returns:
158
- torch.Tensor: Processed mel spectrogram (1024 x 128)
159
- """
160
- # Load audio
161
- audio_tensor, orig_sr = torchaudio.load(file_path)
162
- print(f"Initial tensor shape: {audio_tensor.shape}, sample_rate={orig_sr}")
163
-
164
- # Convert to mono if needed
165
- if audio_tensor.shape[0] > 1:
166
- audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True)
167
-
168
- # Resample to target sample rate
169
- if orig_sr != sr:
170
- resampler = torchaudio.transforms.Resample(orig_sr, sr)
171
- audio_tensor = resampler(audio_tensor)
172
-
173
- # Create mel spectrogram
174
- mel_spec = torchaudio.transforms.MelSpectrogram(
175
- sample_rate=sr,
176
- n_mels=128,
177
- n_fft=2048,
178
- hop_length=160
179
- )(audio_tensor)
180
- spec_db = torchaudio.transforms.AmplitudeToDB()(mel_spec)
181
-
182
- # Post-process spectrogram
183
- spec_db = spec_db.squeeze(0).transpose(0, 1)
184
- spec_db = (spec_db + 4.26) / (4.57 * 2) # Normalize
185
-
186
- # Ensure correct length (pad/trim to 1024 frames)
187
- target_len = 1024
188
- if spec_db.shape[0] < target_len:
189
- pad = torch.zeros(target_len - spec_db.shape[0], 128)
190
- spec_db = torch.cat([spec_db, pad], dim=0)
191
- else:
192
- spec_db = spec_db[:target_len, :]
193
-
194
- return spec_db
195
- ```
196
 
197
- ## 4. Model Loading and Inference
198
- Example of loading the model and running inference:
199
- ```python
200
- # Initialize and load model
201
- model = BinaryAST()
202
- checkpoint = torch.load('/content/final_model.pth', map_location='cpu')
203
- model.load_state_dict(checkpoint['model_state_dict'])
204
- model.eval()
205
-
206
- # Process audio file
207
- spec = process_audio('path_to_audio.mp3')
208
-
209
- # Visualize spectrogram (optional)
210
- plt.figure(figsize=(10, 3))
211
- plt.imshow(spec.numpy().T, aspect='auto', origin='lower')
212
- plt.title('Mel Spectrogram')
213
- plt.xlabel('Time Frames')
214
- plt.ylabel('Mel Bins')
215
- plt.colorbar()
216
- plt.show()
217
-
218
- # Run inference
219
- spec_batch = spec.unsqueeze(0)
220
  with torch.no_grad():
221
- output = model(spec_batch)
222
- prob_fake = torch.sigmoid(output).item()
223
 
224
- print(f"Probability of fake audio: {prob_fake:.4f}")
225
- print("Prediction:", "FAKE" if prob_fake > 0.5 else "REAL")
226
  ```
227
 
228
- ## Key Notes:
229
- - Ensure audio files are accessible and in a supported format
230
- - The model expects 16kHz sample rate input
231
- - Input audio is converted to mono if stereo
232
- - The model outputs probability scores (>0.5 indicates fake audio)
233
- - Visualization of spectrograms is optional but useful for debugging
234
-
235
-
236
  ## Limitations
237
 
238
  Important considerations when using this model:
@@ -246,6 +124,6 @@ Important considerations when using this model:
246
  The training process involved:
247
  1. Loading the base AST model pretrained on AudioSet
248
  2. Replacing the classification head with a binary classifier
249
- 3. Fine-tuning on the fake audio detection dataset for 10000 iterations
250
- 4. Using gradient accumulation (8 steps) with batch size 8
251
  5. Implementing validation checks every 500 steps
 
56
  {
57
  'learning_rate': 1e-5,
58
  'weight_decay': 0.01,
59
+ 'n_iterations': 1500,
60
+ 'batch_size': 16,
61
  'gradient_accumulation_steps': 8,
62
  'validate_every': 500,
63
  'val_samples': 5000
 
66
 
67
  ## Dataset Distribution
68
 
69
+ The model was trained on a filtered dataset with the following class distribution:
70
 
71
  ```
72
+ Training Set:
73
+ - Fake Audio (0): 29,089 samples (53.97%)
74
+ - Real Audio (1): 24,813 samples (46.03%)
75
 
76
+ Test Set:
77
+ - Fake Audio (0): 7,229 samples (53.64%)
78
+ - Real Audio (1): 6,247 samples (46.36%)
79
  ```
80
 
81
  ## Model Performance
 
88
 
89
  # Usage Guide
90
 
91
+ ## Model Usage
 
92
  ```python
93
+ from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
 
 
 
 
 
 
 
 
 
 
 
94
  import torchaudio
95
+ import torch
 
 
 
 
96
 
97
+ # Load audio file
98
+ waveform, sample_rate = torchaudio.load("path_to_audio.ogg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ # Initialize model and feature extractor
101
+ model_name = "WpythonW/ast-fakeaudio-detector"
102
+ extractor = AutoFeatureExtractor.from_pretrained(model_name)
103
+ model = AutoModelForAudioClassification.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ # Process audio and get predictions
106
+ inputs = extractor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  with torch.no_grad():
108
+ logits = model(**inputs).logits
109
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
110
 
111
+ print(f"Probability of fake audio: {probabilities[0][0]:.2%}")
 
112
  ```
113
 
 
 
 
 
 
 
 
 
114
  ## Limitations
115
 
116
  Important considerations when using this model:
 
124
  The training process involved:
125
  1. Loading the base AST model pretrained on AudioSet
126
  2. Replacing the classification head with a binary classifier
127
+ 3. Fine-tuning on the fake audio detection dataset for 1500 iterations
128
+ 4. Using gradient accumulation (8 steps) with batch size 16
129
  5. Implementing validation checks every 500 steps