Spaces:
Running
Running
| import os | |
| from abc import ABC, abstractmethod | |
| import librosa | |
| import numpy as np | |
| from transformers import pipeline | |
| ASR_MODEL_REGISTRY = {} | |
| hf_token = os.getenv("HF_TOKEN") | |
| class AbstractASRModel(ABC): | |
| def __init__( | |
| self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs | |
| ): | |
| print(f"Loading ASR model {model_id}...") | |
| self.model_id = model_id | |
| self.device = device | |
| self.cache_dir = cache_dir | |
| def transcribe(self, audio: np.ndarray, audio_sample_rate: int, **kwargs) -> str: | |
| pass | |
| def register_asr_model(prefix): | |
| def wrapper(cls): | |
| assert issubclass(cls, AbstractASRModel), f"{cls} must inherit AbstractASRModel" | |
| ASR_MODEL_REGISTRY[prefix] = cls | |
| return cls | |
| return wrapper | |
| def get_asr_model(model_id: str, device="cpu", **kwargs) -> AbstractASRModel: | |
| for prefix, cls in ASR_MODEL_REGISTRY.items(): | |
| if model_id.startswith(prefix): | |
| return cls(model_id, device=device, **kwargs) | |
| raise ValueError(f"No ASR wrapper found for model: {model_id}") | |
| class WhisperASR(AbstractASRModel): | |
| def __init__( | |
| self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs | |
| ): | |
| super().__init__(model_id, device, cache_dir, **kwargs) | |
| model_kwargs = kwargs.setdefault("model_kwargs", {}) | |
| model_kwargs["cache_dir"] = cache_dir | |
| self.pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=model_id, | |
| device=0 if device == "cuda" else -1, | |
| token=hf_token, | |
| **kwargs, | |
| ) | |
| def transcribe(self, audio: np.ndarray, audio_sample_rate: int, language: str, **kwargs) -> str: | |
| if audio_sample_rate != 16000: | |
| audio = librosa.resample(audio, orig_sr=audio_sample_rate, target_sr=16000) | |
| return self.pipe(audio, generate_kwargs={"language": language}, return_timestamps=False).get("text", "") | |