Shan-ASR-Demo / asr.py
NorHsangPha's picture
Initial: initial commit
772adb0
raw
history blame
2.28 kB
import os
import librosa
from transformers import Wav2Vec2ForCTC, AutoProcessor
import torch
ASR_SAMPLING_RATE = 16_000
def transcribe(model_name: str, audio_source=None, microphone=None, file_upload=None):
if type(microphone) is dict:
microphone = microphone["name"]
audio_fp = (
file_upload if "upload" in str(audio_source or "").lower() else microphone
)
if audio_fp is None:
return "ERROR: You have to either use the microphone or upload an audio file"
audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]
model_id = {
"original": "facebook/mms-1b-all",
"finetune": "NorHsangPha/wav2vec2-large-mms-1b-shan",
}[model_name]
auth_token = os.environ.get("TOKEN_READ_SECRET") or True
if model_name == "original":
model = Wav2Vec2ForCTC.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.set_target_lang("shn")
model.load_adapter("shn")
elif model_name == "finetune":
model = Wav2Vec2ForCTC.from_pretrained(
model_id, target_lang="shn", ignore_mismatched_sizes=True, token=auth_token
)
processor = AutoProcessor.from_pretrained(model_id, token=auth_token)
else:
return "ERROR: Wrong model name, or model not available please restart."
if torch.cuda.is_available():
device = torch.device("cuda")
elif (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
):
device = torch.device("mps")
else:
device = torch.device("cpu")
model.to(device)
inputs = processor(
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
)
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(**inputs).logits
ids = torch.argmax(outputs, dim=-1)[0]
transcription = processor.decode(ids)
return transcription
ASR_EXAMPLES = [
["finetune", "Upload audio", None, "upload/sample1.wav"],
["finetune", "Upload audio", None, "upload/sample2.wav"],
["original", "Upload audio", None, "upload/sample1.wav"],
["original", "Upload audio", None, "upload/sample2.wav"],
]