# handler.py import base64 import os import uuid import time import binascii import torch from datetime import timedelta import faster_whisper class EndpointHandler: """ A custom handler for a Hugging Face Inference Endpoint that transcribes audio using a high-performance faster-whisper model. This handler is adapted from a multi-cell Colab notebook, combining model loading, audio processing, transcription, and subtitle generation into a single, robust API call. """ def __init__(self, path=""): """ Initializes the handler by loading the specified faster-whisper model. It automatically detects and uses a GPU (CUDA) if available. """ # Model ID from the Colab notebook model_id = "ivrit-ai/whisper-large-v3-turbo-ct2" # Reliable GPU detection device = "cuda" if torch.cuda.is_available() else "cpu" compute_type = "float16" if device == "cuda" else "int8" print(f"Initializing model '{model_id}' on device '{device}' with compute_type '{compute_type}'...") # Load the transcription model self.model = faster_whisper.WhisperModel(model_id, device=device, compute_type=compute_type) print("✅ Model loaded successfully.") # --- Helper functions adapted from Colab notebook --- def _format_timestamp(self, seconds, format_type="srt"): """Formats seconds into SRT or VTT timestamp format.""" if seconds is None: return "00:00:00,000" if format_type == "srt" else "00:00:00.000" delta = timedelta(seconds=seconds) hours, remainder = divmod(int(delta.total_seconds()), 3600) minutes, sec = divmod(remainder, 60) milliseconds = delta.microseconds // 1000 separator = "," if format_type == "srt" else "." return f"{hours:02d}:{minutes:02d}:{sec:02d}{separator}{milliseconds:03d}" def _generate_srt_content(self, segments): """Generates SRT formatted subtitle content from transcription segments.""" srt_content = [] for i, segment in enumerate(segments): start_time, end_time = segment.start, segment.end srt_content.append(str(i + 1)) srt_content.append(f"{self._format_timestamp(start_time, 'srt')} --> {self._format_timestamp(end_time, 'srt')}") srt_content.append(segment.text.strip()) srt_content.append("") return "\n".join(srt_content) def _generate_vtt_content(self, segments): """Generates VTT formatted subtitle content from transcription segments.""" vtt_content = ["WEBVTT", ""] for segment in segments: start_time, end_time = segment.start, segment.end vtt_content.append(f"{self._format_timestamp(start_time, 'vtt')} --> {self._format_timestamp(end_time, 'vtt')}") vtt_content.append(segment.text.strip()) vtt_content.append("") return "\n".join(vtt_content) def __call__(self, data): """ Handles a single API request for audio transcription. """ start_time = time.time() # 1. Extract audio (as base64) and parameters from the payload try: audio_base64 = data.get("inputs") if not audio_base64: return {"error": "Missing 'inputs' key with base64 audio string.", "error_type": "Bad Request"} params = data.get("parameters", {}) language = params.get("language", "he") beam_size = int(params.get("beam_size", 5)) word_timestamps = bool(params.get("word_timestamps", True)) except Exception as e: return {"error": f"Error parsing input data: {e}", "error_type": "Bad Request"} # 2. Decode the base64 string and save to a temporary file try: audio_bytes = base64.b64decode(audio_base64) except (TypeError, binascii.Error) as e: return {"error": f"Invalid base64 string provided: {e}", "error_type": "Bad Request"} temp_filename = os.path.join("/tmp", f"{uuid.uuid4()}.mp3") try: with open(temp_filename, "wb") as f: f.write(audio_bytes) # 3. Run transcription with the specified parameters segments, info = self.model.transcribe( temp_filename, language=language, beam_size=beam_size, word_timestamps=word_timestamps ) # Segments is a generator, so we convert it to a list to reuse it segment_list = list(segments) # 4. Generate full text and subtitle formats full_text = " ".join(s.text.strip() for s in segment_list) srt_content = self._generate_srt_content(segment_list) vtt_content = self._generate_vtt_content(segment_list) processing_time = time.time() - start_time # 5. Return the complete response return { "text": full_text, "srt": srt_content, "vtt": vtt_content, "metadata": { "language": info.language, "language_probability": round(info.language_probability, 2), "audio_duration_seconds": round(info.duration, 2), "processing_time_seconds": round(processing_time, 2), } } except Exception as e: return {"error": str(e), "error_type": "Inference Error"} finally: # 6. Clean up the temporary file if os.path.exists(temp_filename): os.remove(temp_filename)