File size: 3,849 Bytes
6bb168e
c3906f7
6bb168e
afedbd6
c3906f7
afedbd6
c3906f7
afedbd6
c3906f7
 
 
 
 
 
6bb168e
 
c3906f7
 
6bb168e
 
b0416c1
 
c3906f7
 
 
6bb168e
c3906f7
 
 
6bb168e
c3906f7
6bb168e
 
 
9919fac
c3906f7
afedbd6
c3906f7
 
 
 
 
 
 
 
8b34af2
c3906f7
 
 
 
 
 
 
 
 
 
afedbd6
c3906f7
afedbd6
c3906f7
 
 
 
 
afedbd6
c3906f7
 
 
 
 
 
 
 
 
 
 
afedbd6
c3906f7
afedbd6
c3906f7
 
 
 
9919fac
c3906f7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor, MllamaForConditionalGeneration
import torch
import re
from PIL import Image

# ---- GOT OCR Model Initialization and Extraction ----

def init_got_model():
    """Initialize GOT model and tokenizer."""
    model_name = "srimanth-d/GOT_CPU"
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, return_tensors='pt')
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
    return model.eval(), tokenizer

def extract_text_got(uploaded_file):
    """Extract text from the uploaded image using GOT model."""
    temp_file_path = 'temp_image_got.jpg'
    try:
        with open(temp_file_path, 'wb') as temp_file:
            temp_file.write(uploaded_file.read())

        print(f"Processing image using GOT from: {temp_file_path}")
        model, tokenizer = init_got_model()
        outputs = model.chat(tokenizer, temp_file_path, ocr_type='ocr')

        if outputs and isinstance(outputs, list):
            return outputs[0].strip() if outputs[0].strip() else "No text extracted."
        return "No text extracted."
    except Exception as e:
        return f"Error: {str(e)}"
    finally:
        if os.path.exists(temp_file_path):
            os.remove(temp_file_path)

# ---- Qwen OCR Model Initialization and Extraction ----

def init_qwen_model():
    """Initialize Qwen model and processor."""
    model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16)
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
    return model.eval(), processor

def extract_text_qwen(uploaded_file):
    """Extract text using Qwen model."""
    try:
        model, processor = init_qwen_model()
        image = Image.open(uploaded_file).convert('RGB')
        conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}]
        prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
        inputs = processor(text=[prompt], images=[image], return_tensors="pt")
        output_ids = model.generate(**inputs)
        output_text = processor.batch_decode(output_ids, skip_special_tokens=True)
        return output_text[0] if output_text else "No text extracted."
    except Exception as e:
        return f"Error: {str(e)}"

# ---- LLaMA OCR Model Initialization and Extraction ----

def init_llama_model():
    """Initialize LLaMA OCR model and processor."""
    model = MllamaForConditionalGeneration.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct", torch_dtype=torch.bfloat16, device_map="cpu")
    processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct")
    return model.eval(), processor

def extract_text_llama(uploaded_file):
    """Extract text using LLaMA model."""
    try:
        model, processor = init_llama_model()
        image = Image.open(uploaded_file).convert('RGB')
        prompt = "You are an OCR engine. Extract text from this image."
        inputs = processor(images=image, text=prompt, return_tensors="pt")
        output_ids = model.generate(**inputs)
        return processor.decode(output_ids[0], skip_special_tokens=True).strip()
    except Exception as e:
        return f"Error: {str(e)}"

# ---- AI-based Text Cleanup ----

def clean_extracted_text(text):
    """Clean the extracted text by removing extra spaces intelligently."""
    # Remove multiple spaces
    cleaned_text = re.sub(r'\s+', ' ', text).strip()

    # Further clean punctuations with spaces around them
    cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text)

    return cleaned_text