Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,308 Bytes
b370d7d 8adac94 61e3d24 927e645 61e3d24 406b226 ded4e8a 61e3d24 406b226 61e3d24 927e645 61e3d24 ded4e8a 043a39c 927e645 406b226 927e645 8adac94 406b226 927e645 406b226 927e645 406b226 927e645 406b226 8adac94 406b226 8adac94 a9ad278 8adac94 406b226 927e645 406b226 927e645 043a39c 406b226 927e645 a9ad278 406b226 043a39c dbf4b43 406b226 043a39c dbf4b43 406b226 043a39c 927e645 ded4e8a 406b226 927e645 ded4e8a 5da8400 61e3d24 ded4e8a 61e3d24 8998838 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import os, time
from threading import Thread
import gradio as gr
import spaces
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration
from reportlab.platypus import SimpleDocTemplate, Paragraph
from reportlab.lib.styles import getSampleStyleSheet
from docx import Document
from gTTS import gTTS
from jiwer import cer
# ---------------- Models ----------------
MODEL_PATHS = {
"Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
"Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
}
# Model 3 has been removed to conserve memory.
MAX_NEW_TOKENS_DEFAULT = 512
device = "cuda" if torch.cuda.is_available() else "cpu"
_loaded_processors, _loaded_models = {}, {}
print("π Preloading models into GPU/CPU memory...")
for name, (repo_id, cls) in MODEL_PATHS.items():
try:
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
model = cls.from_pretrained(
repo_id,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True
).to(device).eval()
_loaded_processors[name], _loaded_models[name] = processor, model
print(f"β
{name} ready.")
except Exception as e:
print(f"β οΈ Failed to load {name}: {e}")
# ---------------- GPU Warmup ----------------
@spaces.GPU
def warmup(progress=gr.Progress(track_tqdm=True)):
try:
default_model_choice = next(iter(MODEL_PATHS.keys()))
processor = _loaded_processors[default_model_choice]
model = _loaded_models[default_model_choice]
tokenizer = getattr(processor, "tokenizer", None)
messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if tokenizer and hasattr(tokenizer, "apply_chat_template") else "Warmup."
inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device)
with torch.inference_mode(): _ = model.generate(**inputs, max_new_tokens=1)
return f"GPU warm and {default_model_choice} ready."
except Exception as e:
return f"Warmup skipped: {e}"
# ---------------- Helpers ----------------
def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
if tokenizer and hasattr(tokenizer, "apply_chat_template"):
chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return processor(text=[chat_prompt], images=[image], return_tensors="pt")
return processor(text=[prompt], images=[image], return_tensors="pt")
def _decode_text(model, processor, tokenizer, output_ids, prompt: str):
try:
decoded_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
prompt_start = decoded_text.find(prompt)
if prompt_start != -1:
decoded_text = decoded_text[prompt_start + len(prompt):].strip()
else:
decoded_text = decoded_text.strip()
return decoded_text
except Exception:
try:
decoded_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
prompt_start = decoded_text.find(prompt)
if prompt_start != -1:
decoded_text = decoded_text[prompt_start + len(prompt):].strip()
return decoded_text
except Exception:
return str(output_ids).strip()
def _default_prompt(query: str | None) -> str:
if query and query.strip(): return query.strip()
return (
"You are a professional Handwritten OCR system.\n"
"TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
"- Preserve original structure and line breaks.\n"
"- Keep spacing, bullet points, numbering, and indentation.\n"
"- Render tables as Markdown tables if present.\n"
"- Do NOT autocorrect spelling or grammar.\n"
"- Do NOT merge lines.\n"
"Return RAW transcription only."
)
# ---------------- OCR Function ----------------
@spaces.GPU
def ocr_image(image: Image.Image, model_choice: str, query: str = None,
max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
progress=gr.Progress(track_tqdm=True)):
if image is None: return "Please upload or capture an image."
if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
processor, model, tokenizer = _loaded_processors[model_choice], _loaded_models[model_choice], getattr(_loaded_processors[model_choice], "tokenizer", None)
prompt = _default_prompt(query)
batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
with torch.inference_mode():
output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
return _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip()
# ---------------- Export Helpers ----------------
def _safe_text(text: str) -> str: return (text or "").strip()
def save_as_pdf(text):
text = _safe_text(text)
if not text: return None
doc = SimpleDocTemplate("output.pdf")
flowables = [Paragraph(t, getSampleStyleSheet()["Normal"]) for t in text.splitlines() if t != ""]
if not flowables: flowables = [Paragraph(" ", getSampleStyleSheet()["Normal"])]
doc.build(flowables)
return "output.pdf"
def save_as_word(text):
text = _safe_text(text)
if not text: return None
doc = Document()
for line in text.splitlines(): doc.add_paragraph(line)
doc.save("output.docx")
return "output.docx"
def save_as_audio(text):
text = _safe_text(text)
if not text: return None
try:
tts = gTTS(text)
tts.save("output.mp3")
return "output.mp3"
except Exception as e:
print(f"gTTS failed: {e}")
return None
# ---------------- Metrics Function ----------------
def calculate_cer_score(ground_truth: str, prediction: str) -> str:
"""
Calculates the Character Error Rate (CER) between two strings.
A CER of 0.0 means the prediction is perfect.
"""
if not ground_truth or not prediction:
return "Cannot calculate CER: Missing ground truth or prediction."
ground_truth_cleaned = " ".join(ground_truth.strip().split())
prediction_cleaned = " ".join(prediction.strip().split())
error_rate = cer(ground_truth_cleaned, prediction_cleaned)
return f"Character Error Rate (CER): {error_rate:.4f}"
# ---------------- Evaluation Orchestration ----------------
@spaces.GPU
def perform_evaluation(image: Image.Image, model_name: str, ground_truth: str,
max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
if image is None or not ground_truth:
return "Please upload an image and provide the ground truth.", "N/A"
prediction = ocr_image(image, model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
cer_score = calculate_cer_score(ground_truth, prediction)
return prediction, cer_score
# ---------------- Gradio Interface ----------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## βπΎ wilson Handwritten OCR")
model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model")
with gr.Tab("πΌ Image Inference"):
query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image", sources=["upload", "webcam"])
with gr.Accordion("βοΈ Advanced Options", open=False):
max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
extract_btn = gr.Button("π€ Extract RAW Text", variant="primary")
clear_btn = gr.Button("π§Ή Clear")
raw_output = gr.Textbox(label="π RAW Structured Output (exact as written)", lines=18, show_copy_button=True)
pdf_btn = gr.Button("β¬οΈ Download as PDF")
word_btn = gr.Button("β¬οΈ Download as Word")
audio_btn = gr.Button("π Download as Audio")
pdf_file, word_file, audio_file = gr.File(label="PDF File"), gr.File(label="Word File"), gr.File(label="Audio File")
extract_btn.click(fn=ocr_image, inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[raw_output], api_name="ocr_image")
pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
clear_btn.click(fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0), outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty])
with gr.Tab("π Model Evaluation"):
gr.Markdown("### π Evaluate Model Accuracy")
eval_image_input = gr.Image(type="pil", label="Upload Image for Evaluation", sources=["upload"])
eval_ground_truth = gr.Textbox(label="Ground Truth (Correct Transcription)", lines=10, placeholder="Type or paste the correct text here.")
eval_model_output = gr.Textbox(label="Model's Prediction", lines=10, interactive=False, show_copy_button=True)
eval_cer_output = gr.Textbox(label="Metrics", interactive=False)
with gr.Row():
run_evaluation_btn = gr.Button("π Run OCR and Evaluate", variant="primary")
clear_evaluation_btn = gr.Button("π§Ή Clear")
run_evaluation_btn.click(
fn=perform_evaluation,
inputs=[eval_image_input, model_choice, eval_ground_truth, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=[eval_model_output, eval_cer_output]
)
clear_evaluation_btn.click(
fn=lambda: (None, "", "", ""),
outputs=[eval_image_input, eval_ground_truth, eval_model_output, eval_cer_output]
)
if __name__ == "__main__":
demo.queue(max_size=50).launch(share=True) |