Spaces:
Running
Running
import os | |
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor, MllamaForConditionalGeneration | |
import torch | |
import re | |
from PIL import Image | |
# ---- GOT OCR Model Initialization and Extraction ---- | |
def init_got_model(): | |
"""Initialize GOT model and tokenizer.""" | |
model_name = "srimanth-d/GOT_CPU" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, return_tensors='pt') | |
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
return model.eval(), tokenizer | |
def extract_text_got(uploaded_file): | |
"""Extract text from the uploaded image using GOT model.""" | |
temp_file_path = 'temp_image_got.jpg' | |
try: | |
with open(temp_file_path, 'wb') as temp_file: | |
temp_file.write(uploaded_file.read()) | |
print(f"Processing image using GOT from: {temp_file_path}") | |
model, tokenizer = init_got_model() | |
outputs = model.chat(tokenizer, temp_file_path, ocr_type='ocr') | |
if outputs and isinstance(outputs, list): | |
return outputs[0].strip() if outputs[0].strip() else "No text extracted." | |
return "No text extracted." | |
except Exception as e: | |
return f"Error: {str(e)}" | |
finally: | |
if os.path.exists(temp_file_path): | |
os.remove(temp_file_path) | |
# ---- Qwen OCR Model Initialization and Extraction ---- | |
def init_qwen_model(): | |
"""Initialize Qwen model and processor.""" | |
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16) | |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | |
return model.eval(), processor | |
def extract_text_qwen(uploaded_file): | |
"""Extract text using Qwen model.""" | |
try: | |
model, processor = init_qwen_model() | |
image = Image.open(uploaded_file).convert('RGB') | |
conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}] | |
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
inputs = processor(text=[prompt], images=[image], return_tensors="pt") | |
output_ids = model.generate(**inputs) | |
output_text = processor.batch_decode(output_ids, skip_special_tokens=True) | |
return output_text[0] if output_text else "No text extracted." | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# ---- LLaMA OCR Model Initialization and Extraction ---- | |
def init_llama_model(): | |
"""Initialize LLaMA OCR model and processor.""" | |
model = MllamaForConditionalGeneration.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct", torch_dtype=torch.bfloat16, device_map="cpu") | |
processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct") | |
return model.eval(), processor | |
def extract_text_llama(uploaded_file): | |
"""Extract text using LLaMA model.""" | |
try: | |
model, processor = init_llama_model() | |
image = Image.open(uploaded_file).convert('RGB') | |
prompt = "You are an OCR engine. Extract text from this image." | |
inputs = processor(images=image, text=prompt, return_tensors="pt") | |
output_ids = model.generate(**inputs) | |
return processor.decode(output_ids[0], skip_special_tokens=True).strip() | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# ---- AI-based Text Cleanup ---- | |
def clean_extracted_text(text): | |
"""Clean the extracted text by removing extra spaces intelligently.""" | |
# Remove multiple spaces | |
cleaned_text = re.sub(r'\s+', ' ', text).strip() | |
# Further clean punctuations with spaces around them | |
cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text) | |
return cleaned_text | |