File size: 11,308 Bytes
b370d7d
8adac94
61e3d24
927e645
61e3d24
 
406b226
 
 
 
ded4e8a
 
61e3d24
406b226
61e3d24
927e645
 
61e3d24
ded4e8a
043a39c
927e645
 
406b226
927e645
 
8adac94
406b226
 
927e645
 
 
 
 
 
406b226
927e645
406b226
927e645
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406b226
 
8adac94
406b226
 
 
 
 
8adac94
a9ad278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8adac94
 
406b226
927e645
 
 
 
 
 
 
 
 
 
406b226
927e645
 
 
 
 
 
 
 
 
043a39c
 
406b226
927e645
 
a9ad278
406b226
 
 
043a39c
dbf4b43
406b226
 
 
 
 
 
 
043a39c
dbf4b43
406b226
 
 
 
 
 
043a39c
927e645
 
 
 
 
 
 
 
 
 
 
ded4e8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406b226
927e645
 
 
ded4e8a
5da8400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61e3d24
ded4e8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61e3d24
8998838
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os, time
from threading import Thread
import gradio as gr
import spaces
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration
from reportlab.platypus import SimpleDocTemplate, Paragraph
from reportlab.lib.styles import getSampleStyleSheet
from docx import Document
from gTTS import gTTS
from jiwer import cer

