Hebrew
File size: 5,689 Bytes
6a2757d
 
f0785c9
4b615a0
f0785c9
6a2757d
062214b
4990c86
 
 
8ff3a7b
152c35a
4b615a0
4990c86
 
 
 
 
 
4b615a0
f0785c9
4b615a0
4990c86
 
f0785c9
4990c86
 
 
 
4c8e2a9
f0785c9
4990c86
 
4b615a0
6a2757d
4990c86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b615a0
f0785c9
4b615a0
4990c86
4b615a0
6a2757d
4b615a0
4990c86
6a2757d
 
 
 
 
 
4990c86
6a2757d
4990c86
6a2757d
 
 
 
 
f0785c9
 
062214b
6a2757d
f0785c9
6a2757d
4b615a0
d93cc71
f0785c9
 
e98cc49
4990c86
6a2757d
 
 
 
 
 
e98cc49
4990c86
 
 
 
 
 
 
 
6a2757d
 
4990c86
f0785c9
6a2757d
4990c86
 
 
 
 
 
 
 
f0785c9
e98cc49
d93cc71
6a2757d
f0785c9
 
4990c86
f0785c9
 
4990c86
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# handler.py

import base64
import os
import uuid
import time
import binascii
import torch
from datetime import timedelta
import faster_whisper

class EndpointHandler:
    """
    A custom handler for a Hugging Face Inference Endpoint that transcribes
    audio using a high-performance faster-whisper model.

    This handler is adapted from a multi-cell Colab notebook, combining model
    loading, audio processing, transcription, and subtitle generation into
    a single, robust API call.
    """
    def __init__(self, path=""):
        """
        Initializes the handler by loading the specified faster-whisper model.
        It automatically detects and uses a GPU (CUDA) if available.
        """
        # Model ID from the Colab notebook
        model_id = "ivrit-ai/whisper-large-v3-turbo-ct2"

        # Reliable GPU detection
        device = "cuda" if torch.cuda.is_available() else "cpu"
        compute_type = "float16" if device == "cuda" else "int8"

        print(f"Initializing model '{model_id}' on device '{device}' with compute_type '{compute_type}'...")

        # Load the transcription model
        self.model = faster_whisper.WhisperModel(model_id, device=device, compute_type=compute_type)
        print("✅ Model loaded successfully.")

    # --- Helper functions adapted from Colab notebook ---

    def _format_timestamp(self, seconds, format_type="srt"):
        """Formats seconds into SRT or VTT timestamp format."""
        if seconds is None: return "00:00:00,000" if format_type == "srt" else "00:00:00.000"
        delta = timedelta(seconds=seconds)
        hours, remainder = divmod(int(delta.total_seconds()), 3600)
        minutes, sec = divmod(remainder, 60)
        milliseconds = delta.microseconds // 1000
        separator = "," if format_type == "srt" else "."
        return f"{hours:02d}:{minutes:02d}:{sec:02d}{separator}{milliseconds:03d}"

    def _generate_srt_content(self, segments):
        """Generates SRT formatted subtitle content from transcription segments."""
        srt_content = []
        for i, segment in enumerate(segments):
            start_time, end_time = segment.start, segment.end
            srt_content.append(str(i + 1))
            srt_content.append(f"{self._format_timestamp(start_time, 'srt')} --> {self._format_timestamp(end_time, 'srt')}")
            srt_content.append(segment.text.strip())
            srt_content.append("")
        return "\n".join(srt_content)

    def _generate_vtt_content(self, segments):
        """Generates VTT formatted subtitle content from transcription segments."""
        vtt_content = ["WEBVTT", ""]
        for segment in segments:
            start_time, end_time = segment.start, segment.end
            vtt_content.append(f"{self._format_timestamp(start_time, 'vtt')} --> {self._format_timestamp(end_time, 'vtt')}")
            vtt_content.append(segment.text.strip())
            vtt_content.append("")
        return "\n".join(vtt_content)

    def __call__(self, data):
        """
        Handles a single API request for audio transcription.
        """
        start_time = time.time()

        # 1. Extract audio (as base64) and parameters from the payload
        try:
            audio_base64 = data.get("inputs")
            if not audio_base64:
                return {"error": "Missing 'inputs' key with base64 audio string.", "error_type": "Bad Request"}

            params = data.get("parameters", {})
            language = params.get("language", "he")
            beam_size = int(params.get("beam_size", 5))
            word_timestamps = bool(params.get("word_timestamps", True))

        except Exception as e:
            return {"error": f"Error parsing input data: {e}", "error_type": "Bad Request"}

        # 2. Decode the base64 string and save to a temporary file
        try:
            audio_bytes = base64.b64decode(audio_base64)
        except (TypeError, binascii.Error) as e:
            return {"error": f"Invalid base64 string provided: {e}", "error_type": "Bad Request"}

        temp_filename = os.path.join("/tmp", f"{uuid.uuid4()}.mp3")

        try:
            with open(temp_filename, "wb") as f:
                f.write(audio_bytes)

            # 3. Run transcription with the specified parameters
            segments, info = self.model.transcribe(
                temp_filename,
                language=language,
                beam_size=beam_size,
                word_timestamps=word_timestamps
            )

            # Segments is a generator, so we convert it to a list to reuse it
            segment_list = list(segments)

            # 4. Generate full text and subtitle formats
            full_text = " ".join(s.text.strip() for s in segment_list)
            srt_content = self._generate_srt_content(segment_list)
            vtt_content = self._generate_vtt_content(segment_list)

            processing_time = time.time() - start_time

            # 5. Return the complete response
            return {
                "text": full_text,
                "srt": srt_content,
                "vtt": vtt_content,
                "metadata": {
                    "language": info.language,
                    "language_probability": round(info.language_probability, 2),
                    "audio_duration_seconds": round(info.duration, 2),
                    "processing_time_seconds": round(processing_time, 2),
                }
            }

        except Exception as e:
            return {"error": str(e), "error_type": "Inference Error"}

        finally:
            # 6. Clean up the temporary file
            if os.path.exists(temp_filename):
                os.remove(temp_filename)