UniquePratham commited on
Commit
c3906f7
1 Parent(s): a0652de

Update ocr_cpu.py

Browse files
Files changed (1) hide show
  1. ocr_cpu.py +63 -97
ocr_cpu.py CHANGED
@@ -1,122 +1,88 @@
1
- # ocr_cpu.py
2
-
3
  import os
 
4
  import torch
5
- from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
6
  import re
 
7
 
8
- # -----------------------------
9
- # OCR Model Initialization
10
- # -----------------------------
11
-
12
- # Load OCR model and tokenizer
13
- ocr_model_name = "srimanth-d/GOT_CPU" # Using GOT model on CPU
14
- ocr_tokenizer = AutoTokenizer.from_pretrained(
15
- ocr_model_name, trust_remote_code=True, return_tensors='pt'
16
- )
17
-
18
- # Load the OCR model
19
- ocr_model = AutoModel.from_pretrained(
20
- ocr_model_name,
21
- trust_remote_code=True,
22
- low_cpu_mem_usage=True,
23
- use_safetensors=True,
24
- pad_token_id=ocr_tokenizer.eos_token_id,
25
- )
26
-
27
- # Ensure the OCR model is in evaluation mode and loaded on CPU
28
- ocr_device = torch.device("cpu")
29
- ocr_model = ocr_model.eval().to(ocr_device)
30
-
31
- # -----------------------------
32
- # Text Cleaning Model Initialization
33
- # -----------------------------
34
 
35
- # Load Text Cleaning model and tokenizer
36
- clean_model_name = "gpt2" # You can choose a different model if preferred
37
- clean_tokenizer = AutoTokenizer.from_pretrained(clean_model_name)
38
- clean_model = AutoModelForCausalLM.from_pretrained(clean_model_name)
39
-
40
- # Ensure the Text Cleaning model is in evaluation mode and loaded on CPU
41
- clean_device = torch.device("cpu")
42
- clean_model = clean_model.eval().to(clean_device)
43
-
44
- # -----------------------------
45
- # OCR Function
46
- # -----------------------------
47
 
48
  def extract_text_got(uploaded_file):
49
- """
50
- Use GOT-OCR2.0 model to extract text from the uploaded image.
51
- """
52
- temp_file_path = 'temp_image.jpg'
53
-
54
  try:
55
- # Save the uploaded file temporarily
56
  with open(temp_file_path, 'wb') as temp_file:
57
  temp_file.write(uploaded_file.read())
58
 
59
- print(f"Processing image from path: {temp_file_path}")
60
-
61
- ocr_types = ['ocr', 'format']
62
- results = []
63
-
64
- # Run OCR on the image
65
- for ocr_type in ocr_types:
66
- with torch.no_grad():
67
- print(f"Running OCR with type: {ocr_type}")
68
- outputs = ocr_model.chat(ocr_tokenizer, temp_file_path, ocr_type=ocr_type)
69
-
70
- if isinstance(outputs, list) and outputs[0].strip():
71
- return outputs[0].strip() # Return the result if successful
72
- results.append(outputs[0].strip() if outputs else "No result")
73
-
74
- # Combine results or return no text found message
75
- return results[0] if results else "No text extracted."
76
 
 
 
 
77
  except Exception as e:
78
- return f"Error during text extraction: {str(e)}"
79
-
80
  finally:
81
- # Clean up temporary file
82
  if os.path.exists(temp_file_path):
83
  os.remove(temp_file_path)
84
- print(f"Temporary file {temp_file_path} removed.")
85
 
86
- # -----------------------------
87
- # Text Cleaning Function
88
- # -----------------------------
89
 
90
- def clean_text_with_ai(extracted_text):
91
- """
92
- Cleans extracted text by leveraging a language model to intelligently remove extra spaces and correct formatting.
93
- """
 
 
 
 
94
  try:
95
- # Define the prompt for cleaning
96
- prompt = f"Please clean the following text by removing extra spaces and ensuring proper formatting:\n\n{extracted_text}\n\nCleaned Text:"
 
 
 
 
 
 
 
 
97
 
98
- # Tokenize the input prompt
99
- inputs = clean_tokenizer.encode(prompt, return_tensors="pt").to(clean_device)
100
 
