|
import base64 |
|
import logging |
|
import time |
|
from tempfile import NamedTemporaryFile |
|
import torch |
|
|
|
|
|
from faster_whisper import WhisperModel |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
model_path = "meirk/whisper-large-v3-turbo-ct2-copy" |
|
|
|
logger.info(f"Loading CTranslate2 model from: {model_path}") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
self.device = "cuda" |
|
|
|
self.compute_type = "float16" |
|
logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}. Using compute_type: {self.compute_type}") |
|
else: |
|
self.device = "cpu" |
|
self.compute_type = "float32" |
|
logger.info(f"CUDA not available, using CPU with compute_type: {self.compute_type}") |
|
|
|
|
|
|
|
self.model = WhisperModel(model_path, device=self.device, compute_type=self.compute_type) |
|
logger.info(f"Model loaded on {self.device}") |
|
|
|
def __call__(self, data): |
|
try: |
|
start_time = time.time() |
|
|
|
|
|
params = data.get("parameters", {}) |
|
|
|
|
|
audio_b64 = data.get("inputs", None) |
|
if not audio_b64: |
|
return {"error": "Missing 'inputs' field"} |
|
|
|
|
|
audio_bytes = base64.b64decode(audio_b64) |
|
audio_size_mb = len(audio_bytes) / (1024 * 1024) |
|
logger.info(f"Processing {audio_size_mb:.2f} MB of audio on {self.device}") |
|
|
|
|
|
with NamedTemporaryFile(delete=True) as tmp: |
|
tmp.write(audio_bytes) |
|
tmp.flush() |
|
|
|
logger.info("Starting transcription...") |
|
|
|
|
|
|
|
segments_generator, info = self.model.transcribe( |
|
tmp.name, |
|
language=params.get("language", "he"), |
|
task=params.get("task", "transcribe"), |
|
beam_size=params.get("beam_size", 5), |
|
|
|
temperature=params.get("temperature", 0), |
|
word_timestamps=params.get("word_timestamps", False), |
|
|
|
initial_prompt=params.get("initial_prompt", None), |
|
|
|
no_speech_threshold=0.6, |
|
log_prob_threshold=-1.0, |
|
condition_on_previous_text=False |
|
) |
|
|
|
|
|
segments = [] |
|
full_text = "" |
|
audio_duration = 0 |
|
|
|
for seg in segments_generator: |
|
full_text += seg.text |
|
segment_data = { |
|
"text": seg.text.strip(), |
|
"start": round(seg.start, 2), |
|
"end": round(seg.end, 2) |
|
} |
|
|
|
if params.get("word_timestamps", False) and seg.words: |
|
segment_data["words"] = [ |
|
{ |
|
"word": w.word, |
|
"start": round(w.start, 2), |
|
"end": round(w.end, 2), |
|
"probability": round(w.probability, 3) |
|
} for w in seg.words |
|
] |
|
|
|
segments.append(segment_data) |
|
|
|
|
|
detected_language = info.language |
|
audio_duration = info.duration |
|
|
|
processing_time = time.time() - start_time |
|
|
|
|
|
speed_ratio = audio_duration / processing_time if processing_time > 0 else 0 |
|
|
|
logger.info(f"Detected language '{detected_language}' with probability {info.language_probability:.2f}") |
|
logger.info(f"Audio duration: {audio_duration:.2f}s") |
|
logger.info(f"Completed in {processing_time:.1f}s ({speed_ratio:.1f}x realtime)") |
|
|
|
return { |
|
"text": full_text, |
|
"chunks": segments, |
|
"language": detected_language, |
|
"processing_time": round(processing_time, 2), |
|
"speed_ratio": round(speed_ratio, 2), |
|
"segment_count": len(segments), |
|
"device_used": self.device |
|
} |
|
|
|
except Exception as e: |
|
logger.error("Error: %s", str(e), exc_info=True) |
|
return { |
|
"error": str(e), |
|
"error_type": type(e).__name__, |
|
"device_used": getattr(self, 'device', 'unknown') |
|
} |