# ---------------- Models ----------------
MODEL_PATHS = {
    "Model 1 (Complex handwrittings )": ("prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration),
    "Model 2 (simple and scanned handwritting )": ("nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration),
}
# Model 3 has been removed to conserve memory.

MAX_NEW_TOKENS_DEFAULT = 512
device = "cuda" if torch.cuda.is_available() else "cpu"
_loaded_processors, _loaded_models = {}, {}

print("πŸš€ Preloading models into GPU/CPU memory...")
for name, (repo_id, cls) in MODEL_PATHS.items():
    try:
        processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
        model = cls.from_pretrained(
            repo_id,
            trust_remote_code=True,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True
        ).to(device).eval()
        _loaded_processors[name], _loaded_models[name] = processor, model
        print(f"βœ… {name} ready.")
    except Exception as e:
        print(f"⚠️ Failed to load {name}: {e}")

# ---------------- GPU Warmup ----------------
@spaces.GPU
def warmup(progress=gr.Progress(track_tqdm=True)):
    try:
        default_model_choice = next(iter(MODEL_PATHS.keys()))
        processor = _loaded_processors[default_model_choice]
        model = _loaded_models[default_model_choice]
        tokenizer = getattr(processor, "tokenizer", None)
        messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
        chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if tokenizer and hasattr(tokenizer, "apply_chat_template") else "Warmup."
        inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device)
        with torch.inference_mode(): _ = model.generate(**inputs, max_new_tokens=1)
        return f"GPU warm and {default_model_choice} ready."
    except Exception as e:
        return f"Warmup skipped: {e}"

# ---------------- Helpers ----------------
def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
    messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
    if tokenizer and hasattr(tokenizer, "apply_chat_template"):
        chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        return processor(text=[chat_prompt], images=[image], return_tensors="pt")
    return processor(text=[prompt], images=[image], return_tensors="pt")

def _decode_text(model, processor, tokenizer, output_ids, prompt: str):
    try:
        decoded_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
        prompt_start = decoded_text.find(prompt)
        if prompt_start != -1:
            decoded_text = decoded_text[prompt_start + len(prompt):].strip()
        else:
            decoded_text = decoded_text.strip()
        return decoded_text
    except Exception:
        try:
            decoded_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
            prompt_start = decoded_text.find(prompt)
            if prompt_start != -1:
                decoded_text = decoded_text[prompt_start + len(prompt):].strip()
            return decoded_text
        except Exception:
            return str(output_ids).strip()

def _default_prompt(query: str | None) -> str:
    if query and query.strip(): return query.strip()
    return (
        "You are a professional Handwritten OCR system.\n"
        "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
        "- Preserve original structure and line breaks.\n"
        "- Keep spacing, bullet points, numbering, and indentation.\n"
        "- Render tables as Markdown tables if present.\n"
        "- Do NOT autocorrect spelling or grammar.\n"
        "- Do NOT merge lines.\n"
        "Return RAW transcription only."
    )

# ---------------- OCR Function ----------------
@spaces.GPU
def ocr_image(image: Image.Image, model_choice: str, query: str = None,
              max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
              temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
              progress=gr.Progress(track_tqdm=True)):
    if image is None: return "Please upload or capture an image."
    if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
    processor, model, tokenizer = _loaded_processors[model_choice], _loaded_models[model_choice], getattr(_loaded_processors[model_choice], "tokenizer", None)
    prompt = _default_prompt(query)
    batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
    with torch.inference_mode():
        output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
                                    temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
    return _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip()

# ---------------- Export Helpers ----------------
def _safe_text(text: str) -> str: return (text or "").strip()

def save_as_pdf(text):
    text = _safe_text(text)
    if not text: return None
    doc = SimpleDocTemplate("output.pdf")
    flowables = [Paragraph(t, getSampleStyleSheet()["Normal"]) for t in text.splitlines() if t != ""]
    if not flowables: flowables = [Paragraph(" ", getSampleStyleSheet()["Normal"])]
    doc.build(flowables)
    return "output.pdf"

def save_as_word(text):
    text = _safe_text(text)
    if not text: return None
    doc = Document()
    for line in text.splitlines(): doc.add_paragraph(line)
    doc.save("output.docx")
    return "output.docx"

def save_as_audio(text):
    text = _safe_text(text)
    if not text: return None
    try: 
        tts = gTTS(text)
        tts.save("output.mp3")
        return "output.mp3"
    except Exception as e: 
        print(f"gTTS failed: {e}")
        return None

# ---------------- Metrics Function ----------------
def calculate_cer_score(ground_truth: str, prediction: str) -> str:
    """
    Calculates the Character Error Rate (CER) between two strings.
    A CER of 0.0 means the prediction is perfect.
    """
    if not ground_truth or not prediction:
        return "Cannot calculate CER: Missing ground truth or prediction."
    
    ground_truth_cleaned = " ".join(ground_truth.strip().split())
    prediction_cleaned = " ".join(prediction.strip().split())
    
    error_rate = cer(ground_truth_cleaned, prediction_cleaned)
    return f"Character Error Rate (CER): {error_rate:.4f}"

# ---------------- Evaluation Orchestration ----------------
@spaces.GPU
def perform_evaluation(image: Image.Image, model_name: str, ground_truth: str,
                       max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
    if image is None or not ground_truth:
        return "Please upload an image and provide the ground truth.", "N/A"
    
    prediction = ocr_image(image, model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
    
    cer_score = calculate_cer_score(ground_truth, prediction)
    
    return prediction, cer_score

# ---------------- Gradio Interface ----------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## ✍🏾 wilson Handwritten OCR")
    model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model")

    with gr.Tab("πŸ–Ό Image Inference"):
        query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
        image_input = gr.Image(type="pil", label="Upload / Capture Handwritten Image", sources=["upload", "webcam"])
        with gr.Accordion("βš™οΈ Advanced Options", open=False):
            max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
            temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
            top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
            top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
            repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
        extract_btn = gr.Button("πŸ“€ Extract RAW Text", variant="primary")
        clear_btn = gr.Button("🧹 Clear")
        raw_output = gr.Textbox(label="πŸ“œ RAW Structured Output (exact as written)", lines=18, show_copy_button=True)
        pdf_btn = gr.Button("⬇️ Download as PDF")
        word_btn = gr.Button("⬇️ Download as Word")
        audio_btn = gr.Button("πŸ”Š Download as Audio")
        pdf_file, word_file, audio_file = gr.File(label="PDF File"), gr.File(label="Word File"), gr.File(label="Audio File")

        extract_btn.click(fn=ocr_image, inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[raw_output], api_name="ocr_image")
        pdf_btn.click(fn=save_as_pdf, inputs=[raw_output], outputs=[pdf_file])
        word_btn.click(fn=save_as_word, inputs=[raw_output], outputs=[word_file])
        audio_btn.click(fn=save_as_audio, inputs=[raw_output], outputs=[audio_file])
        clear_btn.click(fn=lambda: ("", None, "", MAX_NEW_TOKENS_DEFAULT, 0.1, 1.0, 0, 1.0), outputs=[raw_output, image_input, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty])

    with gr.Tab("πŸ“Š Model Evaluation"):
        gr.Markdown("### πŸ” Evaluate Model Accuracy")
        eval_image_input = gr.Image(type="pil", label="Upload Image for Evaluation", sources=["upload"])
        eval_ground_truth = gr.Textbox(label="Ground Truth (Correct Transcription)", lines=10, placeholder="Type or paste the correct text here.")
        eval_model_output = gr.Textbox(label="Model's Prediction", lines=10, interactive=False, show_copy_button=True)
        eval_cer_output = gr.Textbox(label="Metrics", interactive=False)
        
        with gr.Row():
            run_evaluation_btn = gr.Button("πŸš€ Run OCR and Evaluate", variant="primary")
            clear_evaluation_btn = gr.Button("🧹 Clear")
            
        run_evaluation_btn.click(
            fn=perform_evaluation,
            inputs=[eval_image_input, model_choice, eval_ground_truth, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
            outputs=[eval_model_output, eval_cer_output]
        )
        clear_evaluation_btn.click(
            fn=lambda: (None, "", "", ""),
            outputs=[eval_image_input, eval_ground_truth, eval_model_output, eval_cer_output]
        )

if __name__ == "__main__":
    demo.queue(max_size=50).launch(share=True)