101
- # Generate the cleaned text
102
- with torch.no_grad():
103
- outputs = clean_model.generate(
104
- inputs,
105
- max_length=500, # Adjust as needed
106
- temperature=0.7,
107
- top_p=0.9,
108
- do_sample=True,
109
- eos_token_id=clean_tokenizer.eos_token_id,
110
- pad_token_id=clean_tokenizer.eos_token_id
111
- )
112
 
113
- # Decode the generated text
114
- cleaned_text = clean_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
115
 
116
- # Extract the cleaned text after the prompt
117
- cleaned_text = cleaned_text.split("Cleaned Text:")[-1].strip()
118
 
119
- return cleaned_text
 
 
 
120
 
121
- except Exception as e:
122
- return f"Error during AI text cleaning: {str(e)}"
 
 
 
 
 
1
  import os
2
+ from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor, MllamaForConditionalGeneration
3
  import torch
 
4
  import re
5
+ from PIL import Image
6
 
7
+ # ---- GOT OCR Model Initialization and Extraction ----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ def init_got_model():
10
+ """Initialize GOT model and tokenizer."""
11
+ model_name = "srimanth-d/GOT_CPU"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, return_tensors='pt')
13
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
14
+ return model.eval(), tokenizer
 
 
 
 
 
 
15
 
16
  def extract_text_got(uploaded_file):
17
+ """Extract text from the uploaded image using GOT model."""
18
+ temp_file_path = 'temp_image_got.jpg'
 
 
 
19
  try:
 
20
  with open(temp_file_path, 'wb') as temp_file:
21
  temp_file.write(uploaded_file.read())
22
 
23
+ print(f"Processing image using GOT from: {temp_file_path}")
24
+ model, tokenizer = init_got_model()
25
+ outputs = model.chat(tokenizer, temp_file_path, ocr_type='ocr')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ if outputs and isinstance(outputs, list):
28
+ return outputs[0].strip() if outputs[0].strip() else "No text extracted."
29
+ return "No text extracted."
30
  except Exception as e:
31
+ return f"Error: {str(e)}"
 
32
  finally:
 
33
  if os.path.exists(temp_file_path):
34
  os.remove(temp_file_path)
 
35
 
36
+ # ---- Qwen OCR Model Initialization and Extraction ----
 
 
37
 
38
+ def init_qwen_model():
39
+ """Initialize Qwen model and processor."""
40
+ model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16)
41
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
42
+ return model.eval(), processor
43
+
44
+ def extract_text_qwen(uploaded_file):
45
+ """Extract text using Qwen model."""
46
  try:
47
+ model, processor = init_qwen_model()
48
+ image = Image.open(uploaded_file).convert('RGB')
49
+ conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}]
50
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
51
+ inputs = processor(text=[prompt], images=[image], return_tensors="pt")
52
+ output_ids = model.generate(**inputs)
53
+ output_text = processor.batch_decode(output_ids, skip_special_tokens=True)
54
+ return output_text[0] if output_text else "No text extracted."
55
+ except Exception as e:
56
+ return f"Error: {str(e)}"
57
 
58
+ # ---- LLaMA OCR Model Initialization and Extraction ----
 
59
 
60
+ def init_llama_model():
61
+ """Initialize LLaMA OCR model and processor."""
62
+ model = MllamaForConditionalGeneration.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct", torch_dtype=torch.bfloat16, device_map="cpu")
63
+ processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct")
64
+ return model.eval(), processor
 
 
 
 
 
 
65
 
66
+ def extract_text_llama(uploaded_file):
67
+ """Extract text using LLaMA model."""
68
+ try:
69
+ model, processor = init_llama_model()
70
+ image = Image.open(uploaded_file).convert('RGB')
71
+ prompt = "You are an OCR engine. Extract text from this image."
72
+ inputs = processor(images=image, text=prompt, return_tensors="pt")
73
+ output_ids = model.generate(**inputs)
74
+ return processor.decode(output_ids[0], skip_special_tokens=True).strip()
75
+ except Exception as e:
76
+ return f"Error: {str(e)}"
77
 
78
+ # ---- AI-based Text Cleanup ----
 
79
 
80
+ def clean_extracted_text(text):
81
+ """Clean the extracted text by removing extra spaces intelligently."""
82
+ # Remove multiple spaces
83
+ cleaned_text = re.sub(r'\s+', ' ', text).strip()
84
 
85
+ # Further clean punctuations with spaces around them
86
+ cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text)
87
+
88
+ return cleaned_text