Emeritus-21's picture
Update app.py
ded4e8a verified
raw
history blame
11.3 kB
import os, time
from threading import Thread
import gradio as gr
import spaces
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration
from reportlab.platypus import SimpleDocTemplate, Paragraph
from reportlab.lib.styles import getSampleStyleSheet
from docx import Document
from gTTS import gTTS
from jiwer import cer
# ---------------- Models ----------------
MODEL_PATHS = {
"Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
"Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
}
# Model 3 has been removed to conserve memory.
MAX_NEW_TOKENS_DEFAULT = 512
device = "cuda" if torch.cuda.is_available() else "cpu"
_loaded_processors, _loaded_models = {}, {}
print("πŸš€ Preloading models into GPU/CPU memory...")
for name, (repo_id, cls) in MODEL_PATHS.items():
try:
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
model = cls.from_pretrained(
repo_id,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True
).to(device).eval()
_loaded_processors[name], _loaded_models[name] = processor, model
print(f"βœ… {name} ready.")
except Exception as e:
print(f"⚠️ Failed to load {name}: {e}")
# ---------------- GPU Warmup ----------------
@spaces.GPU
def warmup(progress=gr.Progress(track_tqdm=True)):
try:
default_model_choice = next(iter(MODEL_PATHS.keys()))
processor = _loaded_processors[default_model_choice]
model = _loaded_models[default_model_choice]
tokenizer = getattr(processor, "tokenizer", None)
messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if tokenizer and hasattr(tokenizer, "apply_chat_template") else "Warmup."
inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device)
with torch.inference_mode(): _ = model.generate(**inputs, max_new_tokens=1)
return f"GPU warm and {default_model_choice} ready."
except Exception as e:
return f"Warmup skipped: {e}"
# ---------------- Helpers ----------------
def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
if tokenizer and hasattr(tokenizer, "apply_chat_template"):
chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return processor(text=[chat_prompt], images=[image], return_tensors="pt")
return processor(text=[prompt], images=[image], return_tensors="pt")
def _decode_text(model, processor, tokenizer, output_ids, prompt: str):
try:
decoded_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
prompt_start = decoded_text.find(prompt)
if prompt_start != -1:
decoded_text = decoded_text[prompt_start + len(prompt):].strip()
else:
decoded_text = decoded_text.strip()
return decoded_text
except Exception:
try:
decoded_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
prompt_start = decoded_text.find(prompt)
if prompt_start != -1:
decoded_text = decoded_text[prompt_start + len(prompt):].strip()
return decoded_text
except Exception:
return str(output_ids).strip()
def _default_prompt(query: str | None) -> str:
if query and query.strip(): return query.strip()
return (
"You are a professional Handwritten OCR system.\n"
"TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
"- Preserve original structure and line breaks.\n"
"- Keep spacing, bullet points, numbering, and indentation.\n"
"- Render tables as Markdown tables if present.\n"
"- Do NOT autocorrect spelling or grammar.\n"
"- Do NOT merge lines.\n"
"Return RAW transcription only."
)
# ---------------- OCR Function ----------------
@spaces.GPU
def ocr_image(image: Image.Image, model_choice: str, query: str = None,
max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
progress=gr.Progress(track_tqdm=True)):
if image is None: return "Please upload or capture an image."
if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
processor, model, tokenizer = _loaded_processors[model_choice], _loaded_models[model_choice], getattr(_loaded_processors[model_choice], "tokenizer", None)
prompt = _default_prompt(query)
batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
with torch.inference_mode():
output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
return _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip()
# ---------------- Export Helpers ----------------
def _safe_text(text: str) -> str: return (text or "").strip()
def save_as_pdf(text):
text = _safe_text(text)
if not text: return None
doc = SimpleDocTemplate("output.pdf")
flowables = [Paragraph(t, getSampleStyleSheet()["Normal"]) for t in text.splitlines() if t != ""]
if not flowables: flowables = [Paragraph(" ", getSampleStyleSheet()["Normal"])]
doc.build(flowables)
return "output.pdf"
def save_as_word(text):
text = _safe_text(text)
if not text: return None
doc = Document()
for line in text.splitlines(): doc.add_paragraph(line)
doc.save("output.docx")
return "output.docx"
def save_as_audio(text):
text = _safe_text(text)
if not text: return None
try:
tts = gTTS(text)
tts.save("output.mp3")
return "output.mp3"
except Exception as e:
print(f"gTTS failed: {e}")
return None
# ---------------- Metrics Function ----------------
def calculate_cer_score(ground_truth: str, prediction: str) -> str:
"""
Calculates the Character Error Rate (CER) between two strings.
A CER of 0.0 means the prediction is perfect.
"""
if not ground_truth or not prediction:
return "Cannot calculate CER: Missing ground truth or prediction."
ground_truth_cleaned = " ".join(ground_truth.strip().split())
prediction_cleaned = " ".join(prediction.strip().split())
error_rate = cer(ground_truth_cleaned, prediction_cleaned)
return f"Character Error Rate (CER): {error_rate:.4f}"
# ---------------- Evaluation Orchestration ----------------
@spaces.GPU
def perform_evaluation(image: Image.Image, model_name: str, ground_truth: str,
max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
if image is None or not ground_truth:
return "Please upload an image and provide the ground truth.", "N/A"
prediction = ocr_image(image, model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
cer_score = calculate_cer_score(ground_truth, prediction)
return prediction, cer_score
# ---------------- Gradio Interface ----------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## ✍🏾 wilson Handwritten OCR")
model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model")
with gr.Tab("πŸ–Ό Image Inference"):
query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image", sources=["upload", "webcam"])
with gr.Accordion("βš™οΈ Advanced Options", open=False):
max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
extract_btn = gr.Button("πŸ“€ Extract RAW Text", variant="primary")
clear_btn = gr.Button("🧹 Clear")
raw_output = gr.Textbox(label="πŸ“œ RAW Structured Output (exact as written)", lines=18, show_copy_button=True)
pdf_btn = gr.Button("⬇️ Download as PDF")
word_btn = gr.Button("⬇️ Download as Word")
audio_btn = gr.Button("πŸ”Š Download as Audio")
pdf_file, word_file, audio_file = gr.File(label="PDF File"), gr.File(label="Word File"), gr.File(label="Audio File")
extract_btn.click(fn=ocr_image, inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[raw_output], api_name="ocr_image")
pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
clear_btn.click(fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0), outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty])
with gr.Tab("πŸ“Š Model Evaluation"):
gr.Markdown("### πŸ” Evaluate Model Accuracy")
eval_image_input = gr.Image(type="pil", label="Upload Image for Evaluation", sources=["upload"])
eval_ground_truth = gr.Textbox(label="Ground Truth (Correct Transcription)", lines=10, placeholder="Type or paste the correct text here.")
eval_model_output = gr.Textbox(label="Model's Prediction", lines=10, interactive=False, show_copy_button=True)
eval_cer_output = gr.Textbox(label="Metrics", interactive=False)
with gr.Row():
run_evaluation_btn = gr.Button("πŸš€ Run OCR and Evaluate", variant="primary")
clear_evaluation_btn = gr.Button("🧹 Clear")
run_evaluation_btn.click(
fn=perform_evaluation,
inputs=[eval_image_input, model_choice, eval_ground_truth, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=[eval_model_output, eval_cer_output]
)
clear_evaluation_btn.click(
fn=lambda: (None, "", "", ""),
outputs=[eval_image_input, eval_ground_truth, eval_model_output, eval_cer_output]
)
if __name__ == "__main__":
demo.queue(max_size=50).launch(share=True)