Update handler.py
Browse files- handler.py +76 -50
handler.py
CHANGED
@@ -5,64 +5,93 @@ import os
|
|
5 |
import uuid
|
6 |
import time
|
7 |
import binascii
|
8 |
-
import torch
|
9 |
-
from
|
|
|
10 |
|
11 |
class EndpointHandler:
|
12 |
"""
|
13 |
-
A
|
14 |
-
audio using
|
15 |
-
|
|
|
|
|
|
|
16 |
"""
|
17 |
def __init__(self, path=""):
|
18 |
"""
|
19 |
-
Initializes the handler by loading the faster-whisper model.
|
20 |
-
|
21 |
"""
|
22 |
-
#
|
23 |
-
|
|
|
|
|
24 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
compute_type = "float16" if device == "cuda" else "int8"
|
26 |
-
|
|
|
27 |
|
28 |
# Load the transcription model
|
29 |
-
self.model = WhisperModel(
|
30 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def __call__(self, data):
|
33 |
"""
|
34 |
-
|
35 |
-
and returns a detailed JSON response.
|
36 |
-
|
37 |
-
Args:
|
38 |
-
data (dict): The input data from the request.
|
39 |
-
|
40 |
-
Returns:
|
41 |
-
dict: A detailed dictionary containing the transcription or an error.
|
42 |
"""
|
43 |
start_time = time.time()
|
44 |
|
45 |
-
# 1. Extract audio and parameters from the payload
|
46 |
try:
|
47 |
-
# The base64 audio string is expected in the 'inputs' key
|
48 |
audio_base64 = data.get("inputs")
|
49 |
if not audio_base64:
|
50 |
return {"error": "Missing 'inputs' key with base64 audio string.", "error_type": "Bad Request"}
|
51 |
|
52 |
-
# Transcription parameters are in the 'parameters' key
|
53 |
params = data.get("parameters", {})
|
54 |
-
language = params.get("language"
|
55 |
beam_size = int(params.get("beam_size", 5))
|
56 |
-
word_timestamps = bool(params.get("word_timestamps",
|
57 |
|
58 |
except Exception as e:
|
59 |
return {"error": f"Error parsing input data: {e}", "error_type": "Bad Request"}
|
60 |
|
61 |
-
|
62 |
# 2. Decode the base64 string and save to a temporary file
|
63 |
try:
|
64 |
audio_bytes = base64.b64decode(audio_base64)
|
65 |
-
file_size_mb = len(audio_bytes) / (1024 * 1024)
|
66 |
except (TypeError, binascii.Error) as e:
|
67 |
return {"error": f"Invalid base64 string provided: {e}", "error_type": "Bad Request"}
|
68 |
|
@@ -71,10 +100,8 @@ class EndpointHandler:
|
|
71 |
try:
|
72 |
with open(temp_filename, "wb") as f:
|
73 |
f.write(audio_bytes)
|
74 |
-
|
75 |
-
print(f"Temporarily saved {file_size_mb:.2f} MB audio to {temp_filename}")
|
76 |
|
77 |
-
# 3. Run
|
78 |
segments, info = self.model.transcribe(
|
79 |
temp_filename,
|
80 |
language=language,
|
@@ -82,35 +109,34 @@ class EndpointHandler:
|
|
82 |
word_timestamps=word_timestamps
|
83 |
)
|
84 |
|
85 |
-
#
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
91 |
processing_time = time.time() - start_time
|
92 |
-
|
93 |
-
print(f"Transcription successful in {processing_time:.2f} seconds.")
|
94 |
|
95 |
-
# 5. Return the
|
96 |
return {
|
97 |
"text": full_text,
|
98 |
-
"
|
99 |
-
"
|
100 |
-
"
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
105 |
}
|
106 |
|
107 |
except Exception as e:
|
108 |
-
# Catch any other exceptions during file handling or transcription
|
109 |
-
print(f"An error occurred during transcription: {e}")
|
110 |
return {"error": str(e), "error_type": "Inference Error"}
|
111 |
|
112 |
finally:
|
113 |
-
# 6. Clean up
|
114 |
if os.path.exists(temp_filename):
|
115 |
os.remove(temp_filename)
|
116 |
-
|
|
|
5 |
import uuid
|
6 |
import time
|
7 |
import binascii
|
8 |
+
import torch
|
9 |
+
from datetime import timedelta
|
10 |
+
import faster_whisper
|
11 |
|
12 |
class EndpointHandler:
|
13 |
"""
|
14 |
+
A custom handler for a Hugging Face Inference Endpoint that transcribes
|
15 |
+
audio using a high-performance faster-whisper model.
|
16 |
+
|
17 |
+
This handler is adapted from a multi-cell Colab notebook, combining model
|
18 |
+
loading, audio processing, transcription, and subtitle generation into
|
19 |
+
a single, robust API call.
|
20 |
"""
|
21 |
def __init__(self, path=""):
|
22 |
"""
|
23 |
+
Initializes the handler by loading the specified faster-whisper model.
|
24 |
+
It automatically detects and uses a GPU (CUDA) if available.
|
25 |
"""
|
26 |
+
# Model ID from the Colab notebook
|
27 |
+
model_id = "ivrit-ai/whisper-large-v3-turbo-ct2"
|
28 |
+
|
29 |
+
# Reliable GPU detection
|
30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
compute_type = "float16" if device == "cuda" else "int8"
|
32 |
+
|
33 |
+
print(f"Initializing model '{model_id}' on device '{device}' with compute_type '{compute_type}'...")
|
34 |
|
35 |
# Load the transcription model
|
36 |
+
self.model = faster_whisper.WhisperModel(model_id, device=device, compute_type=compute_type)
|
37 |
+
print("✅ Model loaded successfully.")
|
38 |
+
|
39 |
+
# --- Helper functions adapted from Colab notebook ---
|
40 |
+
|
41 |
+
def _format_timestamp(self, seconds, format_type="srt"):
|
42 |
+
"""Formats seconds into SRT or VTT timestamp format."""
|
43 |
+
if seconds is None: return "00:00:00,000" if format_type == "srt" else "00:00:00.000"
|
44 |
+
delta = timedelta(seconds=seconds)
|
45 |
+
hours, remainder = divmod(int(delta.total_seconds()), 3600)
|
46 |
+
minutes, sec = divmod(remainder, 60)
|
47 |
+
milliseconds = delta.microseconds // 1000
|
48 |
+
separator = "," if format_type == "srt" else "."
|
49 |
+
return f"{hours:02d}:{minutes:02d}:{sec:02d}{separator}{milliseconds:03d}"
|
50 |
+
|
51 |
+
def _generate_srt_content(self, segments):
|
52 |
+
"""Generates SRT formatted subtitle content from transcription segments."""
|
53 |
+
srt_content = []
|
54 |
+
for i, segment in enumerate(segments):
|
55 |
+
start_time, end_time = segment.start, segment.end
|
56 |
+
srt_content.append(str(i + 1))
|
57 |
+
srt_content.append(f"{self._format_timestamp(start_time, 'srt')} --> {self._format_timestamp(end_time, 'srt')}")
|
58 |
+
srt_content.append(segment.text.strip())
|
59 |
+
srt_content.append("")
|
60 |
+
return "\n".join(srt_content)
|
61 |
+
|
62 |
+
def _generate_vtt_content(self, segments):
|
63 |
+
"""Generates VTT formatted subtitle content from transcription segments."""
|
64 |
+
vtt_content = ["WEBVTT", ""]
|
65 |
+
for segment in segments:
|
66 |
+
start_time, end_time = segment.start, segment.end
|
67 |
+
vtt_content.append(f"{self._format_timestamp(start_time, 'vtt')} --> {self._format_timestamp(end_time, 'vtt')}")
|
68 |
+
vtt_content.append(segment.text.strip())
|
69 |
+
vtt_content.append("")
|
70 |
+
return "\n".join(vtt_content)
|
71 |
|
72 |
def __call__(self, data):
|
73 |
"""
|
74 |
+
Handles a single API request for audio transcription.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
"""
|
76 |
start_time = time.time()
|
77 |
|
78 |
+
# 1. Extract audio (as base64) and parameters from the payload
|
79 |
try:
|
|
|
80 |
audio_base64 = data.get("inputs")
|
81 |
if not audio_base64:
|
82 |
return {"error": "Missing 'inputs' key with base64 audio string.", "error_type": "Bad Request"}
|
83 |
|
|
|
84 |
params = data.get("parameters", {})
|
85 |
+
language = params.get("language", "he")
|
86 |
beam_size = int(params.get("beam_size", 5))
|
87 |
+
word_timestamps = bool(params.get("word_timestamps", True))
|
88 |
|
89 |
except Exception as e:
|
90 |
return {"error": f"Error parsing input data: {e}", "error_type": "Bad Request"}
|
91 |
|
|
|
92 |
# 2. Decode the base64 string and save to a temporary file
|
93 |
try:
|
94 |
audio_bytes = base64.b64decode(audio_base64)
|
|
|
95 |
except (TypeError, binascii.Error) as e:
|
96 |
return {"error": f"Invalid base64 string provided: {e}", "error_type": "Bad Request"}
|
97 |
|
|
|
100 |
try:
|
101 |
with open(temp_filename, "wb") as f:
|
102 |
f.write(audio_bytes)
|
|
|
|
|
103 |
|
104 |
+
# 3. Run transcription with the specified parameters
|
105 |
segments, info = self.model.transcribe(
|
106 |
temp_filename,
|
107 |
language=language,
|
|
|
109 |
word_timestamps=word_timestamps
|
110 |
)
|
111 |
|
112 |
+
# Segments is a generator, so we convert it to a list to reuse it
|
113 |
+
segment_list = list(segments)
|
114 |
+
|
115 |
+
# 4. Generate full text and subtitle formats
|
116 |
+
full_text = " ".join(s.text.strip() for s in segment_list)
|
117 |
+
srt_content = self._generate_srt_content(segment_list)
|
118 |
+
vtt_content = self._generate_vtt_content(segment_list)
|
119 |
+
|
120 |
processing_time = time.time() - start_time
|
|
|
|
|
121 |
|
122 |
+
# 5. Return the complete response
|
123 |
return {
|
124 |
"text": full_text,
|
125 |
+
"srt": srt_content,
|
126 |
+
"vtt": vtt_content,
|
127 |
+
"metadata": {
|
128 |
+
"language": info.language,
|
129 |
+
"language_probability": round(info.language_probability, 2),
|
130 |
+
"audio_duration_seconds": round(info.duration, 2),
|
131 |
+
"processing_time_seconds": round(processing_time, 2),
|
132 |
+
}
|
133 |
}
|
134 |
|
135 |
except Exception as e:
|
|
|
|
|
136 |
return {"error": str(e), "error_type": "Inference Error"}
|
137 |
|
138 |
finally:
|
139 |
+
# 6. Clean up the temporary file
|
140 |
if os.path.exists(temp_filename):
|
141 |
os.remove(temp_filename)
|
142 |
+
|