Hebrew
meirk commited on
Commit
4990c86
·
verified ·
1 Parent(s): 136106c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +76 -50
handler.py CHANGED
@@ -5,64 +5,93 @@ import os
5
  import uuid
6
  import time
7
  import binascii
8
- import torch # Import the torch library
9
- from faster_whisper import WhisperModel
 
10
 
11
  class EndpointHandler:
12
  """
13
- A sophisticated handler for a Hugging Face Inference Endpoint that transcribes
14
- audio using the faster-whisper model. It accepts transcription parameters
15
- and returns a detailed JSON response.
 
 
 
16
  """
17
  def __init__(self, path=""):
18
  """
19
- Initializes the handler by loading the faster-whisper model.
20
- The model is loaded onto the GPU if available for better performance.
21
  """
22
- # --- THIS IS THE FIX ---
23
- # A more reliable way to check for GPU availability.
 
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  compute_type = "float16" if device == "cuda" else "int8"
26
- model_size = "large-v3" # You can change this to your desired model size
 
27
 
28
  # Load the transcription model
29
- self.model = WhisperModel(model_size, device=device, compute_type=compute_type)
30
- print(f"Model '{model_size}' loaded successfully on device: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def __call__(self, data):
33
  """
34
- The main inference method. It extracts audio and parameters, transcribes,
35
- and returns a detailed JSON response.
36
-
37
- Args:
38
- data (dict): The input data from the request.
39
-
40
- Returns:
41
- dict: A detailed dictionary containing the transcription or an error.
42
  """
43
  start_time = time.time()
44
 
45
- # 1. Extract audio and parameters from the payload
46
  try:
47
- # The base64 audio string is expected in the 'inputs' key
48
  audio_base64 = data.get("inputs")
49
  if not audio_base64:
50
  return {"error": "Missing 'inputs' key with base64 audio string.", "error_type": "Bad Request"}
51
 
52
- # Transcription parameters are in the 'parameters' key
53
  params = data.get("parameters", {})
54
- language = params.get("language") # Can be None, letting the model detect
55
  beam_size = int(params.get("beam_size", 5))
56
- word_timestamps = bool(params.get("word_timestamps", False))
57
 
58
  except Exception as e:
59
  return {"error": f"Error parsing input data: {e}", "error_type": "Bad Request"}
60
 
61
-
62
  # 2. Decode the base64 string and save to a temporary file
63
  try:
64
  audio_bytes = base64.b64decode(audio_base64)
65
- file_size_mb = len(audio_bytes) / (1024 * 1024)
66
  except (TypeError, binascii.Error) as e:
67
  return {"error": f"Invalid base64 string provided: {e}", "error_type": "Bad Request"}
68
 
@@ -71,10 +100,8 @@ class EndpointHandler:
71
  try:
72
  with open(temp_filename, "wb") as f:
73
  f.write(audio_bytes)
74
-
75
- print(f"Temporarily saved {file_size_mb:.2f} MB audio to {temp_filename}")
76
 
77
- # 3. Run the transcription with the specified parameters
78
  segments, info = self.model.transcribe(
79
  temp_filename,
80
  language=language,
@@ -82,35 +109,34 @@ class EndpointHandler:
82
  word_timestamps=word_timestamps
83
  )
84
 
85
- # 4. Process segments and build the successful response
86
- text_segments = []
87
- for segment in segments:
88
- text_segments.append(segment.text)
89
-
90
- full_text = "".join(text_segments).strip()
 
 
91
  processing_time = time.time() - start_time
92
-
93
- print(f"Transcription successful in {processing_time:.2f} seconds.")
94
 
95
- # 5. Return the detailed JSON response
96
  return {
97
  "text": full_text,
98
- "preview": full_text[:200] + ("..." if len(full_text) > 200 else ""),
99
- "language": info.language,
100
- "language_probability": round(info.language_probability, 2),
101
- "duration": round(info.duration, 2),
102
- "processing_time": round(processing_time, 2),
103
- "segments_count": len(text_segments),
104
- "audio_size_mb": round(file_size_mb, 2),
 
105
  }
106
 
107
  except Exception as e:
108
- # Catch any other exceptions during file handling or transcription
109
- print(f"An error occurred during transcription: {e}")
110
  return {"error": str(e), "error_type": "Inference Error"}
111
 
112
  finally:
113
- # 6. Clean up by deleting the temporary file
114
  if os.path.exists(temp_filename):
115
  os.remove(temp_filename)
116
- print(f"Cleaned up temporary file: {temp_filename}")
 
5
  import uuid
6
  import time
7
  import binascii
8
+ import torch
9
+ from datetime import timedelta
10
+ import faster_whisper
11
 
12
  class EndpointHandler:
13
  """
14
+ A custom handler for a Hugging Face Inference Endpoint that transcribes
15
+ audio using a high-performance faster-whisper model.
16
+
17
+ This handler is adapted from a multi-cell Colab notebook, combining model
18
+ loading, audio processing, transcription, and subtitle generation into
19
+ a single, robust API call.
20
  """
21
  def __init__(self, path=""):
22
  """
23
+ Initializes the handler by loading the specified faster-whisper model.
24
+ It automatically detects and uses a GPU (CUDA) if available.
25
  """
26
+ # Model ID from the Colab notebook
27
+ model_id = "ivrit-ai/whisper-large-v3-turbo-ct2"
28
+
29
+ # Reliable GPU detection
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
  compute_type = "float16" if device == "cuda" else "int8"
32
+
33
+ print(f"Initializing model '{model_id}' on device '{device}' with compute_type '{compute_type}'...")
34
 
35
  # Load the transcription model
36
+ self.model = faster_whisper.WhisperModel(model_id, device=device, compute_type=compute_type)
37
+ print("Model loaded successfully.")
38
+
39
+ # --- Helper functions adapted from Colab notebook ---
40
+
41
+ def _format_timestamp(self, seconds, format_type="srt"):
42
+ """Formats seconds into SRT or VTT timestamp format."""
43
+ if seconds is None: return "00:00:00,000" if format_type == "srt" else "00:00:00.000"
44
+ delta = timedelta(seconds=seconds)
45
+ hours, remainder = divmod(int(delta.total_seconds()), 3600)
46
+ minutes, sec = divmod(remainder, 60)
47
+ milliseconds = delta.microseconds // 1000
48
+ separator = "," if format_type == "srt" else "."
49
+ return f"{hours:02d}:{minutes:02d}:{sec:02d}{separator}{milliseconds:03d}"
50
+
51
+ def _generate_srt_content(self, segments):
52
+ """Generates SRT formatted subtitle content from transcription segments."""
53
+ srt_content = []
54
+ for i, segment in enumerate(segments):
55
+ start_time, end_time = segment.start, segment.end
56
+ srt_content.append(str(i + 1))
57
+ srt_content.append(f"{self._format_timestamp(start_time, 'srt')} --> {self._format_timestamp(end_time, 'srt')}")
58
+ srt_content.append(segment.text.strip())
59
+ srt_content.append("")
60
+ return "\n".join(srt_content)
61
+
62
+ def _generate_vtt_content(self, segments):
63
+ """Generates VTT formatted subtitle content from transcription segments."""
64
+ vtt_content = ["WEBVTT", ""]
65
+ for segment in segments:
66
+ start_time, end_time = segment.start, segment.end
67
+ vtt_content.append(f"{self._format_timestamp(start_time, 'vtt')} --> {self._format_timestamp(end_time, 'vtt')}")
68
+ vtt_content.append(segment.text.strip())
69
+ vtt_content.append("")
70
+ return "\n".join(vtt_content)
71
 
72
  def __call__(self, data):
73
  """
74
+ Handles a single API request for audio transcription.
 
 
 
 
 
 
 
75
  """
76
  start_time = time.time()
77
 
78
+ # 1. Extract audio (as base64) and parameters from the payload
79
  try:
 
80
  audio_base64 = data.get("inputs")
81
  if not audio_base64:
82
  return {"error": "Missing 'inputs' key with base64 audio string.", "error_type": "Bad Request"}
83
 
 
84
  params = data.get("parameters", {})
85
+ language = params.get("language", "he")
86
  beam_size = int(params.get("beam_size", 5))
87
+ word_timestamps = bool(params.get("word_timestamps", True))
88
 
89
  except Exception as e:
90
  return {"error": f"Error parsing input data: {e}", "error_type": "Bad Request"}
91
 
 
92
  # 2. Decode the base64 string and save to a temporary file
93
  try:
94
  audio_bytes = base64.b64decode(audio_base64)
 
95
  except (TypeError, binascii.Error) as e:
96
  return {"error": f"Invalid base64 string provided: {e}", "error_type": "Bad Request"}
97
 
 
100
  try:
101
  with open(temp_filename, "wb") as f:
102
  f.write(audio_bytes)
 
 
103
 
104
+ # 3. Run transcription with the specified parameters
105
  segments, info = self.model.transcribe(
106
  temp_filename,
107
  language=language,
 
109
  word_timestamps=word_timestamps
110
  )
111
 
112
+ # Segments is a generator, so we convert it to a list to reuse it
113
+ segment_list = list(segments)
114
+
115
+ # 4. Generate full text and subtitle formats
116
+ full_text = " ".join(s.text.strip() for s in segment_list)
117
+ srt_content = self._generate_srt_content(segment_list)
118
+ vtt_content = self._generate_vtt_content(segment_list)
119
+
120
  processing_time = time.time() - start_time
 
 
121
 
122
+ # 5. Return the complete response
123
  return {
124
  "text": full_text,
125
+ "srt": srt_content,
126
+ "vtt": vtt_content,
127
+ "metadata": {
128
+ "language": info.language,
129
+ "language_probability": round(info.language_probability, 2),
130
+ "audio_duration_seconds": round(info.duration, 2),
131
+ "processing_time_seconds": round(processing_time, 2),
132
+ }
133
  }
134
 
135
  except Exception as e:
 
 
136
  return {"error": str(e), "error_type": "Inference Error"}
137
 
138
  finally:
139
+ # 6. Clean up the temporary file
140
  if os.path.exists(temp_filename):
141
  os.remove(temp_filename)
142
+