import os from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor, MllamaForConditionalGeneration import torch import re from PIL import Image # ---- GOT OCR Model Initialization and Extraction ---- def init_got_model(): """Initialize GOT model and tokenizer.""" model_name = "srimanth-d/GOT_CPU" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, return_tensors='pt') model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id) return model.eval(), tokenizer def extract_text_got(uploaded_file): """Extract text from the uploaded image using GOT model.""" temp_file_path = 'temp_image_got.jpg' try: with open(temp_file_path, 'wb') as temp_file: temp_file.write(uploaded_file.read()) print(f"Processing image using GOT from: {temp_file_path}") model, tokenizer = init_got_model() outputs = model.chat(tokenizer, temp_file_path, ocr_type='ocr') if outputs and isinstance(outputs, list): return outputs[0].strip() if outputs[0].strip() else "No text extracted." return "No text extracted." except Exception as e: return f"Error: {str(e)}" finally: if os.path.exists(temp_file_path): os.remove(temp_file_path) # ---- Qwen OCR Model Initialization and Extraction ---- def init_qwen_model(): """Initialize Qwen model and processor.""" model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16) processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") return model.eval(), processor def extract_text_qwen(uploaded_file): """Extract text using Qwen model.""" try: model, processor = init_qwen_model() image = Image.open(uploaded_file).convert('RGB') conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}] prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = processor(text=[prompt], images=[image], return_tensors="pt") output_ids = model.generate(**inputs) output_text = processor.batch_decode(output_ids, skip_special_tokens=True) return output_text[0] if output_text else "No text extracted." except Exception as e: return f"Error: {str(e)}" # ---- LLaMA OCR Model Initialization and Extraction ---- def init_llama_model(): """Initialize LLaMA OCR model and processor.""" model = MllamaForConditionalGeneration.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct", torch_dtype=torch.bfloat16, device_map="cpu") processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct") return model.eval(), processor def extract_text_llama(uploaded_file): """Extract text using LLaMA model.""" try: model, processor = init_llama_model() image = Image.open(uploaded_file).convert('RGB') prompt = "You are an OCR engine. Extract text from this image." inputs = processor(images=image, text=prompt, return_tensors="pt") output_ids = model.generate(**inputs) return processor.decode(output_ids[0], skip_special_tokens=True).strip() except Exception as e: return f"Error: {str(e)}" # ---- AI-based Text Cleanup ---- def clean_extracted_text(text): """Clean the extracted text by removing extra spaces intelligently.""" # Remove multiple spaces cleaned_text = re.sub(r'\s+', ' ', text).strip() # Further clean punctuations with spaces around them cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text) return cleaned_text