Hebrew
meirk's picture
Update handler.py
4990c86 verified
# handler.py
import base64
import os
import uuid
import time
import binascii
import torch
from datetime import timedelta
import faster_whisper
class EndpointHandler:
"""
A custom handler for a Hugging Face Inference Endpoint that transcribes
audio using a high-performance faster-whisper model.
This handler is adapted from a multi-cell Colab notebook, combining model
loading, audio processing, transcription, and subtitle generation into
a single, robust API call.
"""
def __init__(self, path=""):
"""
Initializes the handler by loading the specified faster-whisper model.
It automatically detects and uses a GPU (CUDA) if available.
"""
# Model ID from the Colab notebook
model_id = "ivrit-ai/whisper-large-v3-turbo-ct2"
# Reliable GPU detection
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
print(f"Initializing model '{model_id}' on device '{device}' with compute_type '{compute_type}'...")
# Load the transcription model
self.model = faster_whisper.WhisperModel(model_id, device=device, compute_type=compute_type)
print("✅ Model loaded successfully.")
# --- Helper functions adapted from Colab notebook ---
def _format_timestamp(self, seconds, format_type="srt"):
"""Formats seconds into SRT or VTT timestamp format."""
if seconds is None: return "00:00:00,000" if format_type == "srt" else "00:00:00.000"
delta = timedelta(seconds=seconds)
hours, remainder = divmod(int(delta.total_seconds()), 3600)
minutes, sec = divmod(remainder, 60)
milliseconds = delta.microseconds // 1000
separator = "," if format_type == "srt" else "."
return f"{hours:02d}:{minutes:02d}:{sec:02d}{separator}{milliseconds:03d}"
def _generate_srt_content(self, segments):
"""Generates SRT formatted subtitle content from transcription segments."""
srt_content = []
for i, segment in enumerate(segments):
start_time, end_time = segment.start, segment.end
srt_content.append(str(i + 1))
srt_content.append(f"{self._format_timestamp(start_time, 'srt')} --> {self._format_timestamp(end_time, 'srt')}")
srt_content.append(segment.text.strip())
srt_content.append("")
return "\n".join(srt_content)
def _generate_vtt_content(self, segments):
"""Generates VTT formatted subtitle content from transcription segments."""
vtt_content = ["WEBVTT", ""]
for segment in segments:
start_time, end_time = segment.start, segment.end
vtt_content.append(f"{self._format_timestamp(start_time, 'vtt')} --> {self._format_timestamp(end_time, 'vtt')}")
vtt_content.append(segment.text.strip())
vtt_content.append("")
return "\n".join(vtt_content)
def __call__(self, data):
"""
Handles a single API request for audio transcription.
"""
start_time = time.time()
# 1. Extract audio (as base64) and parameters from the payload
try:
audio_base64 = data.get("inputs")
if not audio_base64:
return {"error": "Missing 'inputs' key with base64 audio string.", "error_type": "Bad Request"}
params = data.get("parameters", {})
language = params.get("language", "he")
beam_size = int(params.get("beam_size", 5))
word_timestamps = bool(params.get("word_timestamps", True))
except Exception as e:
return {"error": f"Error parsing input data: {e}", "error_type": "Bad Request"}
# 2. Decode the base64 string and save to a temporary file
try:
audio_bytes = base64.b64decode(audio_base64)
except (TypeError, binascii.Error) as e:
return {"error": f"Invalid base64 string provided: {e}", "error_type": "Bad Request"}
temp_filename = os.path.join("/tmp", f"{uuid.uuid4()}.mp3")
try:
with open(temp_filename, "wb") as f:
f.write(audio_bytes)
# 3. Run transcription with the specified parameters
segments, info = self.model.transcribe(
temp_filename,
language=language,
beam_size=beam_size,
word_timestamps=word_timestamps
)
# Segments is a generator, so we convert it to a list to reuse it
segment_list = list(segments)
# 4. Generate full text and subtitle formats
full_text = " ".join(s.text.strip() for s in segment_list)
srt_content = self._generate_srt_content(segment_list)
vtt_content = self._generate_vtt_content(segment_list)
processing_time = time.time() - start_time
# 5. Return the complete response
return {
"text": full_text,
"srt": srt_content,
"vtt": vtt_content,
"metadata": {
"language": info.language,
"language_probability": round(info.language_probability, 2),
"audio_duration_seconds": round(info.duration, 2),
"processing_time_seconds": round(processing_time, 2),
}
}
except Exception as e:
return {"error": str(e), "error_type": "Inference Error"}
finally:
# 6. Clean up the temporary file
if os.path.exists(temp_filename):
os.remove(temp_filename)