davanstrien HF staff commited on
Commit
1796549
β€’
1 Parent(s): 4e1ec1c
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -1,13 +1,16 @@
1
  import spaces
2
  import gradio as gr
3
 
4
- from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
5
  from qwen_vl_utils import process_vision_info
6
  import torch
7
  import os
8
  import json
 
 
9
 
10
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
 
11
  model = Qwen2VLForConditionalGeneration.from_pretrained(
12
  "Qwen/Qwen2-VL-7B-Instruct",
13
  torch_dtype=torch.bfloat16,
@@ -15,8 +18,7 @@ model = Qwen2VLForConditionalGeneration.from_pretrained(
15
  device_map="auto",
16
  )
17
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
18
- from pydantic import BaseModel
19
- from typing import Tuple
20
 
21
  class GeneralRetrievalQuery(BaseModel):
22
  broad_topical_query: str
@@ -26,6 +28,7 @@ class GeneralRetrievalQuery(BaseModel):
26
  visual_element_query: str
27
  visual_element_explanation: str
28
 
 
29
  def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
30
  if prompt_name != "general":
31
  raise ValueError("Only 'general' prompt is available in this version")
@@ -66,11 +69,11 @@ Generate the queries based on this image and provide the response in the specifi
66
  return prompt, GeneralRetrievalQuery
67
 
68
 
69
-
70
  prompt, pydantic_model = get_retrieval_prompt("general")
71
 
72
- @spaces.GPU
73
- def generate_response(image):
74
  messages = [
75
  {
76
  "role": "user",
@@ -97,6 +100,12 @@ def generate_response(image):
97
  padding=True,
98
  return_tensors="pt",
99
  )
 
 
 
 
 
 
100
  inputs = inputs.to("cuda")
101
 
102
  generated_ids = model.generate(**inputs, max_new_tokens=200)
@@ -116,5 +125,6 @@ def generate_response(image):
116
  except Exception:
117
  return {}
118
 
119
- demo = gr.Interface(fn=generate_response, inputs=gr.Image(type='pil'), outputs="json")
 
120
  demo.launch()
 
1
  import spaces
2
  import gradio as gr
3
 
4
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
5
  from qwen_vl_utils import process_vision_info
6
  import torch
7
  import os
8
  import json
9
+ from pydantic import BaseModel
10
+ from typing import Tuple
11
 
12
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
13
+
14
  model = Qwen2VLForConditionalGeneration.from_pretrained(
15
  "Qwen/Qwen2-VL-7B-Instruct",
16
  torch_dtype=torch.bfloat16,
 
18
  device_map="auto",
19
  )
20
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
21
+
 
22
 
23
  class GeneralRetrievalQuery(BaseModel):
24
  broad_topical_query: str
 
28
  visual_element_query: str
29
  visual_element_explanation: str
30
 
31
+
32
  def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
33
  if prompt_name != "general":
34
  raise ValueError("Only 'general' prompt is available in this version")
 
69
  return prompt, GeneralRetrievalQuery
70
 
71
 
72
+ # defined like this so we can later add more prompting options
73
  prompt, pydantic_model = get_retrieval_prompt("general")
74
 
75
+
76
+ def _prep_data_for_input(image):
77
  messages = [
78
  {
79
  "role": "user",
 
100
  padding=True,
101
  return_tensors="pt",
102
  )
103
+ return inputs
104
+
105
+
106
+ @spaces.GPU
107
+ def generate_response(image):
108
+ inputs = _prep_data_for_input(image)
109
  inputs = inputs.to("cuda")
110
 
111
  generated_ids = model.generate(**inputs, max_new_tokens=200)
 
125
  except Exception:
126
  return {}
127
 
128
+
129
+ demo = gr.Interface(fn=generate_response, inputs=gr.Image(type="pil"), outputs="json")
130
  demo.launch()