|
import sys |
|
import time |
|
|
|
from importlib.metadata import version |
|
|
|
import torch |
|
import torchaudio |
|
import torchaudio.transforms as T |
|
|
|
import gradio as gr |
|
import numpy as np |
|
|
|
from transformers import HubertForCTC, Wav2Vec2Processor |
|
|
|
|
|
model_name = "Yehor/hubert-uk" |
|
|
|
min_duration = 0.5 |
|
max_duration = 60 |
|
|
|
concurrency_limit = 5 |
|
use_torch_compile = False |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
asr_model = HubertForCTC.from_pretrained( |
|
model_name, torch_dtype=torch_dtype, device_map=device |
|
) |
|
processor = Wav2Vec2Processor.from_pretrained(model_name) |
|
|
|
if use_torch_compile: |
|
asr_model = torch.compile(asr_model) |
|
|
|
|
|
examples = [ |
|
"example_1.wav", |
|
"example_2.wav", |
|
"example_3.wav", |
|
"example_4.wav", |
|
"example_5.wav", |
|
"example_6.wav", |
|
] |
|
|
|
examples_table = """ |
|
| File | Text | |
|
| ------------- | ------------- | |
|
| `example_1.wav` | тема про яку не люблять говорити офіційні джерела у генштабі і міноборони це хімічна зброя окупанти вже тривалий час використовують хімічну зброю заборонену | |
|
| `example_2.wav` | всіма конвенціями якщо спочатку це були гранати з дронів то тепер фіксують випадки застосування | |
|
| `example_3.wav` | хімічних снарядів причому склад отруйної речовони різний а отже й наслідки для наших військових теж різні | |
|
| `example_4.wav` | використовує на фронті все що має і хімічна зброя не вийняток тож з чим маємо справу розбиралася марія моганисян | |
|
| `example_5.wav` | двох тисяч випадків застосування росіянами боєприпасів споряджених небезпечними хімічними речовинами | |
|
| `example_6.wav` | на всі писані норми марія моганисян олександр моторний спецкор марафон єдині новини | |
|
""".strip() |
|
|
|
|
|
authors_table = """ |
|
## Authors |
|
|
|
Follow them in social networks and **contact** if you need any help or have any questions: |
|
|
|
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** | |
|
|-------------------------------------------------------------------------------------------------| |
|
| https://t.me/smlkw in Telegram | |
|
| https://x.com/yehor_smoliakov at X | |
|
| https://github.com/egorsmkv at GitHub | |
|
| https://huggingface.co/Yehor at Hugging Face | |
|
| or use [email protected] | |
|
""".strip() |
|
|
|
description_head = f""" |
|
# Speech-to-Text for Ukrainian using HuBERT |
|
|
|
## Overview |
|
|
|
This space uses https://huggingface.co/Yehor/hubert-uk model to recognize audio files. |
|
|
|
> Due to resource limitations, audio duration **must not** exceed **{max_duration}** seconds. |
|
""".strip() |
|
|
|
description_foot = f""" |
|
## Community |
|
|
|
- **Discord**: https://discord.gg/yVAjkBgmt4 |
|
- Speech Recognition: https://t.me/speech_recognition_uk |
|
- Speech Synthesis: https://t.me/speech_synthesis_uk |
|
|
|
## More |
|
|
|
Check out other ASR models: https://github.com/egorsmkv/speech-recognition-uk |
|
|
|
{authors_table} |
|
""".strip() |
|
|
|
transcription_value = """ |
|
Recognized text will appear here. |
|
|
|
Choose **an example file** below the Recognize button, upload **your audio file**, or use **the microphone** to record own voice. |
|
""".strip() |
|
|
|
tech_env = f""" |
|
#### Environment |
|
|
|
- Python: {sys.version} |
|
- Torch device: {device} |
|
- Torch dtype: {torch_dtype} |
|
- Use torch.compile: {use_torch_compile} |
|
""".strip() |
|
|
|
tech_libraries = f""" |
|
#### Libraries |
|
|
|
- torch: {version('torch')} |
|
- torchaudio: {version('torchaudio')} |
|
- transformers: {version('transformers')} |
|
- accelerate: {version('accelerate')} |
|
- gradio: {version('gradio')} |
|
""".strip() |
|
|
|
|
|
def inference(audio_path, progress=gr.Progress()): |
|
if not audio_path: |
|
raise gr.Error("Please upload an audio file.") |
|
|
|
gr.Info("Starting recognition", duration=2) |
|
|
|
progress(0, desc="Recognizing") |
|
|
|
meta = torchaudio.info(audio_path) |
|
duration = meta.num_frames / meta.sample_rate |
|
|
|
if duration < min_duration: |
|
raise gr.Error( |
|
f"The duration of the file is less than {min_duration} seconds, it is {round(duration, 2)} seconds." |
|
) |
|
if duration > max_duration: |
|
raise gr.Error(f"The duration of the file exceeds {max_duration} seconds.") |
|
|
|
paths = [ |
|
audio_path, |
|
] |
|
|
|
results = [] |
|
|
|
for path in progress.tqdm(paths, desc="Recognizing...", unit="file"): |
|
t0 = time.time() |
|
|
|
meta = torchaudio.info(audio_path) |
|
audio_duration = meta.num_frames / meta.sample_rate |
|
|
|
audio_input, sr = torchaudio.load(path) |
|
|
|
if meta.num_channels > 1: |
|
audio_input = torch.mean(audio_input, dim=0, keepdim=True) |
|
|
|
if meta.sample_rate != 16_000: |
|
resampler = T.Resample(sr, 16_000, dtype=audio_input.dtype) |
|
audio_input = resampler(audio_input) |
|
|
|
audio_input = audio_input.squeeze(0).numpy() |
|
|
|
inputs = processor( |
|
[audio_input], sampling_rate=16_000, padding=True |
|
).input_values |
|
features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device) |
|
|
|
with torch.inference_mode(): |
|
logits = asr_model(features).logits |
|
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
predictions = processor.batch_decode(predicted_ids) |
|
|
|
if not predictions: |
|
predictions = "-" |
|
|
|
elapsed_time = round(time.time() - t0, 2) |
|
rtf = round(elapsed_time / audio_duration, 4) |
|
audio_duration = round(audio_duration, 2) |
|
|
|
results.append( |
|
{ |
|
"path": path.split("/")[-1], |
|
"transcription": "\n".join(predictions), |
|
"audio_duration": audio_duration, |
|
"rtf": rtf, |
|
} |
|
) |
|
|
|
gr.Info("Finished!", duration=2) |
|
|
|
result_texts = [] |
|
|
|
for result in results: |
|
result_texts.append(f'**{result["path"]}**') |
|
result_texts.append("\n\n") |
|
result_texts.append(f'> {result["transcription"]}') |
|
result_texts.append("\n\n") |
|
result_texts.append(f'**Audio duration**: {result["audio_duration"]}') |
|
result_texts.append("\n") |
|
result_texts.append(f'**Real-Time Factor**: {result["rtf"]}') |
|
|
|
return "\n".join(result_texts) |
|
|
|
|
|
demo = gr.Blocks( |
|
title="Speech-to-Text for Ukrainian", |
|
analytics_enabled=False, |
|
theme=gr.themes.Base(), |
|
) |
|
|
|
with demo: |
|
gr.Markdown(description_head) |
|
|
|
gr.Markdown("## Usage") |
|
|
|
with gr.Row(): |
|
audio_file = gr.Audio(label="Audio file", type="filepath") |
|
transcription = gr.Markdown( |
|
label="Transcription", |
|
value=transcription_value, |
|
) |
|
|
|
gr.Button("Recognize").click( |
|
inference, |
|
concurrency_limit=concurrency_limit, |
|
inputs=audio_file, |
|
outputs=transcription, |
|
) |
|
|
|
with gr.Row(): |
|
gr.Examples(label="Choose an example", inputs=audio_file, examples=examples) |
|
|
|
gr.Markdown(examples_table) |
|
|
|
gr.Markdown(description_foot) |
|
|
|
gr.Markdown("### Gradio app uses:") |
|
gr.Markdown(tech_env) |
|
gr.Markdown(tech_libraries) |
|
|
|
if __name__ == "__main__": |
|
demo.queue() |
|
demo.launch() |
|
|