File size: 3,726 Bytes
b36fd77
7efa162
bcc0c7f
 
 
 
 
 
 
b36fd77
bcc0c7f
7efa162
 
 
bcc0c7f
7efa162
 
 
 
 
 
 
bcc0c7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7efa162
 
 
 
 
 
bcc0c7f
 
7efa162
 
bcc0c7f
 
 
 
 
7efa162
 
 
 
 
 
bcc0c7f
7efa162
bcc0c7f
7efa162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcc0c7f
7efa162
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
from shakkala import Shakkala
from pathlib import Path
import torch
from eo_pl import TashkeelModel as TashkeelModelEO
from ed_pl import TashkeelModel as TashkeelModelED
from tashkeel_tokenizer import TashkeelTokenizer
from utils import remove_non_arabic
import spaces

# Initialize the Shakkala model
sh = Shakkala(version=3)
model, graph = sh.get_model()

def infer_shakkala(input_text):
    input_int = sh.prepare_input(input_text)
    logits = model.predict(input_int)[0]
    predicted_harakat = sh.logits_to_text(logits)
    final_output = sh.get_final_text(input_text, predicted_harakat)
    print(final_output)
    return final_output

# Initialize the CaTT model and tokenizer
tokenizer = TashkeelTokenizer()
eo_ckpt_path = Path(__file__).parent / 'models/best_eo_mlm_ns_epoch_193.pt'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device:', device)

max_seq_len = 1024
print('Creating Model...')
eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)

eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device)).eval().to(device)
ed_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device)).eval().to(device)

@spaces.GPU()
def infer_catt(input_text, choose_model):
    input_text = remove_non_arabic(input_text)
    batch_size = 16
    verbose = True
    if choose_model == 'Encoder-Only':
        output_text = eo_model.do_tashkeel_batch([input_text], batch_size, verbose)
    else:
        output_text = ed_model.do_tashkeel_batch([input_text], batch_size, verbose)

    return output_text[0]

with gr.Blocks(title="Arabic Tashkeel") as demo:
    gr.HTML("<center><h1>Arabic Tashkeel</h1></center>")
    gr.HTML(
        "<center><p>Compare different methods for adding tashkeel to Arabic text.</p></center>"
    )
    with gr.Tab(label="CATT"):
        gr.Markdown("[CATT](https://github.com/abjadai/catt) is a new deep learning model for Arabic diacritization.")
        with gr.Row():
            with gr.Column():
                text_input1 = gr.Textbox(label="Input Text", rtl=True, text_align="right")
                choose_model = gr.Radio(
                    label="Choose Model",
                    choices=["Encoder-Only", "Encoder-Decoder"],
                    default="Encoder-Decoder",
                )
                with gr.Row():
                    clear_button1 = gr.Button(value="Clear", variant="secondary")
                    submit_button1 = gr.Button(value="Add Tashkeel", variant="primary")

            with gr.Column():
                text_output1 = gr.Textbox(label="Output Text", rtl=True, text_align="right")

    submit_button1.click(infer_catt, inputs=[text_input1, choose_model], outputs=text_output1)
    clear_button1.click(lambda: text_input1.update(""))

    with gr.Tab(label="Shakkala"):
        with gr.Row():
            with gr.Column():
                text_input2 = gr.Textbox(
                    lines=1, label="Input Text", rtl=True, text_align="right"
                )
                with gr.Row():
                    clear_button2 = gr.Button(value="Clear", variant="secondary")
                    submit_button2 = gr.Button(
                        value="Apply Tashkeel", variant="primary"
                    )

            with gr.Column():
                text_output2 = gr.Textbox(
                    lines=1, label="Output Text", rtl=True, text_align="right"
                )

    submit_button2.click(infer_shakkala, inputs=text_input2, outputs=text_output2)
    clear_button2.click(lambda: text_input2.update(""))

demo.queue().launch()