from faster_whisper import WhisperModel from transformers import pipeline from pydub import AudioSegment import os import torchaudio import torch import re import time import sys from pathlib import Path import glob import ctypes import numpy as np from settings import DEBUG_MODE, MODEL_PATH_V2_FAST, MODEL_PATH_V1, LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH, RESAMPLING_FREQ, BATCH_SIZE, TASK def load_cudnn(): if not torch.cuda.is_available(): if DEBUG_MODE: print("[INFO] CUDA is not available, skipping cuDNN setup.") return if DEBUG_MODE: print(f"[INFO] sys.platform: {sys.platform}") if sys.platform == "win32": torch_lib_dir = Path(torch.__file__).parent / "lib" if torch_lib_dir.exists(): os.add_dll_directory(str(torch_lib_dir)) if DEBUG_MODE: print(f"[INFO] Added DLL directory: {torch_lib_dir}") else: if DEBUG_MODE: print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}") elif sys.platform == "linux": site_packages = Path(torch.__file__).resolve().parents[1] cudnn_dir = site_packages / "nvidia" / "cudnn" / "lib" if not cudnn_dir.exists(): if DEBUG_MODE: print(f"[ERROR] cudnn dir not found: {cudnn_dir}") return pattern = str(cudnn_dir / "libcudnn_cnn*.so*") matching_files = sorted(glob.glob(pattern)) if not matching_files: if DEBUG_MODE: print(f"[ERROR] No libcudnn_cnn*.so* found in {cudnn_dir}") return for so_path in matching_files: try: ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) if DEBUG_MODE: print(f"[INFO] Loaded: {so_path}") except OSError as e: if DEBUG_MODE: print(f"[WARNING] Failed to load {so_path}: {e}") else: if DEBUG_MODE: print(f"[WARNING] sys.platform is not win32 or linux") def get_settings(): is_cuda_available = torch.cuda.is_available() if is_cuda_available: device = "cuda" compute_type = "default" else: device = "cpu" compute_type = "default" if DEBUG_MODE: print(f"[SETTINGS] Device: {device}") return device, compute_type def load_model(use_v2_fast, device, compute_type): if DEBUG_MODE: print(f"[MODEL LOADING] use_v2_fast: {use_v2_fast}") if use_v2_fast: model = WhisperModel( MODEL_PATH_V2_FAST, device = device, compute_type = compute_type, ) else: model = pipeline( task="automatic-speech-recognition", model=MODEL_PATH_V1, chunk_length_s=30, device=device, token=os.getenv("HF_TOKEN") ) return model def split_input_stereo_channels(audio_path): ext = os.path.splitext(audio_path)[1].lower() if ext == ".wav": audio = AudioSegment.from_wav(audio_path) elif ext == ".mp3": audio = AudioSegment.from_file(audio_path, format="mp3") else: raise ValueError(f"[FORMAT AUDIO] Unsupported file format for: {audio_path}") channels = audio.split_to_mono() if len(channels) != 2: raise ValueError(f"[FORMAT AUDIO] Audio {audio_path} has {len(channels)} channels (instead of 2).") channels[0].export(RIGHT_CHANNEL_TEMP_PATH, format="wav") # Right channels[1].export(LEFT_CHANNEL_TEMP_PATH, format="wav") # Left def compute_type_to_audio_dtype(compute_type: str, device: str) -> np.dtype: compute_type = compute_type.lower() if device.startswith("cuda"): if "float16" in compute_type or "int8" in compute_type: audio_np_dtype = np.float16 else: audio_np_dtype = np.float32 else: audio_np_dtype = np.float32 return audio_np_dtype def format_audio(audio_path: str, compute_type: str, device: str) -> np.ndarray: input_audio, sample_rate = torchaudio.load(audio_path) if input_audio.shape[0] == 2: input_audio = torch.mean(input_audio, dim=0, keepdim=True) resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=RESAMPLING_FREQ) input_audio = resampler(input_audio) input_audio = input_audio.squeeze() np_dtype = compute_type_to_audio_dtype(compute_type, device) input_audio = input_audio.numpy().astype(np_dtype) if DEBUG_MODE: print(f"[FORMAT AUDIO] Audio dtype for actual_compute_type: {input_audio.dtype}") return input_audio def process_waveforms(device: str, compute_type: str): left_waveform = format_audio(LEFT_CHANNEL_TEMP_PATH, compute_type, device) right_waveform = format_audio(RIGHT_CHANNEL_TEMP_PATH, compute_type, device) return left_waveform, right_waveform def transcribe_pipeline(audio, model): text = model(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": TASK}, return_timestamps=True)["text"] return text def transcribe_channels(left_waveform, right_waveform, model): left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe") right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe") left_result = list(left_result) right_result = list(right_result) return left_result, right_result # TODO refactor and rename this function def post_process_transcription(transcription, max_repeats=2): tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription) cleaned_tokens = [] repetition_count = 0 previous_token = None for token in tokens: reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token) if reduced_token == previous_token: repetition_count += 1 if repetition_count <= max_repeats: cleaned_tokens.append(reduced_token) else: repetition_count = 1 cleaned_tokens.append(reduced_token) previous_token = reduced_token cleaned_transcription = " ".join(cleaned_tokens) cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip() return cleaned_transcription # TODO not used right now, decide to use it or not def post_merge_consecutive_segments_from_text(transcription_text: str) -> str: segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text) merged_transcription = '' current_speaker = None current_segment = [] for i in range(1, len(segments) - 1, 2): speaker_tag = segments[i] text = segments[i + 1].strip() speaker = re.search(r'\d{2}', speaker_tag).group() if speaker == current_speaker: current_segment.append(text) else: if current_speaker is not None: merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' current_speaker = speaker current_segment = [text] if current_speaker is not None: merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' return merged_transcription.strip() def get_segments(result, speaker_label): segments = result final_segments = [ (seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip())) for seg in segments if seg.text ] return final_segments def post_process_transcripts(left_result, right_result): left_segs = get_segments(left_result, "Speaker 1") right_segs = get_segments(right_result, "Speaker 2") merged_transcript = sorted( left_segs + right_segs, key=lambda x: float(x[0]) if x[0] is not None else float("inf") ) clean_output = "" for start, end, speaker, text in merged_transcript: clean_output += f"[{speaker}]: {text}\n" clean_output = clean_output.strip() return clean_output def cleanup_temp_files(*file_paths): for path in file_paths: if path and os.path.exists(path): if DEBUG_MODE: print(f"Removing path: {path}") os.remove(path) def generate(audio_path, use_v2_fast): load_cudnn() device, requested_compute_type = get_settings() model = load_model(use_v2_fast, device, requested_compute_type) if use_v2_fast: actual_compute_type = model.model.compute_type else: actual_compute_type = "float32" #HF pipeline safe default if DEBUG_MODE: print(f"[SETTINGS] Requested compute_type: {requested_compute_type}") print(f"[SETTINGS] Actual compute_type: {actual_compute_type}") if use_v2_fast: split_input_stereo_channels(audio_path) left_waveform, right_waveform = process_waveforms(device, actual_compute_type) left_result, right_result = transcribe_channels(left_waveform, right_waveform, model) output = post_process_transcripts(left_result, right_result) cleanup_temp_files(LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH) else: audio = format_audio(audio_path, actual_compute_type, device) merged_results = transcribe_pipeline(audio, model) output = post_process_transcription(merged_results) return output