Spaces:
Sleeping
Sleeping
File size: 4,240 Bytes
5da9a16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
import torchaudio
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
pipeline
)
from pyannote.audio import Pipeline
import librosa
import numpy as np
from pydub import AudioSegment
import tempfile
class SpeechProcessor:
def __init__(self):
# Load Whisper for ASR
self.whisper_processor = WhisperProcessor.from_pretrained(
"openai/whisper-medium"
)
self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-medium"
)
# Load speaker diarization
self.diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=os.environ.get("HF_TOKEN")
)
def process_audio(self, audio_path, language="id"):
"""
Process audio file untuk ASR dan speaker diarization
"""
# Convert to WAV if needed
audio_path = self._ensure_wav_format(audio_path)
# Load audio
waveform, sample_rate = torchaudio.load(audio_path)
# Speaker diarization
diarization = self.diarization_pipeline(audio_path)
# Process each speaker segment
transcript_segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
# Extract segment audio
start_sample = int(turn.start * sample_rate)
end_sample = int(turn.end * sample_rate)
segment_waveform = waveform[:, start_sample:end_sample]
# ASR on segment
text = self._transcribe_segment(
segment_waveform,
sample_rate,
language
)
transcript_segments.append({
"start": round(turn.start, 2),
"end": round(turn.end, 2),
"speaker": speaker,
"text": text
})
return self._merge_consecutive_segments(transcript_segments)
def _transcribe_segment(self, waveform, sample_rate, language):
"""
Transcribe audio segment menggunakan Whisper
"""
# Resample if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
# Prepare input
input_features = self.whisper_processor(
waveform.squeeze().numpy(),
sampling_rate=16000,
return_tensors="pt"
).input_features
# Generate transcription
forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(
language=language,
task="transcribe"
)
predicted_ids = self.whisper_model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids,
max_length=448
)
transcription = self.whisper_processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
return transcription.strip()
def _ensure_wav_format(self, audio_path):
"""
Convert audio to WAV format if needed
"""
if not audio_path.endswith('.wav'):
audio = AudioSegment.from_file(audio_path)
wav_path = tempfile.mktemp(suffix='.wav')
audio.export(wav_path, format='wav')
return wav_path
return audio_path
def _merge_consecutive_segments(self, segments):
"""
Merge consecutive segments from same speaker
"""
if not segments:
return segments
merged = [segments[0]]
for current in segments[1:]:
last = merged[-1]
# Merge if same speaker and close in time
if (last['speaker'] == current['speaker'] and
current['start'] - last['end'] < 1.0):
last['end'] = current['end']
last['text'] += ' ' + current['text']
else:
merged.append(current)
return merged |