--- license: apache-2.0 datasets: - 012shin/fake-audio-detection-augmented language: - en metrics: - accuracy - f1 - recall - precision base_model: - MIT/ast-finetuned-audioset-10-10-0.4593 pipeline_tag: audio-classification library_name: transformers tags: - audio - audio-classification - fake-audio-detection - ast model-index: - name: ast-fakeaudio-detector results: - task: type: audio-classification name: Audio Classification dataset: name: fake-audio-detection-augmented type: 012shin/fake-audio-detection-augmented metrics: - type: accuracy value: 0.9662 - type: f1 value: 0.9710 - type: precision value: 0.9692 - type: recall value: 0.9728 --- # AST Fine-tuned for Fake Audio Detection This model is a binary classification head fine-tuned version of [MIT/ast-finetuned-audioset-10-10-0.4593](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) for detecting fake/synthetic audio. The original AST (Audio Spectrogram Transformer) classification head was replaced with a binary classification layer optimized for fake audio detection. ## Model Description - **Base Model**: MIT/ast-finetuned-audioset-10-10-0.4593 (AST pretrained on AudioSet) - **Task**: Binary classification (fake/real audio detection) - **Input**: Audio converted to Mel spectrogram (128 mel bins, 1024 time frames) - **Output**: Binary prediction (0: real audio, 1: fake audio) - **Training Hardware**: 2x NVIDIA T4 GPUs ## Training Configuration ```python { 'learning_rate': 1e-5, 'weight_decay': 0.01, 'n_iterations': 10000, 'batch_size': 8, 'gradient_accumulation_steps': 8, 'validate_every': 500, 'val_samples': 5000 } ``` ## Dataset Distribution The model was trained on [012shin/fake-audio-detection-augmented](https://huggingface.co/datasets/012shin/fake-audio-detection-augmented) dataset with the following class distribution: ``` Training Set (80%): - Fake Audio (0): 43,460 samples (63.69%) - Real Audio (1): 24,776 samples (36.31%) Test Set (20%): - Fake Audio (0): 10,776 samples (63.17%) - Real Audio (1): 6,284 samples (36.83%) ``` ## Model Performance Final metrics on validation set: - Accuracy: 0.9662 (96.62%) - F1 Score: 0.9710 (97.10%) - Precision: 0.9692 (96.92%) - Recall: 0.9728 (97.28%) # Usage Guide ## 1. Environment Setup First, clone the AST repository and install required dependencies: ```python # Clone AST repository and set up path git clone https://github.com/YuanGongND/ast.git import sys sys.path.append('./ast') cd ast # Install dependencies pip install timm==0.4.5 wget # Required imports import os import torch import torchaudio import matplotlib.pyplot as plt import numpy as np from torch import nn from src.models import ASTModel ``` ## 2. Model Implementation Implement the BinaryAST model class: ```python class BinaryAST(nn.Module): def __init__(self, pretrained_path='pretrained_models/audioset_10_10_0.4593.pth'): super().__init__() # Initialize AST base model self.ast = ASTModel( label_dim=527, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384' ) # Load pretrained weights if available if os.path.exists(pretrained_path): print(f"Loading pretrained weights from {pretrained_path}") state_dict = torch.load(pretrained_path, map_location='cpu', weights_only=True) self.ast.load_state_dict(state_dict, strict=False) # Binary classification head self.ast.mlp_head = nn.Sequential( nn.LayerNorm(768), nn.Dropout(0.3), nn.Linear(768, 1) ) def forward(self, x): return self.ast(x) ``` ## 3. Audio Processing Function Function to preprocess audio files for model input: ```python def process_audio(file_path, sr=16000): """ Process audio file for model inference. Args: file_path (str): Path to audio file sr (int): Target sample rate (default: 16000) Returns: torch.Tensor: Processed mel spectrogram (1024 x 128) """ # Load audio audio_tensor, orig_sr = torchaudio.load(file_path) print(f"Initial tensor shape: {audio_tensor.shape}, sample_rate={orig_sr}") # Convert to mono if needed if audio_tensor.shape[0] > 1: audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True) # Resample to target sample rate if orig_sr != sr: resampler = torchaudio.transforms.Resample(orig_sr, sr) audio_tensor = resampler(audio_tensor) # Create mel spectrogram mel_spec = torchaudio.transforms.MelSpectrogram( sample_rate=sr, n_mels=128, n_fft=2048, hop_length=160 )(audio_tensor) spec_db = torchaudio.transforms.AmplitudeToDB()(mel_spec) # Post-process spectrogram spec_db = spec_db.squeeze(0).transpose(0, 1) spec_db = (spec_db + 4.26) / (4.57 * 2) # Normalize # Ensure correct length (pad/trim to 1024 frames) target_len = 1024 if spec_db.shape[0] < target_len: pad = torch.zeros(target_len - spec_db.shape[0], 128) spec_db = torch.cat([spec_db, pad], dim=0) else: spec_db = spec_db[:target_len, :] return spec_db ``` ## 4. Model Loading and Inference Example of loading the model and running inference: ```python # Initialize and load model model = BinaryAST() checkpoint = torch.load('/content/final_model.pth', map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) model.eval() # Process audio file spec = process_audio('path_to_audio.mp3') # Visualize spectrogram (optional) plt.figure(figsize=(10, 3)) plt.imshow(spec.numpy().T, aspect='auto', origin='lower') plt.title('Mel Spectrogram') plt.xlabel('Time Frames') plt.ylabel('Mel Bins') plt.colorbar() plt.show() # Run inference spec_batch = spec.unsqueeze(0) with torch.no_grad(): output = model(spec_batch) prob_fake = torch.sigmoid(output).item() print(f"Probability of fake audio: {prob_fake:.4f}") print("Prediction:", "FAKE" if prob_fake > 0.5 else "REAL") ``` ## Key Notes: - Ensure audio files are accessible and in a supported format - The model expects 16kHz sample rate input - Input audio is converted to mono if stereo - The model outputs probability scores (>0.5 indicates fake audio) - Visualization of spectrograms is optional but useful for debugging ## Limitations Important considerations when using this model: 1. The model works best with 16kHz audio input 2. Performance may vary with different types of audio manipulation not present in training data 3. Very short audio clips (<1 second) might not provide reliable results 4. The model should not be used as the sole determiner for real/fake audio detection ## Training Details The training process involved: 1. Loading the base AST model pretrained on AudioSet 2. Replacing the classification head with a binary classifier 3. Fine-tuning on the fake audio detection dataset for 10000 iterations 4. Using gradient accumulation (8 steps) with batch size 8 5. Implementing validation checks every 500 steps