File size: 2,917 Bytes
4be82b6
4fefa0f
084cf95
866ad82
4be82b6
084cf95
e592d13
9cf4ed9
675e248
084cf95
 
d67d88c
 
084cf95
 
 
 
 
f8a90a5
 
 
 
 
 
 
 
084cf95
 
 
 
 
 
 
 
 
4be82b6
084cf95
 
3fa5149
4be82b6
084cf95
4be82b6
 
81c80a1
084cf95
4fefa0f
 
c48dacd
6840864
4be82b6
afb8670
 
 
 
4fefa0f
 
951d010
 
 
084cf95
afb8670
084cf95
f8a90a5
afb8670
084cf95
 
 
4be82b6
f8a90a5
 
afb8670
 
 
 
 
 
084cf95
3fa5149
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datetime import datetime
import gradio as gr
# import spaces
import torch


DEVICE = 'cpu' # 'cuda' if torch.cuda.is_available() else 'cpu'
DEBUG_UI = False
LANGS = {
    'English': 'eng_Latn',
    'Interslavic': 'isv_Latn',
    # 'Интерславик': 'isv_Cyrl',
    'Russian': 'rus_Cyrl',
    'Belarusian': 'bel_Cyrl',
    'Ukrainian': 'ukr_Cyrl',
    'Polish': 'pol_Latn',
    'Silesian': 'szl_Latn',
    'Czech': 'ces_Latn',
    'Slovak': 'slk_Latn',
    'Slovenian': 'slv_Latn',
    'Croatian': 'hrv_Latn',
    'Bosnian': 'bos_Latn',
    'Serbian': 'srp_Cyrl',
    'Macedonian': 'mkd_Cyrl',
    'Bulgarian': 'bul_Cyrl',
    'Esperanto': 'epo_Latn',
    'German': 'deu_Latn',
    'French': 'fra_Latn',
    'Spanish': 'spa_Latn',
}


if DEBUG_UI:
    def translate(text, src_lang, tgt_lang):
        return text

else:
    model_name = 'salavat/nllb-200-distilled-600M-finetuned-isv_v2'
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    lang2id = tokenizer.added_tokens_encoder

    # @spaces.GPU
    def translate(text, from_, to_):
        start = datetime.now()

        # empty line hallucinations fix
        lines = [f'{line} ' for line in text.split('\n')] if text else ''
        inputs = tokenizer(lines, return_tensors="pt", padding=True).to(DEVICE)
        inputs['input_ids'][:, 0] = lang2id[LANGS[from_]]
        translated_tokens = model.generate(**inputs, max_length=400, forced_bos_token_id=lang2id[LANGS[to_]])
        result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
        output = '\n'.join(result)

        end = datetime.now()
        input_shape = tuple(inputs.input_ids.shape)
        output_shape = tuple(translated_tokens.shape)
        print(f'[{end}] {DEVICE}_time - {end-start}; input {input_shape} / output {output_shape};')
        return output
        

with gr.Blocks() as demo:
    gr.Markdown('<div align="center"><h1>Interslavic translator</h1></div>')
    with gr.Row():
        lang_input = gr.components.Dropdown(label="From", choices=list(LANGS.keys()), value='English')
        lang_output = gr.components.Dropdown(label="To", choices=list(LANGS.keys()), value='Interslavic')
    with gr.Row(equal_height=True):
        text_input = gr.components.Textbox(label="Text", lines=5, placeholder="Your text")
        text_output = gr.components.Textbox(label="Result", lines=5, placeholder="Translation...")
    translate_btn = gr.Button("Translate")
    gr.Markdown((
        'Finetuned model [NLLB200](https://ai.facebook.com/research/no-language-left-behind/) '
        'using corpus of [Inter-Slavic](https://interslavic-dictionary.com/grammar) language'
    ))
    translate_btn.click(translate, inputs=[text_input, lang_input, lang_output], outputs=text_output)

demo.launch(share=True)