IPA-Transcription-EN / app /inference.py
SanderGi's picture
fix typo
f718747
# This module handles model inference
import torch
from transformers import AutoProcessor, AutoModelForCTC
DEVICE = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
# set espeak library path for macOS
import sys
if sys.platform == "darwin":
from phonemizer.backend.espeak.wrapper import EspeakWrapper
_ESPEAK_LIBRARY = "/opt/homebrew/Cellar/espeak/1.48.04_1/lib/libespeak.1.1.48.dylib"
EspeakWrapper.set_library(_ESPEAK_LIBRARY)
def clear_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
def load_model(model_id, device=DEVICE):
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForCTC.from_pretrained(model_id).to(device)
return model, processor
def transcribe(audio, model, processor) -> str:
input_values = (
processor(
[audio],
sampling_rate=processor.feature_extractor.sampling_rate,
return_tensors="pt",
padding=True,
)
.input_values.type(torch.float32)
.to(model.device)
)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
return processor.decode(predicted_ids[0])