Kabatubare's picture
Update app.py
09f8c18 verified
import gradio as gr
import torch
import torchaudio
import torchaudio.transforms as T
import traceback
import matplotlib.pyplot as plt
import io
from PIL import Image
# Ensure AudioSeal is imported correctly
try:
from audioseal import AudioSeal
print("AudioSeal is available for watermark detection.")
except ImportError as e:
print(f"AudioSeal could not be imported: {e}")
def load_and_resample_audio(audio_file_path, target_sample_rate=16000):
waveform, sample_rate = torchaudio.load(audio_file_path)
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
waveform = resampler(waveform)
return waveform, target_sample_rate
def extract_mfcc_features(waveform, sample_rate, n_mfcc=40, n_mels=128, win_length=400, hop_length=160):
mfcc_transform = T.MFCC(
sample_rate=sample_rate,
n_mfcc=n_mfcc,
melkwargs={
'n_fft': 400,
'n_mels': n_mels,
'hop_length': hop_length,
'win_length': win_length
}
)
mfcc = mfcc_transform(waveform)
return mfcc.mean(dim=2)
def plot_spectrogram(waveform, sample_rate):
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0) # Ensure waveform is 2D
spectrogram_transform = T.Spectrogram()
spectrogram = spectrogram_transform(waveform)
spectrogram_db = torchaudio.transforms.AmplitudeToDB()(spectrogram)
plt.figure(figsize=(10, 4))
plt.imshow(spectrogram_db[0].numpy(), cmap='hot', aspect='auto', origin='lower')
plt.axis('off') # Hide axes for a clean image
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close()
buf.seek(0)
return Image.open(buf)
def detect_watermark(audio_data, sample_rate):
# Ensure AudioSeal is available
if 'AudioSeal' not in globals():
spectrogram_image = plot_spectrogram(audio_data, sample_rate)
return "AudioSeal not available", spectrogram_image
# Load audio data correctly
waveform, sr = load_and_resample_audio(audio_data, target_sample_rate=16000)
# Ensure waveform is a tensor before passing to the detector
if not isinstance(waveform, torch.Tensor):
return "Error: waveform is not a tensor.", plot_spectrogram(waveform, sr)
# Load the detector
detector = AudioSeal.load_detector("audioseal_detector_16bits")
# Process waveform with the detector
results, messages = detector.forward(waveform.unsqueeze(0), sample_rate=sample_rate) # Ensure waveform is in batch form
detect_probs = results[:, 1, :]
result = detect_probs.mean().cpu().item()
message = f"Detection result: {'Watermarked Audio' if result > 0.5 else 'Not watermarked'}"
spectrogram_image = plot_spectrogram(waveform, sr)
return message, spectrogram_image
def main(audio_file_path):
waveform, resampled_sr = load_and_resample_audio(audio_file_path)
plot_spectrogram(waveform, resampled_sr)
samples_per_batch = 5 * resampled_sr # 5s audios
audio_batches = torch.split(waveform, samples_per_batch, dim=1)[:-1] # Exclude the last batch if it's not 5 seconds long
audio_batched = torch.concat(audio_batches, dim=0)
audio_batched = audio_batched.unsqueeze(1) # add channel dimension
result = detect_watermark(audio_batched, resampled_sr)
print(f"Probability of watermark: {result}")
# Gradio interface
interface = gr.Interface(
fn=detect_watermark,
inputs=gr.Audio(label="Upload your audio", type="filepath"),
outputs=["text", "image"],
title="Deep Fake Defender: AudioSeal Watermark Detection",
description="Analyzes audio to detect AI-generated content."
)
if __name__ == "__main__":
interface.launch()