Spaces:
Sleeping
Sleeping
import torch | |
import torchaudio | |
from transformers import ( | |
WhisperProcessor, | |
WhisperForConditionalGeneration, | |
pipeline | |
) | |
from pyannote.audio import Pipeline | |
import librosa | |
import numpy as np | |
from pydub import AudioSegment | |
import tempfile | |
class SpeechProcessor: | |
def __init__(self): | |
# Load Whisper for ASR | |
self.whisper_processor = WhisperProcessor.from_pretrained( | |
"openai/whisper-medium" | |
) | |
self.whisper_model = WhisperForConditionalGeneration.from_pretrained( | |
"openai/whisper-medium" | |
) | |
# Load speaker diarization | |
self.diarization_pipeline = Pipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", | |
use_auth_token=os.environ.get("HF_TOKEN") | |
) | |
def process_audio(self, audio_path, language="id"): | |
""" | |
Process audio file untuk ASR dan speaker diarization | |
""" | |
# Convert to WAV if needed | |
audio_path = self._ensure_wav_format(audio_path) | |
# Load audio | |
waveform, sample_rate = torchaudio.load(audio_path) | |
# Speaker diarization | |
diarization = self.diarization_pipeline(audio_path) | |
# Process each speaker segment | |
transcript_segments = [] | |
for turn, _, speaker in diarization.itertracks(yield_label=True): | |
# Extract segment audio | |
start_sample = int(turn.start * sample_rate) | |
end_sample = int(turn.end * sample_rate) | |
segment_waveform = waveform[:, start_sample:end_sample] | |
# ASR on segment | |
text = self._transcribe_segment( | |
segment_waveform, | |
sample_rate, | |
language | |
) | |
transcript_segments.append({ | |
"start": round(turn.start, 2), | |
"end": round(turn.end, 2), | |
"speaker": speaker, | |
"text": text | |
}) | |
return self._merge_consecutive_segments(transcript_segments) | |
def _transcribe_segment(self, waveform, sample_rate, language): | |
""" | |
Transcribe audio segment menggunakan Whisper | |
""" | |
# Resample if needed | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
waveform = resampler(waveform) | |
# Prepare input | |
input_features = self.whisper_processor( | |
waveform.squeeze().numpy(), | |
sampling_rate=16000, | |
return_tensors="pt" | |
).input_features | |
# Generate transcription | |
forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids( | |
language=language, | |
task="transcribe" | |
) | |
predicted_ids = self.whisper_model.generate( | |
input_features, | |
forced_decoder_ids=forced_decoder_ids, | |
max_length=448 | |
) | |
transcription = self.whisper_processor.batch_decode( | |
predicted_ids, | |
skip_special_tokens=True | |
)[0] | |
return transcription.strip() | |
def _ensure_wav_format(self, audio_path): | |
""" | |
Convert audio to WAV format if needed | |
""" | |
if not audio_path.endswith('.wav'): | |
audio = AudioSegment.from_file(audio_path) | |
wav_path = tempfile.mktemp(suffix='.wav') | |
audio.export(wav_path, format='wav') | |
return wav_path | |
return audio_path | |
def _merge_consecutive_segments(self, segments): | |
""" | |
Merge consecutive segments from same speaker | |
""" | |
if not segments: | |
return segments | |
merged = [segments[0]] | |
for current in segments[1:]: | |
last = merged[-1] | |
# Merge if same speaker and close in time | |
if (last['speaker'] == current['speaker'] and | |
current['start'] - last['end'] < 1.0): | |
last['end'] = current['end'] | |
last['text'] += ' ' + current['text'] | |
else: | |
merged.append(current) | |
return merged |