IPA-Transcription-EN / app /inference.py
SanderGi's picture
hubert phoneme + quick test model
20b52a3
# This module handles model inference
import torch
from transformers import AutoProcessor, AutoModelForCTC
from espnet2.bin.s2t_inference import Speech2Text
from inference_huberphoneme import HuBERTPhoneme, Tokenizer
MODEL_TYPES = ["Transformers CTC", "POWSM", "HuBERTPhoneme"]
DEVICE = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
# set espeak library path for macOS
import sys
if sys.platform == "darwin":
from phonemizer.backend.espeak.wrapper import EspeakWrapper
_ESPEAK_LIBRARY = "/opt/homebrew/Cellar/espeak/1.48.04_1/lib/libespeak.1.1.48.dylib"
EspeakWrapper.set_library(_ESPEAK_LIBRARY)
def clear_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
# ================================== POWSM ==================================
def load_powsm(model_id, language="<eng>", device=DEVICE):
s2t = Speech2Text.from_pretrained(
model_id,
device=device.replace("mps", "cpu"),
lang_sym=language,
task_sym="<pr>",
)
if device == "mps":
s2t.s2t_model.to(device=device, dtype=torch.float32)
s2t.beam_search.to(device=device, dtype=torch.float32)
s2t.dtype = "float32"
s2t.device = device
return s2t
def transcribe_powsm(audio, model):
pred = model(audio, text_prev="<na>")[0][0]
return pred.split("<notimestamps>")[1].strip().replace("/", "")
# ===========================================================================
# ============================= Transformers CTC ============================
def load_transformers_ctc(model_id, device=DEVICE):
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForCTC.from_pretrained(model_id).to(device)
return model, processor
def transcribe_transformers_ctc(audio, model) -> str:
model, processor = model
input_values = (
processor(
[audio],
sampling_rate=processor.feature_extractor.sampling_rate,
return_tensors="pt",
padding=True,
)
.input_values.type(torch.float32)
.to(model.device)
)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
return processor.decode(predicted_ids[0])
# ===========================================================================
# ============================== HuBERTPhoneme ==============================
def load_hubert_phoneme(model_id, device=DEVICE):
model = HuBERTPhoneme.from_pretrained(model_id).to(device).eval()
tokenizer = Tokenizer(with_blank=model.ctc_training)
return model, tokenizer, device
def transcribe_hubert_phoneme(audio, model) -> str:
model, tokenizer, device = model
with torch.inference_mode():
output, _ = model.inference(torch.from_numpy(audio).to(device).unsqueeze(0))
predictions = output.argmax(dim=-1).squeeze().cpu()
arpabet = tokenizer.decode(predictions.unique_consecutive())
return arpabet
# ===========================================================================
def load_model(model_id, type, device=DEVICE):
if type == "POWSM":
return load_powsm(model_id, device=device)
elif type == "Transformers CTC":
return load_transformers_ctc(model_id, device=device)
elif type == "HuBERTPhoneme":
return load_hubert_phoneme(model_id, device=device)
else:
raise ValueError("Unsupported model type: " + str(type))
def transcribe(audio, type, model) -> str:
if type == "POWSM":
return transcribe_powsm(audio, model)
elif type == "Transformers CTC":
return transcribe_transformers_ctc(audio, model)
elif type == "HuBERTPhoneme":
return transcribe_hubert_phoneme(audio, model)
else:
raise ValueError("Unsupported model type: " + str(type))