meeting-minutes-ai / utils /speech_processor.py
Yermia's picture
First init
5da9a16
raw
history blame
4.24 kB
import torch
import torchaudio
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
pipeline
)
from pyannote.audio import Pipeline
import librosa
import numpy as np
from pydub import AudioSegment
import tempfile
class SpeechProcessor:
def __init__(self):
# Load Whisper for ASR
self.whisper_processor = WhisperProcessor.from_pretrained(
"openai/whisper-medium"
)
self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-medium"
)
# Load speaker diarization
self.diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=os.environ.get("HF_TOKEN")
)
def process_audio(self, audio_path, language="id"):
"""
Process audio file untuk ASR dan speaker diarization
"""
# Convert to WAV if needed
audio_path = self._ensure_wav_format(audio_path)
# Load audio
waveform, sample_rate = torchaudio.load(audio_path)
# Speaker diarization
diarization = self.diarization_pipeline(audio_path)
# Process each speaker segment
transcript_segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
# Extract segment audio
start_sample = int(turn.start * sample_rate)
end_sample = int(turn.end * sample_rate)
segment_waveform = waveform[:, start_sample:end_sample]
# ASR on segment
text = self._transcribe_segment(
segment_waveform,
sample_rate,
language
)
transcript_segments.append({
"start": round(turn.start, 2),
"end": round(turn.end, 2),
"speaker": speaker,
"text": text
})
return self._merge_consecutive_segments(transcript_segments)
def _transcribe_segment(self, waveform, sample_rate, language):
"""
Transcribe audio segment menggunakan Whisper
"""
# Resample if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
# Prepare input
input_features = self.whisper_processor(
waveform.squeeze().numpy(),
sampling_rate=16000,
return_tensors="pt"
).input_features
# Generate transcription
forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(
language=language,
task="transcribe"
)
predicted_ids = self.whisper_model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids,
max_length=448
)
transcription = self.whisper_processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
return transcription.strip()
def _ensure_wav_format(self, audio_path):
"""
Convert audio to WAV format if needed
"""
if not audio_path.endswith('.wav'):
audio = AudioSegment.from_file(audio_path)
wav_path = tempfile.mktemp(suffix='.wav')
audio.export(wav_path, format='wav')
return wav_path
return audio_path
def _merge_consecutive_segments(self, segments):
"""
Merge consecutive segments from same speaker
"""
if not segments:
return segments
merged = [segments[0]]
for current in segments[1:]:
last = merged[-1]
# Merge if same speaker and close in time
if (last['speaker'] == current['speaker'] and
current['start'] - last['end'] < 1.0):
last['end'] = current['end']
last['text'] += ' ' + current['text']
else:
merged.append(current)
return merged