|
|
|
|
|
import base64 |
|
import os |
|
import uuid |
|
import time |
|
import binascii |
|
import torch |
|
from datetime import timedelta |
|
import faster_whisper |
|
|
|
class EndpointHandler: |
|
""" |
|
A custom handler for a Hugging Face Inference Endpoint that transcribes |
|
audio using a high-performance faster-whisper model. |
|
|
|
This handler is adapted from a multi-cell Colab notebook, combining model |
|
loading, audio processing, transcription, and subtitle generation into |
|
a single, robust API call. |
|
""" |
|
def __init__(self, path=""): |
|
""" |
|
Initializes the handler by loading the specified faster-whisper model. |
|
It automatically detects and uses a GPU (CUDA) if available. |
|
""" |
|
|
|
model_id = "ivrit-ai/whisper-large-v3-turbo-ct2" |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
compute_type = "float16" if device == "cuda" else "int8" |
|
|
|
print(f"Initializing model '{model_id}' on device '{device}' with compute_type '{compute_type}'...") |
|
|
|
|
|
self.model = faster_whisper.WhisperModel(model_id, device=device, compute_type=compute_type) |
|
print("✅ Model loaded successfully.") |
|
|
|
|
|
|
|
def _format_timestamp(self, seconds, format_type="srt"): |
|
"""Formats seconds into SRT or VTT timestamp format.""" |
|
if seconds is None: return "00:00:00,000" if format_type == "srt" else "00:00:00.000" |
|
delta = timedelta(seconds=seconds) |
|
hours, remainder = divmod(int(delta.total_seconds()), 3600) |
|
minutes, sec = divmod(remainder, 60) |
|
milliseconds = delta.microseconds // 1000 |
|
separator = "," if format_type == "srt" else "." |
|
return f"{hours:02d}:{minutes:02d}:{sec:02d}{separator}{milliseconds:03d}" |
|
|
|
def _generate_srt_content(self, segments): |
|
"""Generates SRT formatted subtitle content from transcription segments.""" |
|
srt_content = [] |
|
for i, segment in enumerate(segments): |
|
start_time, end_time = segment.start, segment.end |
|
srt_content.append(str(i + 1)) |
|
srt_content.append(f"{self._format_timestamp(start_time, 'srt')} --> {self._format_timestamp(end_time, 'srt')}") |
|
srt_content.append(segment.text.strip()) |
|
srt_content.append("") |
|
return "\n".join(srt_content) |
|
|
|
def _generate_vtt_content(self, segments): |
|
"""Generates VTT formatted subtitle content from transcription segments.""" |
|
vtt_content = ["WEBVTT", ""] |
|
for segment in segments: |
|
start_time, end_time = segment.start, segment.end |
|
vtt_content.append(f"{self._format_timestamp(start_time, 'vtt')} --> {self._format_timestamp(end_time, 'vtt')}") |
|
vtt_content.append(segment.text.strip()) |
|
vtt_content.append("") |
|
return "\n".join(vtt_content) |
|
|
|
def __call__(self, data): |
|
""" |
|
Handles a single API request for audio transcription. |
|
""" |
|
start_time = time.time() |
|
|
|
|
|
try: |
|
audio_base64 = data.get("inputs") |
|
if not audio_base64: |
|
return {"error": "Missing 'inputs' key with base64 audio string.", "error_type": "Bad Request"} |
|
|
|
params = data.get("parameters", {}) |
|
language = params.get("language", "he") |
|
beam_size = int(params.get("beam_size", 5)) |
|
word_timestamps = bool(params.get("word_timestamps", True)) |
|
|
|
except Exception as e: |
|
return {"error": f"Error parsing input data: {e}", "error_type": "Bad Request"} |
|
|
|
|
|
try: |
|
audio_bytes = base64.b64decode(audio_base64) |
|
except (TypeError, binascii.Error) as e: |
|
return {"error": f"Invalid base64 string provided: {e}", "error_type": "Bad Request"} |
|
|
|
temp_filename = os.path.join("/tmp", f"{uuid.uuid4()}.mp3") |
|
|
|
try: |
|
with open(temp_filename, "wb") as f: |
|
f.write(audio_bytes) |
|
|
|
|
|
segments, info = self.model.transcribe( |
|
temp_filename, |
|
language=language, |
|
beam_size=beam_size, |
|
word_timestamps=word_timestamps |
|
) |
|
|
|
|
|
segment_list = list(segments) |
|
|
|
|
|
full_text = " ".join(s.text.strip() for s in segment_list) |
|
srt_content = self._generate_srt_content(segment_list) |
|
vtt_content = self._generate_vtt_content(segment_list) |
|
|
|
processing_time = time.time() - start_time |
|
|
|
|
|
return { |
|
"text": full_text, |
|
"srt": srt_content, |
|
"vtt": vtt_content, |
|
"metadata": { |
|
"language": info.language, |
|
"language_probability": round(info.language_probability, 2), |
|
"audio_duration_seconds": round(info.duration, 2), |
|
"processing_time_seconds": round(processing_time, 2), |
|
} |
|
} |
|
|
|
except Exception as e: |
|
return {"error": str(e), "error_type": "Inference Error"} |
|
|
|
finally: |
|
|
|
if os.path.exists(temp_filename): |
|
os.remove(temp_filename) |
|
|
|
|