import torch from torch import nn import torch.nn.functional as F import nltk nltk.download("punkt_tab") from nltk.tokenize import word_tokenize from custom_utils import encode_texts, decode_ids from model_structure import Seq2Seq device = "cpu" import json uz_encode_vocab = json.load(open("uz_encoding_vocab.json", "r")) uz_decode_vocab = json.load(open("uz_decoding_vocab.json", "r")) uz_decode_vocab = {int(k): v for k, v in uz_decode_vocab.items()} en_encode_vocab = json.load(open("en_encoding_vocab.json", "r")) en_decode_vocab = json.load(open("en_decoding_vocab.json", "r")) en_decode_vocab = {int(k): v for k, v in en_decode_vocab.items()} en2uz_model = Seq2Seq( input_vocab_size=len(en_encode_vocab), output_vocab_size=len(uz_encode_vocab), emb_dim=1000, hid_size=1000, ).to(device) en2uz_model.load_state_dict( torch.load("en2uz_lstm_model_weights_50k.pth", map_location=device) ) uz2en_model = Seq2Seq( input_vocab_size=len(uz_encode_vocab), output_vocab_size=len(en_encode_vocab), emb_dim=1000, hid_size=1000, ).to(device) uz2en_model.load_state_dict( torch.load("uz2en_lstm_model_weights_50k.pth", map_location=device) ) import gradio as gr def swap_labels_and_texts(is_en2uz, source_text, target_text): if is_en2uz: return "Uzbek", "English", False, target_text, source_text else: return "English", "Uzbek", True, target_text, source_text def translate_text(source_text, is_en2uz): if is_en2uz: input_tokens, input_mask = encode_texts( [source_text], en_encode_vocab, decoder_input=False, device=device ) output_tokens, _ = en2uz_model.translate(input_tokens, input_mask, max_len=50) target_text = decode_ids( output_tokens, uz_decode_vocab, uz_encode_vocab[""] )[0] else: input_tokens, input_mask = encode_texts( [source_text], uz_encode_vocab, decoder_input=False, device=device ) output_tokens, _ = uz2en_model.translate(input_tokens, input_mask, max_len=50) target_text = decode_ids( output_tokens, en_decode_vocab, en_encode_vocab[""] )[0] return target_text with gr.Blocks(title = 'Seq2Seq Language Translator') as demo: is_en2uz = gr.State(True) with gr.Row(equal_height=True): label1 = gr.Label(value="English", label="Source Language") swap_button = gr.Button("⇄ Swap", size="sm", variant="huggingface") label2 = gr.Label(value="Uzbek", label="Target Language") with gr.Group(): with gr.Row(equal_height=True): source_textbox = gr.Textbox(label="Source Text", placeholder="Enter text") target_textbox = gr.Textbox( label="Target Text", placeholder="Translation", interactive=False ) clear_button = gr.ClearButton( [source_textbox, target_textbox], size="sm", variant="stop" ) translate_button = gr.Button("Translate", variant="primary", size="lg") translate_button.click( fn=translate_text, inputs=[source_textbox, is_en2uz], outputs=target_textbox, ) swap_button.click( fn=swap_labels_and_texts, inputs=[is_en2uz, source_textbox, target_textbox], outputs=[label1, label2, is_en2uz, source_textbox, target_textbox], ) demo.launch(server_name="0.0.0.0", server_port=7860)