Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
# 1) Load the base M2M100 tokenizer (avoids the “non-consecutive added token” error) | |
BASE_MODEL = "facebook/m2m100_418M" | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
# 2) Load your fine-tuned French⇄Zarma model | |
FINETUNED_MODEL = "Mamadou2727/Feriji_model" | |
model = AutoModelForSeq2SeqLM.from_pretrained(FINETUNED_MODEL) | |
# 3) Ensure the model’s embedding matrix matches the tokenizer vocab size | |
model.resize_token_embeddings(len(tokenizer)) | |
# 4) Move model to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
model.eval() | |
# 5) Correct ISO 639-3 code for Zarma (“dje”) | |
LANG_CODES = { | |
"French": "fr", | |
"Zarma": "yo" | |
} | |
def translate(text: str, num_seqs: int): | |
# set source & target language codes | |
tokenizer.src_lang = LANG_CODES["French"] | |
tokenizer.tgt_lang = LANG_CODES["Zarma"] | |
# tokenize & move to device | |
inputs = tokenizer(text, return_tensors="pt", padding=True).to(device) | |
# generate translations | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
forced_bos_token_id=tokenizer.lang_code_to_id[LANG_CODES["Zarma"]], | |
num_beams=num_seqs, | |
num_return_sequences=num_seqs, | |
length_penalty=1.0, | |
early_stopping=True | |
) | |
# decode & join multiple hypotheses | |
translations = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
return "\n\n".join(translations) | |
# 6) Build Gradio app | |
with gr.Blocks() as app: | |
gr.Markdown( | |
""" | |
# FERIJI Translator: French ⇄ Zarma | |
*Beta version – academic & research use only.* | |
""" | |
) | |
with gr.Row(): | |
inp = gr.Textbox(lines=7, label="Français / French") | |
beams = gr.Slider( | |
label="Nombre de séquences retournées", | |
minimum=1, maximum=5, value=1, step=1 | |
) | |
out = gr.Textbox(lines=7, label="Zarma") | |
btn = gr.Button("Traduire") | |
btn.click(fn=translate, inputs=[inp, beams], outputs=out, api_name="predict") | |
# On HF Spaces you don’t need share=True | |
app.launch() | |