File size: 5,689 Bytes
6a2757d f0785c9 4b615a0 f0785c9 6a2757d 062214b 4990c86 8ff3a7b 152c35a 4b615a0 4990c86 4b615a0 f0785c9 4b615a0 4990c86 f0785c9 4990c86 4c8e2a9 f0785c9 4990c86 4b615a0 6a2757d 4990c86 4b615a0 f0785c9 4b615a0 4990c86 4b615a0 6a2757d 4b615a0 4990c86 6a2757d 4990c86 6a2757d 4990c86 6a2757d f0785c9 062214b 6a2757d f0785c9 6a2757d 4b615a0 d93cc71 f0785c9 e98cc49 4990c86 6a2757d e98cc49 4990c86 6a2757d 4990c86 f0785c9 6a2757d 4990c86 f0785c9 e98cc49 d93cc71 6a2757d f0785c9 4990c86 f0785c9 4990c86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# handler.py
import base64
import os
import uuid
import time
import binascii
import torch
from datetime import timedelta
import faster_whisper
class EndpointHandler:
"""
A custom handler for a Hugging Face Inference Endpoint that transcribes
audio using a high-performance faster-whisper model.
This handler is adapted from a multi-cell Colab notebook, combining model
loading, audio processing, transcription, and subtitle generation into
a single, robust API call.
"""
def __init__(self, path=""):
"""
Initializes the handler by loading the specified faster-whisper model.
It automatically detects and uses a GPU (CUDA) if available.
"""
# Model ID from the Colab notebook
model_id = "ivrit-ai/whisper-large-v3-turbo-ct2"
# Reliable GPU detection
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
print(f"Initializing model '{model_id}' on device '{device}' with compute_type '{compute_type}'...")
# Load the transcription model
self.model = faster_whisper.WhisperModel(model_id, device=device, compute_type=compute_type)
print("✅ Model loaded successfully.")
# --- Helper functions adapted from Colab notebook ---
def _format_timestamp(self, seconds, format_type="srt"):
"""Formats seconds into SRT or VTT timestamp format."""
if seconds is None: return "00:00:00,000" if format_type == "srt" else "00:00:00.000"
delta = timedelta(seconds=seconds)
hours, remainder = divmod(int(delta.total_seconds()), 3600)
minutes, sec = divmod(remainder, 60)
milliseconds = delta.microseconds // 1000
separator = "," if format_type == "srt" else "."
return f"{hours:02d}:{minutes:02d}:{sec:02d}{separator}{milliseconds:03d}"
def _generate_srt_content(self, segments):
"""Generates SRT formatted subtitle content from transcription segments."""
srt_content = []
for i, segment in enumerate(segments):
start_time, end_time = segment.start, segment.end
srt_content.append(str(i + 1))
srt_content.append(f"{self._format_timestamp(start_time, 'srt')} --> {self._format_timestamp(end_time, 'srt')}")
srt_content.append(segment.text.strip())
srt_content.append("")
return "\n".join(srt_content)
def _generate_vtt_content(self, segments):
"""Generates VTT formatted subtitle content from transcription segments."""
vtt_content = ["WEBVTT", ""]
for segment in segments:
start_time, end_time = segment.start, segment.end
vtt_content.append(f"{self._format_timestamp(start_time, 'vtt')} --> {self._format_timestamp(end_time, 'vtt')}")
vtt_content.append(segment.text.strip())
vtt_content.append("")
return "\n".join(vtt_content)
def __call__(self, data):
"""
Handles a single API request for audio transcription.
"""
start_time = time.time()
# 1. Extract audio (as base64) and parameters from the payload
try:
audio_base64 = data.get("inputs")
if not audio_base64:
return {"error": "Missing 'inputs' key with base64 audio string.", "error_type": "Bad Request"}
params = data.get("parameters", {})
language = params.get("language", "he")
beam_size = int(params.get("beam_size", 5))
word_timestamps = bool(params.get("word_timestamps", True))
except Exception as e:
return {"error": f"Error parsing input data: {e}", "error_type": "Bad Request"}
# 2. Decode the base64 string and save to a temporary file
try:
audio_bytes = base64.b64decode(audio_base64)
except (TypeError, binascii.Error) as e:
return {"error": f"Invalid base64 string provided: {e}", "error_type": "Bad Request"}
temp_filename = os.path.join("/tmp", f"{uuid.uuid4()}.mp3")
try:
with open(temp_filename, "wb") as f:
f.write(audio_bytes)
# 3. Run transcription with the specified parameters
segments, info = self.model.transcribe(
temp_filename,
language=language,
beam_size=beam_size,
word_timestamps=word_timestamps
)
# Segments is a generator, so we convert it to a list to reuse it
segment_list = list(segments)
# 4. Generate full text and subtitle formats
full_text = " ".join(s.text.strip() for s in segment_list)
srt_content = self._generate_srt_content(segment_list)
vtt_content = self._generate_vtt_content(segment_list)
processing_time = time.time() - start_time
# 5. Return the complete response
return {
"text": full_text,
"srt": srt_content,
"vtt": vtt_content,
"metadata": {
"language": info.language,
"language_probability": round(info.language_probability, 2),
"audio_duration_seconds": round(info.duration, 2),
"processing_time_seconds": round(processing_time, 2),
}
}
except Exception as e:
return {"error": str(e), "error_type": "Inference Error"}
finally:
# 6. Clean up the temporary file
if os.path.exists(temp_filename):
os.remove(temp_filename)
|