nykodmar's picture
Added examples
4c28db1
import gradio as gr
import os
import librosa
from transformers import Wav2Vec2ProcessorWithLM, AutoModelForCTC, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
import torch
model_name = os.getenv("MODEL_NAME")
auth_token = os.getenv("API_TOKEN")
# Load models
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name, eos_token=None, bos_token=None, use_auth_token=auth_token)
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_name, use_auth_token=auth_token)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name, use_auth_token=auth_token)
decoder = processor.decoder
processor = Wav2Vec2ProcessorWithLM(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder)
model = AutoModelForCTC.from_pretrained(model_name, use_auth_token=auth_token)
def load_data(input_file):
# Read the file
speech, sample_rate = librosa.load(input_file)
# Make it 1-D
if len(speech.shape) > 1:
speech = speech[:,0] + speech[:,1]
# Resampling at 16KHz
if sample_rate !=16_000:
speech = librosa.resample(speech, sample_rate, 16_000)
return speech
def transcribe(input_file):
audio = load_data(input_file)
# audio = input_file
# Tokenize
input_values = processor(audio, return_tensors="pt", sampling_rate=16_000).input_values
# Take logits
with torch.no_grad():
logits = model(input_values).logits.cpu().numpy()[0]
# Decode
text = decoder.decode(logits, beam_width=30)
return text
examples = [
["examples/example1.mp3"],
["examples/example2.mp3"],
]
gr.Interface(
title="Rozpoznání mluvené řeči pro český jazyk",
fn=transcribe,
inputs=gr.inputs.Audio(source="upload", type="filepath"),
outputs="text",
examples=examples
).launch()