UniquePratham commited on
Commit
6bb168e
1 Parent(s): 19b1a60

Update ocr_cpu.py

Browse files
Files changed (1) hide show
  1. ocr_cpu.py +98 -97
ocr_cpu.py CHANGED
@@ -1,97 +1,98 @@
1
- import os
2
- from transformers import AutoModel, AutoTokenizer
3
- import torch
4
-
5
- # Load model and tokenizer
6
- model_name = "ucaslcl/GOT-OCR2_0"
7
- tokenizer = AutoTokenizer.from_pretrained(
8
- model_name, trust_remote_code=True, return_tensors='pt'
9
- )
10
-
11
- # Load the model
12
- model = AutoModel.from_pretrained(
13
- model_name,
14
- trust_remote_code=True,
15
- low_cpu_mem_usage=True,
16
- use_safetensors=True,
17
- pad_token_id=tokenizer.eos_token_id,
18
- )
19
-
20
- # Ensure the model is in evaluation mode and loaded on CPU
21
- device = torch.device("cpu")
22
- dtype = torch.float32 # Use float32 on CPU
23
- model = model.eval().to(device)
24
-
25
- # OCR function
26
-
27
-
28
- def extract_text_got(uploaded_file):
29
- """Use GOT-OCR2.0 model to extract text from the uploaded image."""
30
- try:
31
- temp_file_path = 'temp_image.jpg'
32
- with open(temp_file_path, 'wb') as temp_file:
33
- temp_file.write(uploaded_file.read()) # Save file
34
-
35
- # OCR attempts
36
- ocr_types = ['ocr', 'format']
37
- fine_grained_options = ['ocr', 'format']
38
- color_options = ['red', 'green', 'blue']
39
- box = [10, 10, 100, 100] # Example box for demonstration
40
- multi_crop_types = ['ocr', 'format']
41
-
42
- results = []
43
-
44
- # Run the model without autocast (not necessary for CPU)
45
- for ocr_type in ocr_types:
46
- with torch.no_grad():
47
- outputs = model.chat(
48
- tokenizer, temp_file_path, ocr_type=ocr_type
49
- )
50
- if isinstance(outputs, list) and outputs[0].strip():
51
- return outputs[0].strip() # Return if successful
52
- results.append(outputs[0].strip() if outputs else "No result")
53
-
54
- # Try FINE-GRAINED OCR with box options
55
- for ocr_type in fine_grained_options:
56
- with torch.no_grad():
57
- outputs = model.chat(
58
- tokenizer, temp_file_path, ocr_type=ocr_type, ocr_box=box
59
- )
60
- if isinstance(outputs, list) and outputs[0].strip():
61
- return outputs[0].strip() # Return if successful
62
- results.append(outputs[0].strip() if outputs else "No result")
63
-
64
- # Try FINE-GRAINED OCR with color options
65
- for ocr_type in fine_grained_options:
66
- for color in color_options:
67
- with torch.no_grad():
68
- outputs = model.chat(
69
- tokenizer, temp_file_path, ocr_type=ocr_type, ocr_color=color
70
- )
71
- if isinstance(outputs, list) and outputs[0].strip():
72
- return outputs[0].strip() # Return if successful
73
- results.append(outputs[0].strip()
74
- if outputs else "No result")
75
-
76
- # Try MULTI-CROP OCR
77
- for ocr_type in multi_crop_types:
78
- with torch.no_grad():
79
- outputs = model.chat_crop(
80
- tokenizer, temp_file_path, ocr_type=ocr_type
81
- )
82
- if isinstance(outputs, list) and outputs[0].strip():
83
- return outputs[0].strip() # Return if successful
84
- results.append(outputs[0].strip() if outputs else "No result")
85
-
86
- # If no text was extracted
87
- if all(not text for text in results):
88
- return "No text extracted."
89
- else:
90
- return "\n".join(results)
91
-
92
- except Exception as e:
93
- return f"Error during text extraction: {str(e)}"
94
-
95
- finally:
96
- if os.path.exists(temp_file_path):
97
- os.remove(temp_file_path)
 
 
1
+ import os
2
+ from transformers import AutoModel, AutoTokenizer
3
+ import torch
4
+
5
+ # Load model and tokenizer
6
+ # model_name = "ucaslcl/GOT-OCR2_0"
7
+ model_name = "srimanth-d/GOT_CPU"
8
+ tokenizer = AutoTokenizer.from_pretrained(
9
+ model_name, trust_remote_code=True, return_tensors='pt'
10
+ )
11
+
12
+ # Load the model
13
+ model = AutoModel.from_pretrained(
14
+ model_name,
15
+ trust_remote_code=True,
16
+ low_cpu_mem_usage=True,
17
+ use_safetensors=True,
18
+ pad_token_id=tokenizer.eos_token_id,
19
+ )
20
+
21
+ # Ensure the model is in evaluation mode and loaded on CPU
22
+ device = torch.device("cpu")
23
+ dtype = torch.float32 # Use float32 on CPU
24
+ model = model.eval()
25
+
26
+ # OCR function
27
+
28
+
29
+ def extract_text_got(uploaded_file):
30
+ """Use GOT-OCR2.0 model to extract text from the uploaded image."""
31
+ try:
32
+ temp_file_path = 'temp_image.jpg'
33
+ with open(temp_file_path, 'wb') as temp_file:
34
+ temp_file.write(uploaded_file.read()) # Save file
35
+
36
+ # OCR attempts
37
+ ocr_types = ['ocr', 'format']
38
+ fine_grained_options = ['ocr', 'format']
39
+ color_options = ['red', 'green', 'blue']
40
+ box = [10, 10, 100, 100] # Example box for demonstration
41
+ multi_crop_types = ['ocr', 'format']
42
+
43
+ results = []
44
+
45
+ # Run the model without autocast (not necessary for CPU)
46
+ for ocr_type in ocr_types:
47
+ with torch.no_grad():
48
+ outputs = model.chat(
49
+ tokenizer, temp_file_path, ocr_type=ocr_type
50
+ )
51
+ if isinstance(outputs, list) and outputs[0].strip():
52
+ return outputs[0].strip() # Return if successful
53
+ results.append(outputs[0].strip() if outputs else "No result")
54
+
55
+ # Try FINE-GRAINED OCR with box options
56
+ for ocr_type in fine_grained_options:
57
+ with torch.no_grad():
58
+ outputs = model.chat(
59
+ tokenizer, temp_file_path, ocr_type=ocr_type, ocr_box=box
60
+ )
61
+ if isinstance(outputs, list) and outputs[0].strip():
62
+ return outputs[0].strip() # Return if successful
63
+ results.append(outputs[0].strip() if outputs else "No result")
64
+
65
+ # Try FINE-GRAINED OCR with color options
66
+ for ocr_type in fine_grained_options:
67
+ for color in color_options:
68
+ with torch.no_grad():
69
+ outputs = model.chat(
70
+ tokenizer, temp_file_path, ocr_type=ocr_type, ocr_color=color
71
+ )
72
+ if isinstance(outputs, list) and outputs[0].strip():
73
+ return outputs[0].strip() # Return if successful
74
+ results.append(outputs[0].strip()
75
+ if outputs else "No result")
76
+
77
+ # Try MULTI-CROP OCR
78
+ for ocr_type in multi_crop_types:
79
+ with torch.no_grad():
80
+ outputs = model.chat_crop(
81
+ tokenizer, temp_file_path, ocr_type=ocr_type
82
+ )
83
+ if isinstance(outputs, list) and outputs[0].strip():
84
+ return outputs[0].strip() # Return if successful
85
+ results.append(outputs[0].strip() if outputs else "No result")
86
+
87
+ # If no text was extracted
88
+ if all(not text for text in results):
89
+ return "No text extracted."
90
+ else:
91
+ return "\n".join(results)
92
+
93
+ except Exception as e:
94
+ return f"Error during text extraction: {str(e)}"
95
+
96
+ finally:
97
+ if os.path.exists(temp_file_path):
98
+ os.remove(temp_file_path)