MohamedRashad's picture
chore: Add requirements for shakkala and kaldialign
bcc0c7f
raw
history blame
3.73 kB
import gradio as gr
from shakkala import Shakkala
from pathlib import Path
import torch
from eo_pl import TashkeelModel as TashkeelModelEO
from ed_pl import TashkeelModel as TashkeelModelED
from tashkeel_tokenizer import TashkeelTokenizer
from utils import remove_non_arabic
import spaces
# Initialize the Shakkala model
sh = Shakkala(version=3)
model, graph = sh.get_model()
def infer_shakkala(input_text):
input_int = sh.prepare_input(input_text)
logits = model.predict(input_int)[0]
predicted_harakat = sh.logits_to_text(logits)
final_output = sh.get_final_text(input_text, predicted_harakat)
print(final_output)
return final_output
# Initialize the CaTT model and tokenizer
tokenizer = TashkeelTokenizer()
eo_ckpt_path = Path(__file__).parent / 'models/best_eo_mlm_ns_epoch_193.pt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device:', device)
max_seq_len = 1024
print('Creating Model...')
eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)
eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device)).eval().to(device)
ed_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device)).eval().to(device)
@spaces.GPU()
def infer_catt(input_text, choose_model):
input_text = remove_non_arabic(input_text)
batch_size = 16
verbose = True
if choose_model == 'Encoder-Only':
output_text = eo_model.do_tashkeel_batch([input_text], batch_size, verbose)
else:
output_text = ed_model.do_tashkeel_batch([input_text], batch_size, verbose)
return output_text[0]
with gr.Blocks(title="Arabic Tashkeel") as demo:
gr.HTML("<center><h1>Arabic Tashkeel</h1></center>")
gr.HTML(
"<center><p>Compare different methods for adding tashkeel to Arabic text.</p></center>"
)
with gr.Tab(label="CATT"):
gr.Markdown("[CATT](https://github.com/abjadai/catt) is a new deep learning model for Arabic diacritization.")
with gr.Row():
with gr.Column():
text_input1 = gr.Textbox(label="Input Text", rtl=True, text_align="right")
choose_model = gr.Radio(
label="Choose Model",
choices=["Encoder-Only", "Encoder-Decoder"],
default="Encoder-Decoder",
)
with gr.Row():
clear_button1 = gr.Button(value="Clear", variant="secondary")
submit_button1 = gr.Button(value="Add Tashkeel", variant="primary")
with gr.Column():
text_output1 = gr.Textbox(label="Output Text", rtl=True, text_align="right")
submit_button1.click(infer_catt, inputs=[text_input1, choose_model], outputs=text_output1)
clear_button1.click(lambda: text_input1.update(""))
with gr.Tab(label="Shakkala"):
with gr.Row():
with gr.Column():
text_input2 = gr.Textbox(
lines=1, label="Input Text", rtl=True, text_align="right"
)
with gr.Row():
clear_button2 = gr.Button(value="Clear", variant="secondary")
submit_button2 = gr.Button(
value="Apply Tashkeel", variant="primary"
)
with gr.Column():
text_output2 = gr.Textbox(
lines=1, label="Output Text", rtl=True, text_align="right"
)
submit_button2.click(infer_shakkala, inputs=text_input2, outputs=text_output2)
clear_button2.click(lambda: text_input2.update(""))
demo.queue().launch()