WpythonW's picture
Update README.md
e401779 verified
|
raw
history blame
7.25 kB
---
license: apache-2.0
datasets:
- 012shin/fake-audio-detection-augmented
language:
- en
metrics:
- accuracy
- f1
- recall
- precision
base_model:
- MIT/ast-finetuned-audioset-10-10-0.4593
pipeline_tag: audio-classification
library_name: transformers
tags:
- audio
- audio-classification
- fake-audio-detection
- ast
model-index:
- name: ast-fakeaudio-detector
results:
- task:
type: audio-classification
name: Audio Classification
dataset:
name: fake-audio-detection-augmented
type: 012shin/fake-audio-detection-augmented
metrics:
- type: accuracy
value: 0.9662
- type: f1
value: 0.9710
- type: precision
value: 0.9692
- type: recall
value: 0.9728
---
# AST Fine-tuned for Fake Audio Detection
This model is a binary classification head fine-tuned version of [MIT/ast-finetuned-audioset-10-10-0.4593](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) for detecting fake/synthetic audio. The original AST (Audio Spectrogram Transformer) classification head was replaced with a binary classification layer optimized for fake audio detection.
## Model Description
- **Base Model**: MIT/ast-finetuned-audioset-10-10-0.4593 (AST pretrained on AudioSet)
- **Task**: Binary classification (fake/real audio detection)
- **Input**: Audio converted to Mel spectrogram (128 mel bins, 1024 time frames)
- **Output**: Binary prediction (0: real audio, 1: fake audio)
- **Training Hardware**: 2x NVIDIA T4 GPUs
## Training Configuration
```python
{
'learning_rate': 1e-5,
'weight_decay': 0.01,
'n_iterations': 10000,
'batch_size': 8,
'gradient_accumulation_steps': 8,
'validate_every': 500,
'val_samples': 5000
}
```
## Dataset Distribution
The model was trained on [012shin/fake-audio-detection-augmented](https://huggingface.co/datasets/012shin/fake-audio-detection-augmented) dataset with the following class distribution:
```
Training Set (80%):
- Fake Audio (0): 43,460 samples (63.69%)
- Real Audio (1): 24,776 samples (36.31%)
Test Set (20%):
- Fake Audio (0): 10,776 samples (63.17%)
- Real Audio (1): 6,284 samples (36.83%)
```
## Model Performance
Final metrics on validation set:
- Accuracy: 0.9662 (96.62%)
- F1 Score: 0.9710 (97.10%)
- Precision: 0.9692 (96.92%)
- Recall: 0.9728 (97.28%)
# Usage Guide
## 1. Environment Setup
First, clone the AST repository and install required dependencies:
```python
# Clone AST repository and set up path
git clone https://github.com/YuanGongND/ast.git
import sys
sys.path.append('./ast')
cd ast
# Install dependencies
pip install timm==0.4.5 wget
# Required imports
import os
import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
from src.models import ASTModel
```
## 2. Model Implementation
Implement the BinaryAST model class:
```python
class BinaryAST(nn.Module):
def __init__(self, pretrained_path='pretrained_models/audioset_10_10_0.4593.pth'):
super().__init__()
# Initialize AST base model
self.ast = ASTModel(
label_dim=527,
input_fdim=128,
input_tdim=1024,
imagenet_pretrain=True,
audioset_pretrain=False,
model_size='base384'
)
# Load pretrained weights if available
if os.path.exists(pretrained_path):
print(f"Loading pretrained weights from {pretrained_path}")
state_dict = torch.load(pretrained_path, map_location='cpu', weights_only=True)
self.ast.load_state_dict(state_dict, strict=False)
# Binary classification head
self.ast.mlp_head = nn.Sequential(
nn.LayerNorm(768),
nn.Dropout(0.3),
nn.Linear(768, 1)
)
def forward(self, x):
return self.ast(x)
```
## 3. Audio Processing Function
Function to preprocess audio files for model input:
```python
def process_audio(file_path, sr=16000):
"""
Process audio file for model inference.
Args:
file_path (str): Path to audio file
sr (int): Target sample rate (default: 16000)
Returns:
torch.Tensor: Processed mel spectrogram (1024 x 128)
"""
# Load audio
audio_tensor, orig_sr = torchaudio.load(file_path)
print(f"Initial tensor shape: {audio_tensor.shape}, sample_rate={orig_sr}")
# Convert to mono if needed
if audio_tensor.shape[0] > 1:
audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True)
# Resample to target sample rate
if orig_sr != sr:
resampler = torchaudio.transforms.Resample(orig_sr, sr)
audio_tensor = resampler(audio_tensor)
# Create mel spectrogram
mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sr,
n_mels=128,
n_fft=2048,
hop_length=160
)(audio_tensor)
spec_db = torchaudio.transforms.AmplitudeToDB()(mel_spec)
# Post-process spectrogram
spec_db = spec_db.squeeze(0).transpose(0, 1)
spec_db = (spec_db + 4.26) / (4.57 * 2) # Normalize
# Ensure correct length (pad/trim to 1024 frames)
target_len = 1024
if spec_db.shape[0] < target_len:
pad = torch.zeros(target_len - spec_db.shape[0], 128)
spec_db = torch.cat([spec_db, pad], dim=0)
else:
spec_db = spec_db[:target_len, :]
return spec_db
```
## 4. Model Loading and Inference
Example of loading the model and running inference:
```python
# Initialize and load model
model = BinaryAST()
checkpoint = torch.load('/content/final_model.pth', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Process audio file
spec = process_audio('path_to_audio.mp3')
# Visualize spectrogram (optional)
plt.figure(figsize=(10, 3))
plt.imshow(spec.numpy().T, aspect='auto', origin='lower')
plt.title('Mel Spectrogram')
plt.xlabel('Time Frames')
plt.ylabel('Mel Bins')
plt.colorbar()
plt.show()
# Run inference
spec_batch = spec.unsqueeze(0)
with torch.no_grad():
output = model(spec_batch)
prob_fake = torch.sigmoid(output).item()
print(f"Probability of fake audio: {prob_fake:.4f}")
print("Prediction:", "FAKE" if prob_fake > 0.5 else "REAL")
```
## Key Notes:
- Ensure audio files are accessible and in a supported format
- The model expects 16kHz sample rate input
- Input audio is converted to mono if stereo
- The model outputs probability scores (>0.5 indicates fake audio)
- Visualization of spectrograms is optional but useful for debugging
## Limitations
Important considerations when using this model:
1. The model works best with 16kHz audio input
2. Performance may vary with different types of audio manipulation not present in training data
3. Very short audio clips (<1 second) might not provide reliable results
4. The model should not be used as the sole determiner for real/fake audio detection
## Training Details
The training process involved:
1. Loading the base AST model pretrained on AudioSet
2. Replacing the classification head with a binary classifier
3. Fine-tuning on the fake audio detection dataset for 10000 iterations
4. Using gradient accumulation (8 steps) with batch size 8
5. Implementing validation checks every 500 steps