import os import io import re from typing import List, Tuple import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline import docx from docx.enum.text import WD_ALIGN_PARAGRAPH from docx.text.paragraph import Paragraph as DocxParagraph import fitz # PyMuPDF from reportlab.lib.pagesizes import A4 from reportlab.lib.styles import getSampleStyleSheet from reportlab.lib.enums import TA_JUSTIFY from reportlab.platypus import SimpleDocTemplate, Paragraph as RLParagraph, Spacer, PageBreak from reportlab.lib.units import cm from html import escape as html_escape # --- Disable compile/dynamo to avoid meta tensor issues --- os.environ["TORCH_COMPILE_DISABLE"] = "1" os.environ["TORCHDYNAMO_DISABLE"] = "1" os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1") # --- Config --- MODEL_REPO = "Toadoum/ngambay-fr-v1" FR_CODE_PREFERRED = "fra_Latn" # French (NLLB) FR_CODE_ALT = "fr_Latn" # Some custom models use this NG_CODE_PREFERRED = "sba_Latn" # Ngambay (Saba) Latin # --- Inference params --- MAX_NEW_TOKENS = 256 TEMPERATURE = 0.0 # not used when do_sample=False # --- Device selection --- device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # --- Load model & tokenizer --- print("Loading tokenizer and model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) # Load model with appropriate dtype and device model_kwargs = {"torch_dtype": torch.float16} if device == "cuda" else {} model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_REPO, **model_kwargs) model = model.to(device) # Move model to device after full loading print(f"Model loaded on: {model.device}") # Ensure a pad token to avoid generate() quirks if tokenizer.pad_token_id is None: if tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token elif tokenizer.unk_token is not None: tokenizer.pad_token = tokenizer.unk_token else: tokenizer.add_special_tokens({"pad_token": ""}) model.resize_token_embeddings(len(tokenizer)) model.config.pad_token_id = tokenizer.pad_token_id # --- Language code resolution --- def _resolve_lang_code(preferred: str, alt: str | None) -> str: codes = getattr(tokenizer, "lang_code_to_id", None) if isinstance(codes, dict) and len(codes) > 0: if preferred in codes: return preferred if alt and alt in codes: return alt if hasattr(tokenizer, "get_lang_id"): try: tokenizer.get_lang_id(preferred) return preferred except Exception: if alt: try: tokenizer.get_lang_id(alt) return alt except Exception: pass return preferred FR_CODE = _resolve_lang_code(FR_CODE_PREFERRED, FR_CODE_ALT) NG_CODE = _resolve_lang_code(NG_CODE_PREFERRED, None) # --- Helpers --- def _token_len(s: str) -> int: return len(tokenizer.encode(s, add_special_tokens=False)) def chunk_text_for_translation(text: str, max_src_tokens: int = 380) -> List[str]: parts = re.split(r'(\s*[\.\!\?…:;]\s+)', text) sentences = [] for i in range(0, len(parts), 2): s = parts[i] p = parts[i + 1] if i + 1 < len(parts) else "" unit = (s + (p or "")).strip() if unit: sentences.append(unit) chunks, current = [], "" for sent in sentences: if not current: current = sent continue candidate = f"{current} {sent}" if _token_len(candidate) <= max_src_tokens: current = candidate else: chunks.append(current.strip()) current = sent if current.strip(): chunks.append(current.strip()) return chunks if chunks else ([text] if text.strip() else []) # --- Translation functions --- def _translate_with_pipeline(text: str) -> str: translator = pipeline( task="translation", model=model, tokenizer=tokenizer, device=device, ) out = translator( text, src_lang=FR_CODE, tgt_lang=NG_CODE, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, ) key = "translation_text" if "translation_text" in out[0] else "generated_text" return out[0][key] def _translate_with_generate(text: str) -> str: if hasattr(tokenizer, "src_lang"): tokenizer.src_lang = FR_CODE inputs = tokenizer(text, return_tensors="pt").to(device) forced_bos = None lang2id = getattr(tokenizer, "lang_code_to_id", None) if isinstance(lang2id, dict) and NG_CODE in lang2id: forced_bos = lang2id[NG_CODE] elif hasattr(tokenizer, "convert_lang_code_to_id"): try: forced_bos = tokenizer.convert_lang_code_to_id(NG_CODE) except Exception: forced_bos = None gen_kwargs = dict(max_new_tokens=MAX_NEW_TOKENS, do_sample=False) if forced_bos is not None: gen_kwargs["forced_bos_token_id"] = torch.tensor([forced_bos], device=device) with torch.no_grad(): output_ids = model.generate(**inputs, **gen_kwargs) return tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] # --- Public APIs --- def translate_text_simple(text: str) -> str: text = (text or "").strip() if not text: return "" try: return _translate_with_pipeline(text) except Exception as e: print(f"Pipeline error: {e}. Falling back to generate().") return _translate_with_generate(text) def translate_large_text(text: str) -> str: chunks = chunk_text_for_translation(text) outputs = [] for ch in chunks: try: outputs.append(_translate_with_pipeline(ch)) except Exception as e: print(f"Pipeline error for chunk: {e}. Falling back to generate().") outputs.append(_translate_with_generate(ch)) return "\n".join(outputs).strip() # --- DOCX helpers --- def is_heading(par: DocxParagraph) -> Tuple[bool, int]: style_name = (par.style.name or "").lower() if not style_name: return False, 0 if "heading" in style_name or "titre" in style_name: for lvl in range(1, 10): if str(lvl) in style_name: return True, lvl return True, 1 return False, 0 def translate_docx_bytes(file_bytes: bytes) -> bytes: f = io.BytesIO(file_bytes) doc = docx.Document(f) new = docx.Document() for par in doc.paragraphs: text = par.text or "" if not text.strip(): new.add_paragraph("") continue is_head, lvl = is_heading(par) translated = translate_large_text(text) if is_head: new.add_heading(translated, level=min(max(lvl, 1), 9)) else: np = new.add_paragraph(translated) try: np.alignment = WD_ALIGN_PARAGRAPH.JUSTIFY except Exception: pass for table in doc.tables: new_table = new.add_table(rows=len(table.rows), cols=len(table.columns)) for r_idx, row in enumerate(table.rows): for c_idx, cell in enumerate(row.cells): cell_text = "\n".join(p.text for p in cell.paragraphs).strip() translated = translate_large_text(cell_text) if cell_text else "" tgt_cell = new_table.cell(r_idx, c_idx) tgt_cell.text = translated for p in tgt_cell.paragraphs: try: p.alignment = WD_ALIGN_PARAGRAPH.JUSTIFY except Exception: pass out = io.BytesIO() new.save(out) return out.getvalue() # --- PDF helpers --- def extract_pdf_text_blocks(pdf_bytes: bytes) -> List[List[str]]: pages_blocks: List[List[str]] = [] doc = fitz.open(stream=pdf_bytes, filetype="pdf") for page in doc: blocks = page.get_text("blocks") or [] blocks.sort(key=lambda b: (round(b[1], 1), round(b[0], 1))) page_texts = [] for b in blocks: text = (b[4] if len(b) > 4 else "") or "" text = text.strip() if text: page_texts.append(text) pages_blocks.append(page_texts) doc.close() return pages_blocks def build_pdf_from_blocks(translated_pages: List[List[str]]) -> bytes: buf = io.BytesIO() doc = SimpleDocTemplate( buf, pagesize=A4, rightMargin=2*cm, leftMargin=2*cm, topMargin=2*cm, bottomMargin=2*cm ) styles = getSampleStyleSheet() body = styles["BodyText"] body.alignment = TA_JUSTIFY body.leading = 14 story = [] for p_idx, blocks in enumerate(translated_pages): if p_idx > 0: story.append(PageBreak()) for blk in blocks: safe = html_escape(blk).replace("\n", "
") story.append(RLParagraph(safe, body)) story.append(Spacer(1, 0.35*cm)) doc.build(story) return buf.getvalue() def translate_pdf_bytes(file_bytes: bytes) -> bytes: pages_blocks = extract_pdf_text_blocks(file_bytes) translated_pages = [] for blocks in pages_blocks: t_blocks = [translate_large_text(blk) if blk else "" for blk in blocks] translated_pages.append(t_blocks) return build_pdf_from_blocks(translated_pages) # --- Gradio file handler --- def translate_document(file_path: str): if not file_path: return None, "Veuillez sélectionner un fichier .docx ou .pdf" try: name = os.path.basename(file_path) with open(file_path, "rb") as f: data = f.read() if name.lower().endswith(".docx"): out_bytes = translate_docx_bytes(data) out_path = "translated_ngambay.docx" with open(out_path, "wb") as f: f.write(out_bytes) return out_path, "✅ Traduction DOCX terminée (paragraphes justifiés)." if name.lower().endswith(".pdf"): out_bytes = translate_pdf_bytes(data) out_path = "translated_ngambay.pdf" with open(out_path, "wb") as f: f.write(out_bytes) return out_path, "✅ Traduction PDF terminée (paragraphes justifiés)." return None, "Type de fichier non supporté. Choisissez .docx ou .pdf" except Exception as e: return None, f"❌ Erreur pendant la traduction: {e}" # --- UI --- theme = gr.themes.Soft( primary_hue="indigo", radius_size="lg", font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui"] ).set( body_background_fill="#f7f7fb", button_primary_text_color="#ffffff" ) CUSTOM_CSS = """ .gradio-container {max-width: 980px !important;} .header-card { background: linear-gradient(135deg, #4f46e5 0%, #7c3aed 100%); color: white; padding: 22px; border-radius: 18px; box-shadow: 0 10px 30px rgba(79,70,229,.25); transition: transform .2s ease; } .header-card:hover { transform: translateY(-1px); } .header-title { font-size: 26px; font-weight: 800; margin: 0 0 6px 0; letter-spacing: .2px; } .header-sub { opacity: .98; font-size: 14px; } .brand { display:flex; align-items:center; gap:10px; justify-content:space-between; flex-wrap:wrap; } .badge { display:inline-block; background: rgba(255,255,255,.18); padding: 4px 10px; border-radius: 999px; font-size: 12px; border: 1px solid rgba(255,255,255,.25); } .footer-note { margin-top: 8px; color: #64748b; font-size: 12px; text-align: center; } .support-banner { margin-top: 14px; border-radius: 14px; padding: 14px 16px; background: linear-gradient(135deg, rgba(79,70,229,.08), rgba(124,58,237,.08)); border: 1px solid rgba(99,102,241,.25); box-shadow: 0 6px 18px rgba(79,70,229,.08); } .support-title { font-weight: 700; font-size: 16px; margin-bottom: 4px; } .support-text { font-size: 13px; color: #334155; line-height: 1.5; } .support-contacts { display: flex; gap: 10px; flex-wrap: wrap; margin-top: 8px; } .support-chip { display:inline-block; padding: 6px 10px; border-radius: 999px; background: white; border: 1px dashed rgba(79,70,229,.45); font-size: 12px; color: #3730a3; } """ with gr.Blocks( title="Français → Ngambay · Toadoum/ngambay-fr-v1", theme=theme, css=CUSTOM_CSS, fill_height=True, ) as demo: with gr.Group(elem_classes=["header-card"]): gr.HTML( """
Français → Ngambay (v1)
🚀 Version bêta · Merci de tester et partager vos retours pour améliorer la qualité de traduction.
Modèle : Toadoum/ngambay-fr-v1
""" ) with gr.Tabs(): with gr.Tab("Traduction de texte"): with gr.Row(): with gr.Column(scale=5): src = gr.Textbox( label="Texte source (Français)", placeholder="Saisissez votre texte en français…", lines=8, autofocus=True ) with gr.Row(): btn = gr.Button("Traduire", variant="primary", scale=3) clear_btn = gr.Button("Effacer", scale=1) gr.Examples( examples=[ ["Bonjour, comment allez-vous aujourd’hui ?"], ["La réunion de sensibilisation aura lieu demain au centre communautaire."], ["Merci pour votre participation et votre soutien."], ["Veuillez suivre les recommandations de santé pour protéger votre famille."] ], inputs=[src], label="Exemples (cliquez pour remplir)" ) with gr.Column(scale=5): tgt = gr.Textbox( label="Traduction (Ngambay)", lines=8, interactive=False, show_copy_button=True ) gr.Markdown('') with gr.Tab("Traduction de document (.docx / .pdf)"): with gr.Row(): with gr.Column(scale=5): doc_inp = gr.File( label="Sélectionnez un document (.docx ou .pdf)", file_types=[".docx", ".pdf"], type="filepath" ) run_doc = gr.Button("Traduire le document", variant="primary") with gr.Column(scale=5): doc_out = gr.File(label="Fichier traduit (télécharger)") doc_status = gr.Markdown("") run_doc.click(translate_document, inputs=doc_inp, outputs=[doc_out, doc_status]) gr.HTML( """
💙 Contribuer au projet (recrutement de linguistes)
Nous cherchons à recruter des linguistes pour renforcer la construction de données Ngambay. Si vous souhaitez soutenir financièrement ou en tant que bénévole, contactez-nous :
📱 WhatsApp, Airtel Money : +235 66 04 90 94 ✉️ Email : tsakayo@aimsammi.org
""" ) btn.click(translate_text_simple, inputs=src, outputs=tgt) clear_btn.click(lambda: ("", ""), outputs=[src, tgt]) if __name__ == "__main__": demo.queue(default_concurrency_limit=4).launch( ssr_mode=False, share=False if os.environ.get("SPACE_ID") else True, server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), show_error=True, )