DualTextOCRFusion / ocr_cpu.py
UniquePratham's picture
Update ocr_cpu.py
c3906f7 verified
raw
history blame
3.85 kB
import os
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor, MllamaForConditionalGeneration
import torch
import re
from PIL import Image
# ---- GOT OCR Model Initialization and Extraction ----
def init_got_model():
"""Initialize GOT model and tokenizer."""
model_name = "srimanth-d/GOT_CPU"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, return_tensors='pt')
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
return model.eval(), tokenizer
def extract_text_got(uploaded_file):
"""Extract text from the uploaded image using GOT model."""
temp_file_path = 'temp_image_got.jpg'
try:
with open(temp_file_path, 'wb') as temp_file:
temp_file.write(uploaded_file.read())
print(f"Processing image using GOT from: {temp_file_path}")
model, tokenizer = init_got_model()
outputs = model.chat(tokenizer, temp_file_path, ocr_type='ocr')
if outputs and isinstance(outputs, list):
return outputs[0].strip() if outputs[0].strip() else "No text extracted."
return "No text extracted."
except Exception as e:
return f"Error: {str(e)}"
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
# ---- Qwen OCR Model Initialization and Extraction ----
def init_qwen_model():
"""Initialize Qwen model and processor."""
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
return model.eval(), processor
def extract_text_qwen(uploaded_file):
"""Extract text using Qwen model."""
try:
model, processor = init_qwen_model()
image = Image.open(uploaded_file).convert('RGB')
conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(text=[prompt], images=[image], return_tensors="pt")
output_ids = model.generate(**inputs)
output_text = processor.batch_decode(output_ids, skip_special_tokens=True)
return output_text[0] if output_text else "No text extracted."
except Exception as e:
return f"Error: {str(e)}"
# ---- LLaMA OCR Model Initialization and Extraction ----
def init_llama_model():
"""Initialize LLaMA OCR model and processor."""
model = MllamaForConditionalGeneration.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct", torch_dtype=torch.bfloat16, device_map="cpu")
processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct")
return model.eval(), processor
def extract_text_llama(uploaded_file):
"""Extract text using LLaMA model."""
try:
model, processor = init_llama_model()
image = Image.open(uploaded_file).convert('RGB')
prompt = "You are an OCR engine. Extract text from this image."
inputs = processor(images=image, text=prompt, return_tensors="pt")
output_ids = model.generate(**inputs)
return processor.decode(output_ids[0], skip_special_tokens=True).strip()
except Exception as e:
return f"Error: {str(e)}"
# ---- AI-based Text Cleanup ----
def clean_extracted_text(text):
"""Clean the extracted text by removing extra spaces intelligently."""
# Remove multiple spaces
cleaned_text = re.sub(r'\s+', ' ', text).strip()
# Further clean punctuations with spaces around them
cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text)
return cleaned_text