import torch import torchaudio from transformers import ( WhisperProcessor, WhisperForConditionalGeneration, pipeline ) from pyannote.audio import Pipeline import librosa import numpy as np from pydub import AudioSegment import tempfile class SpeechProcessor: def __init__(self): # Load Whisper for ASR self.whisper_processor = WhisperProcessor.from_pretrained( "openai/whisper-medium" ) self.whisper_model = WhisperForConditionalGeneration.from_pretrained( "openai/whisper-medium" ) # Load speaker diarization self.diarization_pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=os.environ.get("HF_TOKEN") ) def process_audio(self, audio_path, language="id"): """ Process audio file untuk ASR dan speaker diarization """ # Convert to WAV if needed audio_path = self._ensure_wav_format(audio_path) # Load audio waveform, sample_rate = torchaudio.load(audio_path) # Speaker diarization diarization = self.diarization_pipeline(audio_path) # Process each speaker segment transcript_segments = [] for turn, _, speaker in diarization.itertracks(yield_label=True): # Extract segment audio start_sample = int(turn.start * sample_rate) end_sample = int(turn.end * sample_rate) segment_waveform = waveform[:, start_sample:end_sample] # ASR on segment text = self._transcribe_segment( segment_waveform, sample_rate, language ) transcript_segments.append({ "start": round(turn.start, 2), "end": round(turn.end, 2), "speaker": speaker, "text": text }) return self._merge_consecutive_segments(transcript_segments) def _transcribe_segment(self, waveform, sample_rate, language): """ Transcribe audio segment menggunakan Whisper """ # Resample if needed if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) # Prepare input input_features = self.whisper_processor( waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt" ).input_features # Generate transcription forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids( language=language, task="transcribe" ) predicted_ids = self.whisper_model.generate( input_features, forced_decoder_ids=forced_decoder_ids, max_length=448 ) transcription = self.whisper_processor.batch_decode( predicted_ids, skip_special_tokens=True )[0] return transcription.strip() def _ensure_wav_format(self, audio_path): """ Convert audio to WAV format if needed """ if not audio_path.endswith('.wav'): audio = AudioSegment.from_file(audio_path) wav_path = tempfile.mktemp(suffix='.wav') audio.export(wav_path, format='wav') return wav_path return audio_path def _merge_consecutive_segments(self, segments): """ Merge consecutive segments from same speaker """ if not segments: return segments merged = [segments[0]] for current in segments[1:]: last = merged[-1] # Merge if same speaker and close in time if (last['speaker'] == current['speaker'] and current['start'] - last['end'] < 1.0): last['end'] = current['end'] last['text'] += ' ' + current['text'] else: merged.append(current) return merged