|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- 012shin/fake-audio-detection-augmented |
|
language: |
|
- en |
|
metrics: |
|
- accuracy |
|
- f1 |
|
- recall |
|
- precision |
|
base_model: |
|
- MIT/ast-finetuned-audioset-10-10-0.4593 |
|
pipeline_tag: audio-classification |
|
library_name: transformers |
|
tags: |
|
- audio |
|
- audio-classification |
|
- fake-audio-detection |
|
- ast |
|
model-index: |
|
- name: ast-fakeaudio-detector |
|
results: |
|
- task: |
|
type: audio-classification |
|
name: Audio Classification |
|
dataset: |
|
name: fake-audio-detection-augmented |
|
type: 012shin/fake-audio-detection-augmented |
|
metrics: |
|
- type: accuracy |
|
value: 0.9662 |
|
- type: f1 |
|
value: 0.9710 |
|
- type: precision |
|
value: 0.9692 |
|
- type: recall |
|
value: 0.9728 |
|
--- |
|
|
|
# AST Fine-tuned for Fake Audio Detection |
|
|
|
This model is a binary classification head fine-tuned version of [MIT/ast-finetuned-audioset-10-10-0.4593](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) for detecting fake/synthetic audio. The original AST (Audio Spectrogram Transformer) classification head was replaced with a binary classification layer optimized for fake audio detection. |
|
|
|
## Model Description |
|
|
|
- **Base Model**: MIT/ast-finetuned-audioset-10-10-0.4593 (AST pretrained on AudioSet) |
|
- **Task**: Binary classification (fake/real audio detection) |
|
- **Input**: Audio converted to Mel spectrogram (128 mel bins, 1024 time frames) |
|
- **Output**: Binary prediction (0: real audio, 1: fake audio) |
|
- **Training Hardware**: 2x NVIDIA T4 GPUs |
|
|
|
## Training Configuration |
|
|
|
```python |
|
{ |
|
'learning_rate': 1e-5, |
|
'weight_decay': 0.01, |
|
'n_iterations': 10000, |
|
'batch_size': 8, |
|
'gradient_accumulation_steps': 8, |
|
'validate_every': 500, |
|
'val_samples': 5000 |
|
} |
|
``` |
|
|
|
## Dataset Distribution |
|
|
|
The model was trained on [012shin/fake-audio-detection-augmented](https://huggingface.co/datasets/012shin/fake-audio-detection-augmented) dataset with the following class distribution: |
|
|
|
``` |
|
Training Set (80%): |
|
- Fake Audio (0): 43,460 samples (63.69%) |
|
- Real Audio (1): 24,776 samples (36.31%) |
|
|
|
Test Set (20%): |
|
- Fake Audio (0): 10,776 samples (63.17%) |
|
- Real Audio (1): 6,284 samples (36.83%) |
|
``` |
|
|
|
## Model Performance |
|
|
|
Final metrics on validation set: |
|
- Accuracy: 0.9662 (96.62%) |
|
- F1 Score: 0.9710 (97.10%) |
|
- Precision: 0.9692 (96.92%) |
|
- Recall: 0.9728 (97.28%) |
|
|
|
# Usage Guide |
|
|
|
## 1. Environment Setup |
|
First, clone the AST repository and install required dependencies: |
|
```python |
|
# Clone AST repository and set up path |
|
git clone https://github.com/YuanGongND/ast.git |
|
import sys |
|
sys.path.append('./ast') |
|
cd ast |
|
|
|
# Install dependencies |
|
pip install timm==0.4.5 wget |
|
|
|
# Required imports |
|
import os |
|
import torch |
|
import torchaudio |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from torch import nn |
|
from src.models import ASTModel |
|
``` |
|
|
|
## 2. Model Implementation |
|
Implement the BinaryAST model class: |
|
```python |
|
class BinaryAST(nn.Module): |
|
def __init__(self, pretrained_path='pretrained_models/audioset_10_10_0.4593.pth'): |
|
super().__init__() |
|
# Initialize AST base model |
|
self.ast = ASTModel( |
|
label_dim=527, |
|
input_fdim=128, |
|
input_tdim=1024, |
|
imagenet_pretrain=True, |
|
audioset_pretrain=False, |
|
model_size='base384' |
|
) |
|
|
|
# Load pretrained weights if available |
|
if os.path.exists(pretrained_path): |
|
print(f"Loading pretrained weights from {pretrained_path}") |
|
state_dict = torch.load(pretrained_path, map_location='cpu', weights_only=True) |
|
self.ast.load_state_dict(state_dict, strict=False) |
|
|
|
# Binary classification head |
|
self.ast.mlp_head = nn.Sequential( |
|
nn.LayerNorm(768), |
|
nn.Dropout(0.3), |
|
nn.Linear(768, 1) |
|
) |
|
|
|
def forward(self, x): |
|
return self.ast(x) |
|
``` |
|
|
|
## 3. Audio Processing Function |
|
Function to preprocess audio files for model input: |
|
```python |
|
def process_audio(file_path, sr=16000): |
|
""" |
|
Process audio file for model inference. |
|
|
|
Args: |
|
file_path (str): Path to audio file |
|
sr (int): Target sample rate (default: 16000) |
|
|
|
Returns: |
|
torch.Tensor: Processed mel spectrogram (1024 x 128) |
|
""" |
|
# Load audio |
|
audio_tensor, orig_sr = torchaudio.load(file_path) |
|
print(f"Initial tensor shape: {audio_tensor.shape}, sample_rate={orig_sr}") |
|
|
|
# Convert to mono if needed |
|
if audio_tensor.shape[0] > 1: |
|
audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True) |
|
|
|
# Resample to target sample rate |
|
if orig_sr != sr: |
|
resampler = torchaudio.transforms.Resample(orig_sr, sr) |
|
audio_tensor = resampler(audio_tensor) |
|
|
|
# Create mel spectrogram |
|
mel_spec = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=sr, |
|
n_mels=128, |
|
n_fft=2048, |
|
hop_length=160 |
|
)(audio_tensor) |
|
spec_db = torchaudio.transforms.AmplitudeToDB()(mel_spec) |
|
|
|
# Post-process spectrogram |
|
spec_db = spec_db.squeeze(0).transpose(0, 1) |
|
spec_db = (spec_db + 4.26) / (4.57 * 2) # Normalize |
|
|
|
# Ensure correct length (pad/trim to 1024 frames) |
|
target_len = 1024 |
|
if spec_db.shape[0] < target_len: |
|
pad = torch.zeros(target_len - spec_db.shape[0], 128) |
|
spec_db = torch.cat([spec_db, pad], dim=0) |
|
else: |
|
spec_db = spec_db[:target_len, :] |
|
|
|
return spec_db |
|
``` |
|
|
|
## 4. Model Loading and Inference |
|
Example of loading the model and running inference: |
|
```python |
|
# Initialize and load model |
|
model = BinaryAST() |
|
checkpoint = torch.load('/content/final_model.pth', map_location='cpu') |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.eval() |
|
|
|
# Process audio file |
|
spec = process_audio('path_to_audio.mp3') |
|
|
|
# Visualize spectrogram (optional) |
|
plt.figure(figsize=(10, 3)) |
|
plt.imshow(spec.numpy().T, aspect='auto', origin='lower') |
|
plt.title('Mel Spectrogram') |
|
plt.xlabel('Time Frames') |
|
plt.ylabel('Mel Bins') |
|
plt.colorbar() |
|
plt.show() |
|
|
|
# Run inference |
|
spec_batch = spec.unsqueeze(0) |
|
with torch.no_grad(): |
|
output = model(spec_batch) |
|
prob_fake = torch.sigmoid(output).item() |
|
|
|
print(f"Probability of fake audio: {prob_fake:.4f}") |
|
print("Prediction:", "FAKE" if prob_fake > 0.5 else "REAL") |
|
``` |
|
|
|
## Key Notes: |
|
- Ensure audio files are accessible and in a supported format |
|
- The model expects 16kHz sample rate input |
|
- Input audio is converted to mono if stereo |
|
- The model outputs probability scores (>0.5 indicates fake audio) |
|
- Visualization of spectrograms is optional but useful for debugging |
|
|
|
|
|
## Limitations |
|
|
|
Important considerations when using this model: |
|
1. The model works best with 16kHz audio input |
|
2. Performance may vary with different types of audio manipulation not present in training data |
|
3. Very short audio clips (<1 second) might not provide reliable results |
|
4. The model should not be used as the sole determiner for real/fake audio detection |
|
|
|
## Training Details |
|
|
|
The training process involved: |
|
1. Loading the base AST model pretrained on AudioSet |
|
2. Replacing the classification head with a binary classifier |
|
3. Fine-tuning on the fake audio detection dataset for 10000 iterations |
|
4. Using gradient accumulation (8 steps) with batch size 8 |
|
5. Implementing validation checks every 500 steps |