Spaces:
Running
on
Zero
Running
on
Zero
import os, time | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
from PIL import Image | |
import torch | |
from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration | |
from reportlab.platypus import SimpleDocTemplate, Paragraph | |
from reportlab.lib.styles import getSampleStyleSheet | |
from docx import Document | |
from gTTS import gTTS | |
from jiwer import cer | |
# ---------------- Models ---------------- | |
MODEL_PATHS = { | |
"Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration), | |
"Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration), | |
} | |
# Model 3 has been removed to conserve memory. | |
MAX_NEW_TOKENS_DEFAULT = 512 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
_loaded_processors, _loaded_models = {}, {} | |
print("π Preloading models into GPU/CPU memory...") | |
for name, (repo_id, cls) in MODEL_PATHS.items(): | |
try: | |
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True) | |
model = cls.from_pretrained( | |
repo_id, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
low_cpu_mem_usage=True | |
).to(device).eval() | |
_loaded_processors[name], _loaded_models[name] = processor, model | |
print(f"β {name} ready.") | |
except Exception as e: | |
print(f"β οΈ Failed to load {name}: {e}") | |
# ---------------- GPU Warmup ---------------- | |
def warmup(progress=gr.Progress(track_tqdm=True)): | |
try: | |
default_model_choice = next(iter(MODEL_PATHS.keys())) | |
processor = _loaded_processors[default_model_choice] | |
model = _loaded_models[default_model_choice] | |
tokenizer = getattr(processor, "tokenizer", None) | |
messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}] | |
chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if tokenizer and hasattr(tokenizer, "apply_chat_template") else "Warmup." | |
inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device) | |
with torch.inference_mode(): _ = model.generate(**inputs, max_new_tokens=1) | |
return f"GPU warm and {default_model_choice} ready." | |
except Exception as e: | |
return f"Warmup skipped: {e}" | |
# ---------------- Helpers ---------------- | |
def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str): | |
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}] | |
if tokenizer and hasattr(tokenizer, "apply_chat_template"): | |
chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
return processor(text=[chat_prompt], images=[image], return_tensors="pt") | |
return processor(text=[prompt], images=[image], return_tensors="pt") | |
def _decode_text(model, processor, tokenizer, output_ids, prompt: str): | |
try: | |
decoded_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0] | |
prompt_start = decoded_text.find(prompt) | |
if prompt_start != -1: | |
decoded_text = decoded_text[prompt_start + len(prompt):].strip() | |
else: | |
decoded_text = decoded_text.strip() | |
return decoded_text | |
except Exception: | |
try: | |
decoded_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] | |
prompt_start = decoded_text.find(prompt) | |
if prompt_start != -1: | |
decoded_text = decoded_text[prompt_start + len(prompt):].strip() | |
return decoded_text | |
except Exception: | |
return str(output_ids).strip() | |
def _default_prompt(query: str | None) -> str: | |
if query and query.strip(): return query.strip() | |
return ( | |
"You are a professional Handwritten OCR system.\n" | |
"TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n" | |
"- Preserve original structure and line breaks.\n" | |
"- Keep spacing, bullet points, numbering, and indentation.\n" | |
"- Render tables as Markdown tables if present.\n" | |
"- Do NOT autocorrect spelling or grammar.\n" | |
"- Do NOT merge lines.\n" | |
"Return RAW transcription only." | |
) | |
# ---------------- OCR Function ---------------- | |
def ocr_image(image: Image.Image, model_choice: str, query: str = None, | |
max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT, | |
temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0, | |
progress=gr.Progress(track_tqdm=True)): | |
if image is None: return "Please upload or capture an image." | |
if model_choice not in _loaded_models: return f"Invalid model: {model_choice}" | |
processor, model, tokenizer = _loaded_processors[model_choice], _loaded_models[model_choice], getattr(_loaded_processors[model_choice], "tokenizer", None) | |
prompt = _default_prompt(query) | |
batch = _build_inputs(processor, tokenizer, image, prompt).to(device) | |
with torch.inference_mode(): | |
output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False, | |
temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty) | |
return _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip() | |
# ---------------- Export Helpers ---------------- | |
def _safe_text(text: str) -> str: return (text or "").strip() | |
def save_as_pdf(text): | |
text = _safe_text(text) | |
if not text: return None | |
doc = SimpleDocTemplate("output.pdf") | |
flowables = [Paragraph(t, getSampleStyleSheet()["Normal"]) for t in text.splitlines() if t != ""] | |
if not flowables: flowables = [Paragraph(" ", getSampleStyleSheet()["Normal"])] | |
doc.build(flowables) | |
return "output.pdf" | |
def save_as_word(text): | |
text = _safe_text(text) | |
if not text: return None | |
doc = Document() | |
for line in text.splitlines(): doc.add_paragraph(line) | |
doc.save("output.docx") | |
return "output.docx" | |
def save_as_audio(text): | |
text = _safe_text(text) | |
if not text: return None | |
try: | |
tts = gTTS(text) | |
tts.save("output.mp3") | |
return "output.mp3" | |
except Exception as e: | |
print(f"gTTS failed: {e}") | |
return None | |
# ---------------- Metrics Function ---------------- | |
def calculate_cer_score(ground_truth: str, prediction: str) -> str: | |
""" | |
Calculates the Character Error Rate (CER) between two strings. | |
A CER of 0.0 means the prediction is perfect. | |
""" | |
if not ground_truth or not prediction: | |
return "Cannot calculate CER: Missing ground truth or prediction." | |
ground_truth_cleaned = " ".join(ground_truth.strip().split()) | |
prediction_cleaned = " ".join(prediction.strip().split()) | |
error_rate = cer(ground_truth_cleaned, prediction_cleaned) | |
return f"Character Error Rate (CER): {error_rate:.4f}" | |
# ---------------- Evaluation Orchestration ---------------- | |
def perform_evaluation(image: Image.Image, model_name: str, ground_truth: str, | |
max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float): | |
if image is None or not ground_truth: | |
return "Please upload an image and provide the ground truth.", "N/A" | |
prediction = ocr_image(image, model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty) | |
cer_score = calculate_cer_score(ground_truth, prediction) | |
return prediction, cer_score | |
# ---------------- Gradio Interface ---------------- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("## βπΎ wilson Handwritten OCR") | |
model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model") | |
with gr.Tab("πΌ Image Inference"): | |
query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output") | |
image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image", sources=["upload", "webcam"]) | |
with gr.Accordion("βοΈ Advanced Options", open=False): | |
max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens") | |
temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature") | |
top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)") | |
top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k") | |
repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty") | |
extract_btn = gr.Button("π€ Extract RAW Text", variant="primary") | |
clear_btn = gr.Button("π§Ή Clear") | |
raw_output = gr.Textbox(label="π RAW Structured Output (exact as written)", lines=18, show_copy_button=True) | |
pdf_btn = gr.Button("β¬οΈ Download as PDF") | |
word_btn = gr.Button("β¬οΈ Download as Word") | |
audio_btn = gr.Button("π Download as Audio") | |
pdf_file, word_file, audio_file = gr.File(label="PDF File"), gr.File(label="Word File"), gr.File(label="Audio File") | |
extract_btn.click(fn=ocr_image, inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[raw_output], api_name="ocr_image") | |
pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file]) | |
word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file]) | |
audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file]) | |
clear_btn.click(fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0), outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty]) | |
with gr.Tab("π Model Evaluation"): | |
gr.Markdown("### π Evaluate Model Accuracy") | |
eval_image_input = gr.Image(type="pil", label="Upload Image for Evaluation", sources=["upload"]) | |
eval_ground_truth = gr.Textbox(label="Ground Truth (Correct Transcription)", lines=10, placeholder="Type or paste the correct text here.") | |
eval_model_output = gr.Textbox(label="Model's Prediction", lines=10, interactive=False, show_copy_button=True) | |
eval_cer_output = gr.Textbox(label="Metrics", interactive=False) | |
with gr.Row(): | |
run_evaluation_btn = gr.Button("π Run OCR and Evaluate", variant="primary") | |
clear_evaluation_btn = gr.Button("π§Ή Clear") | |
run_evaluation_btn.click( | |
fn=perform_evaluation, | |
inputs=[eval_image_input, model_choice, eval_ground_truth, max_new_tokens, temperature, top_p, top_k, repetition_penalty], | |
outputs=[eval_model_output, eval_cer_output] | |
) | |
clear_evaluation_btn.click( | |
fn=lambda: (None, "", "", ""), | |
outputs=[eval_image_input, eval_ground_truth, eval_model_output, eval_cer_output] | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=50).launch(share=True) |