File size: 2,917 Bytes
4be82b6 4fefa0f 084cf95 866ad82 4be82b6 084cf95 e592d13 9cf4ed9 675e248 084cf95 d67d88c 084cf95 f8a90a5 084cf95 4be82b6 084cf95 3fa5149 4be82b6 084cf95 4be82b6 81c80a1 084cf95 4fefa0f c48dacd 6840864 4be82b6 afb8670 4fefa0f 951d010 084cf95 afb8670 084cf95 f8a90a5 afb8670 084cf95 4be82b6 f8a90a5 afb8670 084cf95 3fa5149 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datetime import datetime
import gradio as gr
# import spaces
import torch
DEVICE = 'cpu' # 'cuda' if torch.cuda.is_available() else 'cpu'
DEBUG_UI = False
LANGS = {
'English': 'eng_Latn',
'Interslavic': 'isv_Latn',
# 'Интерславик': 'isv_Cyrl',
'Russian': 'rus_Cyrl',
'Belarusian': 'bel_Cyrl',
'Ukrainian': 'ukr_Cyrl',
'Polish': 'pol_Latn',
'Silesian': 'szl_Latn',
'Czech': 'ces_Latn',
'Slovak': 'slk_Latn',
'Slovenian': 'slv_Latn',
'Croatian': 'hrv_Latn',
'Bosnian': 'bos_Latn',
'Serbian': 'srp_Cyrl',
'Macedonian': 'mkd_Cyrl',
'Bulgarian': 'bul_Cyrl',
'Esperanto': 'epo_Latn',
'German': 'deu_Latn',
'French': 'fra_Latn',
'Spanish': 'spa_Latn',
}
if DEBUG_UI:
def translate(text, src_lang, tgt_lang):
return text
else:
model_name = 'salavat/nllb-200-distilled-600M-finetuned-isv_v2'
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(model_name)
lang2id = tokenizer.added_tokens_encoder
# @spaces.GPU
def translate(text, from_, to_):
start = datetime.now()
# empty line hallucinations fix
lines = [f'{line} ' for line in text.split('\n')] if text else ''
inputs = tokenizer(lines, return_tensors="pt", padding=True).to(DEVICE)
inputs['input_ids'][:, 0] = lang2id[LANGS[from_]]
translated_tokens = model.generate(**inputs, max_length=400, forced_bos_token_id=lang2id[LANGS[to_]])
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
output = '\n'.join(result)
end = datetime.now()
input_shape = tuple(inputs.input_ids.shape)
output_shape = tuple(translated_tokens.shape)
print(f'[{end}] {DEVICE}_time - {end-start}; input {input_shape} / output {output_shape};')
return output
with gr.Blocks() as demo:
gr.Markdown('<div align="center"><h1>Interslavic translator</h1></div>')
with gr.Row():
lang_input = gr.components.Dropdown(label="From", choices=list(LANGS.keys()), value='English')
lang_output = gr.components.Dropdown(label="To", choices=list(LANGS.keys()), value='Interslavic')
with gr.Row(equal_height=True):
text_input = gr.components.Textbox(label="Text", lines=5, placeholder="Your text")
text_output = gr.components.Textbox(label="Result", lines=5, placeholder="Translation...")
translate_btn = gr.Button("Translate")
gr.Markdown((
'Finetuned model [NLLB200](https://ai.facebook.com/research/no-language-left-behind/) '
'using corpus of [Inter-Slavic](https://interslavic-dictionary.com/grammar) language'
))
translate_btn.click(translate, inputs=[text_input, lang_input, lang_output], outputs=text_output)
demo.launch(share=True) |