File size: 3,431 Bytes
5900e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from torch import nn
import torch.nn.functional as F

import nltk
nltk.download("punkt_tab")
from nltk.tokenize import word_tokenize

from custom_utils import encode_texts, decode_ids
from model_structure import Seq2Seq

device = "cpu"

import json

uz_encode_vocab = json.load(open("uz_encoding_vocab.json", "r"))

uz_decode_vocab = json.load(open("uz_decoding_vocab.json", "r"))
uz_decode_vocab = {int(k): v for k, v in uz_decode_vocab.items()}

en_encode_vocab = json.load(open("en_encoding_vocab.json", "r"))

en_decode_vocab = json.load(open("en_decoding_vocab.json", "r"))
en_decode_vocab = {int(k): v for k, v in en_decode_vocab.items()}

en2uz_model = Seq2Seq(
    input_vocab_size=len(en_encode_vocab),
    output_vocab_size=len(uz_encode_vocab),
    emb_dim=1000,
    hid_size=1000,
).to(device)

en2uz_model.load_state_dict(
    torch.load("en2uz_lstm_model_weights_50k.pth", map_location=device)
)


uz2en_model = Seq2Seq(
    input_vocab_size=len(uz_encode_vocab),
    output_vocab_size=len(en_encode_vocab),
    emb_dim=1000,
    hid_size=1000,
).to(device)

uz2en_model.load_state_dict(
    torch.load("uz2en_lstm_model_weights_50k.pth", map_location=device)
)

import gradio as gr


def swap_labels_and_texts(is_en2uz, source_text, target_text):
    if is_en2uz:
        return "Uzbek", "English", False, target_text, source_text
    else:
        return "English", "Uzbek", True, target_text, source_text


def translate_text(source_text, is_en2uz):
    if is_en2uz:
        input_tokens, input_mask = encode_texts(
            [source_text], en_encode_vocab, decoder_input=False, device=device
        )
        output_tokens, _ = en2uz_model.translate(input_tokens, input_mask, max_len=50)
        target_text = decode_ids(
            output_tokens, uz_decode_vocab, uz_encode_vocab["<eos>"]
        )[0]

    else:
        input_tokens, input_mask = encode_texts(
            [source_text], uz_encode_vocab, decoder_input=False, device=device
        )
        output_tokens, _ = uz2en_model.translate(input_tokens, input_mask, max_len=50)
        target_text = decode_ids(
            output_tokens, en_decode_vocab, en_encode_vocab["<eos>"]
        )[0]
    return target_text


with gr.Blocks(title = 'Seq2Seq Language Translator') as demo:
    is_en2uz = gr.State(True)

    with gr.Row(equal_height=True):
        label1 = gr.Label(value="English", label="Source Language")
        swap_button = gr.Button("⇄ Swap", size="sm", variant="huggingface")
        label2 = gr.Label(value="Uzbek", label="Target Language")

    with gr.Group():
        with gr.Row(equal_height=True):
            source_textbox = gr.Textbox(label="Source Text", placeholder="Enter text")
            target_textbox = gr.Textbox(
                label="Target Text", placeholder="Translation", interactive=False
            )

        clear_button = gr.ClearButton(
            [source_textbox, target_textbox], size="sm", variant="stop"
        )
    translate_button = gr.Button("Translate", variant="primary", size="lg")

    translate_button.click(
        fn=translate_text,
        inputs=[source_textbox, is_en2uz],
        outputs=target_textbox,
    )

    swap_button.click(
        fn=swap_labels_and_texts,
        inputs=[is_en2uz, source_textbox, target_textbox],
        outputs=[label1, label2, is_en2uz, source_textbox, target_textbox],
    )

demo.launch(server_name="0.0.0.0", server_port=7860)