vk commited on
Commit
5919b75
·
1 Parent(s): c0118f4

first commit

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. .idea/.gitignore +3 -0
  3. app.py +64 -0
  4. requirements.txt +5 -0
  5. utils.py +41 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example/invoice1.png filter=lfs diff=lfs merge=lfs -text
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from peft import PeftModel, PeftConfig
3
+ from transformers import PaliGemmaForConditionalGeneration
4
+ import torch
5
+ from transformers import PaliGemmaProcessor
6
+ import PIL
7
+ from utils import parse_bbox_and_labels,display_boxes
8
+
9
+
10
+
11
+
12
+
13
+
14
+
15
+ def get_response(
16
+ image: PIL.Image.Image,
17
+ prompt: str,
18
+ max_new_tokens: str
19
+ ) -> str:
20
+ raw_image = image.convert("RGB")
21
+ width, height = raw_image.size
22
+
23
+ inputs = processor(raw_image, prompt, return_tensors="pt").to(device)
24
+ with torch.inference_mode():
25
+ output = peft_model.generate(**inputs, max_new_tokens=int(max_new_tokens))
26
+
27
+ input_len = inputs["input_ids"].shape[-1]
28
+ output = processor.decode(output[0][input_len:], skip_special_tokens=True)
29
+ print(output)
30
+ if "loc" in output:
31
+ boxes, labels = parse_bbox_and_labels(output)
32
+ raw_image=display_boxes(raw_image, boxes, labels, target_size=(width, height))
33
+
34
+ return output,raw_image
35
+
36
+
37
+ if __name__ == "__main__":
38
+
39
+ device = torch.device("cpu")
40
+ # bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) #for gpu
41
+ peft_model_id = "vk888/paligemma_vqav2"
42
+ model_id = "google/paligemma2-3b-pt-448"
43
+ config = PeftConfig.from_pretrained(peft_model_id)
44
+ base_model = PaliGemmaForConditionalGeneration.from_pretrained(config.base_model_name_or_path,
45
+ device_map=device) # , quantization_config=bnb_config)
46
+ peft_model = PeftModel.from_pretrained(base_model, peft_model_id)
47
+ processor = PaliGemmaProcessor.from_pretrained(model_id)
48
+
49
+ examples = [
50
+ ["example/invoice1.png","<image>answer en what is the balance due ?\n", 80],
51
+ ["example/invoice1.png","<image>detect signature\n", 80],
52
+ ["example/invoice1.png","<image>answer en what is the rate cada of design ?\n", 80],
53
+ ]
54
+
55
+ iface = gr.Interface(
56
+ cache_examples=False,
57
+ fn=get_response,
58
+ inputs=[gr.Image(type="pil"),gr.Textbox(placeholder="<image>answer en what is the balance due ?\n"),gr.Textbox(placeholder="80")],
59
+ examples=examples,
60
+ outputs=[gr.Textbox(), gr.Image(type="pil")],
61
+ title="DocVQA with Paligemma2 VLM",
62
+ description="DocVQA with Paligemma2 VLM"
63
+ )
64
+ iface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ --index-url https://download.pytorch.org/whl/cpu
2
+ torch
3
+
4
+ transformers==4.53.0.dev0
5
+ peft
utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ from PIL import ImageDraw
4
+
5
+ def parse_bbox_and_labels(detokenized_output: str):
6
+ matches = re.finditer(
7
+ '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
8
+ ' (?P<label>.+?)( ;|$)',
9
+ detokenized_output,
10
+ )
11
+ labels, boxes = [], []
12
+ fmt = lambda x: float(x) / 1024.0
13
+ for m in matches:
14
+ d = m.groupdict()
15
+ boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
16
+ labels.append(d['label'])
17
+ return np.array(boxes), np.array(labels)
18
+
19
+ def display_boxes(image, boxes, labels, target_size):
20
+ h, w = target_size
21
+ # fig, ax = plt.subplots()
22
+ # ax.imshow(image)
23
+ draw = ImageDraw.Draw(image)
24
+ for i in range(boxes.shape[0]):
25
+ y, x, y2, x2 = (boxes[i][0]*w,boxes[i][1]*h,boxes[i][2]*w,boxes[i][3]*h)
26
+ # width = x2 - x
27
+ # height = y2 - y
28
+ # Create a Rectangle patch
29
+ # rect = patches.Rectangle((x, y),
30
+ # width,
31
+ # height,
32
+ # linewidth=1,
33
+ # edgecolor='r',
34
+ # facecolor='none')
35
+ draw.rectangle((x,y,x2,y2) , outline="red", width=3)
36
+ # Add label
37
+ # plt.text(x, y, labels[i], color='red', fontsize=12)
38
+ # # Add the patch to the Axes
39
+ # ax.add_patch(rect)
40
+ # plt.show()
41
+ return image