toilaluan commited on
Commit
25bd6f8
·
1 Parent(s): 83f4a0c
Files changed (3) hide show
  1. app.py +69 -4
  2. requirements.txt +7 -0
  3. scorer.py +402 -0
app.py CHANGED
@@ -1,7 +1,72 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import spaces
3
+ from scorer import DSGPromptProcessor
4
+ import matplotlib.pyplot as plt
5
+ import networkx as nx
6
+ from PIL import Image
7
 
 
 
8
 
9
+ def draw_colored_graph(dependencies, questions, answers):
10
+ # Create a directed graph
11
+ G = nx.DiGraph()
12
+
13
+ # Add nodes with labels and colors based on answers
14
+ for node, question in questions.items():
15
+ color = 'green' if answers[node] else 'red'
16
+ G.add_node(int(node), label=question, color=color)
17
+
18
+ # Add edges based on dependencies
19
+ for node, deps in dependencies.items():
20
+ for dep in deps:
21
+ G.add_edge(dep, int(node))
22
+
23
+ # Set node positions using a layout
24
+ pos = nx.spring_layout(G) # You can use other layouts like 'shell_layout' or 'circular_layout'
25
+
26
+ # Draw nodes with custom colors and labels
27
+ node_colors = [G.nodes[node]['color'] for node in G.nodes()]
28
+ nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=2000, edgecolors='black')
29
+
30
+ # Draw edges with arrows
31
+ nx.draw_networkx_edges(G, pos, arrowstyle='-|>', arrows=True, arrowsize=20, connectionstyle='arc3,rad=0.1')
32
+
33
+ # Draw labels
34
+ labels = nx.get_node_attributes(G, 'label')
35
+ nx.draw_networkx_labels(G, pos, labels, font_size=10, font_color='black')
36
+
37
+ # Save the graph as a Pillow image
38
+ buf = io.BytesIO()
39
+ plt.axis('off')
40
+ plt.savefig(buf, format='png')
41
+ buf.seek(0)
42
+ img = Image.open(buf)
43
+ return img
44
+
45
+ dsg_scorer = DSGPromptProcessor("mistralai/Mixtral-8x7B-Instruct-v0.1")
46
+
47
+ def process_image(image, prompt):
48
+ tuples, _ = processor.generate_tuples(prompt)
49
+ dependencies, _ = processor.generate_dependencies(tuples)
50
+ questions, _ = processor.generate_questions(
51
+ input_text, tuples.tuples, dependencies
52
+ )
53
+ reward = processor.get_reward(input_text, questions, dependencies, [image])
54
+ reward = reward[0]
55
+ answers = {i: v > 0.5 for i, v in enumerate(reward)}
56
+ graph_img = draw_colored_graph(dependencies, questions, answers)
57
+ return reward, f"""
58
+ Question: {questions}.
59
+ Reward per question: {reward}"""
60
+
61
+ # Define the Gradio interface
62
+ interface = gr.Interface(
63
+ fn=process_image,
64
+ inputs=[gr.Image(type="pil"), gr.Textbox(label="Enter your prompt")],
65
+ outputs=[gr.Image(type="pil"), gr.Textbox(label="Output text")],
66
+ title="Image and Prompt Interface",
67
+ description="Upload an image and enter a prompt. The output is an image and text below it."
68
+ )
69
+
70
+ # Launch the Gradio app
71
+ interface.launch()
72
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ openai
2
+ pydantic
3
+ transformers
4
+ torch
5
+ pillow
6
+ timm
7
+ einops
scorer.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import json
3
+ from pydantic import BaseModel, Field
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ from transformers import AutoProcessor, AutoModelForCausalLM
7
+ import torch
8
+ import requests
9
+ import spaces
10
+
11
+ class PromptTuple(BaseModel):
12
+ class Tuple(BaseModel):
13
+ type: str = Field(
14
+ description="The type of the tuple. One of entity, attribute, relation",
15
+ example="attribute",
16
+ )
17
+ type_detail: str = Field(
18
+ description="""The detail of the type. For example:
19
+ - Entity: whole (entire entity, e.g., chair), part (part of entity, e.g., back of chair).
20
+ - Attribute: color (e.g., red book), type (e.g., aviator goggles), material (e.g., wooden chair), count (e.g., 5 geese), texture (e.g., rough surface), text rendering (e.g., letters “Macaroni”), shape (e.g., triangle block), size (e.g., large fence).
21
+ - Relation: spatial (e.g., A next to B); action (A kicks B).""",
22
+ example="color",
23
+ )
24
+ semantics: list = Field(
25
+ description="List of strings that explain the existence of type and type_detail in the tuple",
26
+ example=["motorcycle", "blue"],
27
+ )
28
+
29
+ tuples: list[Tuple] = Field(
30
+ description="List of tuples. Maximum 8 tuples.",
31
+ example=[
32
+ {
33
+ "type": "attribute",
34
+ "type_detail": "color",
35
+ "semantics": ["motorcycle", "blue"],
36
+ }
37
+ ],
38
+ )
39
+
40
+
41
+ class DSGPromptProcessor:
42
+ def __init__(self, model_name="gpt-4o-mini"):
43
+ self.client = openai.OpenAI()
44
+ self.model_name = model_name
45
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ self.binary_vqa = AutoModelForCausalLM.from_pretrained("toilaluan/Florence-2-base-Yes-No-VQA", trust_remote_code=True).to(self.device, torch.float16)
47
+ self.binary_vqa_processor = processor = AutoProcessor.from_pretrained("toilaluan/Florence-2-base-Yes-No-VQA", trust_remote_code=True)
48
+
49
+
50
+ def generate_tuples(self, input_text: str) -> PromptTuple:
51
+ system_message = """
52
+ Given an image caption, extract the relevant entities, attributes, and relations present in the caption, and structure them into JSON format according to the following schema:
53
+ Each tuple contains the following information:
54
+ - Id: A unique identifier for the tuple.
55
+ - Type: The category of the tuple. Choose from "entity," "attribute," or "relation."
56
+ - Type Detail: Provide additional details based on the selected type:
57
+ - Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair").
58
+ - Attribute: Specify the attribute type, such as "color", "type", "material", "count", "style", "texture", "text rendering", "shape" or "size".
59
+ - Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B").
60
+ - Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple.
61
+ Example Input: "A blue motorcycle parked next to a red car."
62
+ Example output:
63
+ {
64
+ "tuples": [
65
+ {
66
+ "type": "entity",
67
+ "type_detail": "whole",
68
+ "semantics": ["motorcycle"]
69
+ },
70
+ {
71
+ "type": "attribute",
72
+ "type_detail": "color",
73
+ "semantics": ["motorcycle", "blue"]
74
+ },
75
+ {
76
+ "type": "entity",
77
+ "type_detail": "whole",
78
+ "semantics": ["car"]
79
+ },
80
+ {
81
+ "type": "attribute",
82
+ "type_detail": "color",
83
+ "semantics": ["car", "red"]
84
+ },
85
+ {
86
+ "type": "relation",
87
+ "type_detail": "spatial",
88
+ "semantics": ["motorcycle", "next to", "car"]
89
+ }
90
+ ]
91
+ }
92
+ The final JSON should contain a list of tuples, each describing a unique entity, attribute, or relation from the image caption. Each JSON should contain a maximum of 8 tuples.
93
+ """
94
+ messages = [
95
+ {
96
+ "role": "system",
97
+ "content": system_message,
98
+ },
99
+ {
100
+ "role": "user",
101
+ "content": input_text,
102
+ },
103
+ ]
104
+
105
+ response = self.client.chat.completions.create(
106
+ model=self.model_name,
107
+ messages=messages,
108
+ response_format={"type": "json_object"},
109
+ max_tokens=512,
110
+ )
111
+ output = json.loads(response.choices[0].message.content)
112
+ return PromptTuple(**output), response.usage.total_tokens
113
+
114
+ def generate_dependencies(self, tuples: PromptTuple) -> dict:
115
+ DEPENDENCY_PROMPT = """
116
+ Given the following tuples extracted from an image caption, determine the dependencies between the entities, attributes, and relations in the JSON format.
117
+ Each tuple contains the following information:
118
+ - Id: A unique identifier for the tuple.
119
+ - Type: The category of the tuple. Choose from "entity," "attribute," or "relation."
120
+ - Type Detail: Provide additional details based on the selected type:
121
+ - Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair").
122
+ - Attribute: Specify the attribute type, such as "color," "type," "material," "count," "texture," "text rendering," "shape," or "size."
123
+ - Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B").
124
+ - Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple.
125
+ Output is a dictionary where the key is the id of the tuple and the value is a list of ids that the tuple depends on.
126
+ Example input:
127
+ [
128
+ {
129
+ "id": 1,
130
+ "type": "entity",
131
+ "type_detail": "whole",
132
+ "semantics": ["motorcycle"]
133
+ },
134
+ {
135
+ "id": 2,
136
+ "type": "attribute",
137
+ "type_detail": "color",
138
+ "semantics": ["motorcycle", "blue"]
139
+ },
140
+ {
141
+ "id": 3,
142
+ "type": "entity",
143
+ "type_detail": "whole",
144
+ "semantics": ["car"]
145
+ },
146
+ {
147
+ "id": 4,
148
+ "type": "attribute",
149
+ "type_detail": "color",
150
+ "semantics": ["car", "red"]
151
+ },
152
+ {
153
+ "id": 5,
154
+ "type": "relation",
155
+ "type_detail": "spatial",
156
+ "semantics": ["motorcycle", "next to", "car"]
157
+ }
158
+ ]
159
+
160
+ Example output:
161
+ {
162
+ "1": [],
163
+ "2": [1],
164
+ "3": [],
165
+ "4": [3],
166
+ "5": [1, 3]
167
+ }
168
+
169
+ """
170
+ input_obj = [{"id": i, **t.dict()} for i, t in enumerate(tuples.tuples)]
171
+
172
+ messages = [
173
+ {
174
+ "role": "system",
175
+ "content": DEPENDENCY_PROMPT,
176
+ },
177
+ {
178
+ "role": "user",
179
+ "content": json.dumps(input_obj),
180
+ },
181
+ ]
182
+
183
+ response = self.client.chat.completions.create(
184
+ model=self.model_name,
185
+ messages=messages,
186
+ response_format={"type": "json_object"},
187
+ )
188
+ return (
189
+ json.loads(response.choices[0].message.content),
190
+ response.usage.total_tokens,
191
+ )
192
+
193
+ def generate_questions(
194
+ self, prompt: str, tuples: list[dict], dependencies: dict
195
+ ) -> list[str]:
196
+ """Generate validate question based on tuples and dependencies.
197
+
198
+ Args:
199
+ prompt (str): a prompt describe the image
200
+ tuples (list[dict]): each tuple is a unit of information extracted from the prompt
201
+ dependencies (dict): the dependencies between tuples
202
+ """
203
+ system_message = """
204
+ Task: Given a prompt that describe the image and a list of tuples extracted from the prompt. Generate questions based on tuple in natural language as a list.
205
+ Each tuple contains the following information:
206
+ - Id: A unique identifier for the tuple.
207
+ - Type: The category of the tuple. Choose from "entity," "attribute," or "relation."
208
+ - Type Detail: Provide additional details based on the selected type:
209
+ - Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair").
210
+ - Attribute: Specify the attribute type, such as "color", "type", "material", "count", "style", "texture", "text rendering", "shape" or "size".
211
+ - Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B").
212
+ - Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple.
213
+ Output is a list of questions, each question corresponds to a tuple. The number of questions must be the same as the number of tuples.
214
+ Example input:
215
+ Prompt: "A traffic light and a signpost at a crossroads intersection near a waterway"
216
+ Tuples:
217
+ [
218
+ {
219
+ "id": 1,
220
+ "type": "entity",
221
+ "type_detail": "whole",
222
+ "semantics": ["traffic light"]
223
+ },
224
+ {
225
+ "id": 2,
226
+ "type": "entity",
227
+ "type_detail": "whole",
228
+ "semantics": ["signpost"]
229
+ },
230
+ {
231
+ "id": 3,
232
+ "type": "relation",
233
+ "type_detail": "spatial",
234
+ "semantics": ["traffic light", "at", "crossroads intersection"]
235
+ },
236
+ {
237
+ "id": 4,
238
+ "type": "relation",
239
+ "type_detail": "spatial",
240
+ "semantics": ["crossroads intersection", "near", "waterway"]
241
+ }
242
+ ]
243
+ Dependencies:
244
+ {
245
+ "1": [],
246
+ "2": [],
247
+ "3": [1, 2],
248
+ "4": [3]
249
+ }
250
+ Example output is a json object. Each question ask about the existence of the tuple in the prompt and the answer should always be yes.
251
+ {
252
+ "1": "Is there a light?",
253
+ "2": "Is there a signpost?",
254
+ "3": "Is the traffic light at a crossroads intersection?",
255
+ "4": "Is the crossroads intersection near a waterway?"
256
+ }
257
+ """
258
+
259
+ user_str = f"""
260
+ Prompt: {prompt}
261
+ Tuples: {tuples}
262
+ Dependencies: {dependencies}
263
+ """
264
+ messages = [
265
+ {
266
+ "role": "system",
267
+ "content": system_message,
268
+ },
269
+ {
270
+ "role": "user",
271
+ "content": user_str,
272
+ },
273
+ ]
274
+
275
+ response = self.client.chat.completions.create(
276
+ model=self.model_name,
277
+ messages=messages,
278
+ response_format={"type": "json_object"},
279
+ )
280
+ return (
281
+ json.loads(response.choices[0].message.content),
282
+ response.usage.total_tokens,
283
+ )
284
+
285
+ def find_layers(self, dep_dict):
286
+ layers = []
287
+ remaining_keys = set(dep_dict.keys())
288
+
289
+ while remaining_keys:
290
+ current_layer = []
291
+ for key in list(remaining_keys):
292
+ # If all dependencies of the key are in previous layers
293
+ if all(
294
+ str(dep) in [k for layer in layers for k in layer]
295
+ for dep in dep_dict[key]
296
+ ):
297
+ current_layer.append(key)
298
+
299
+ # If no new layer is formed, break to avoid infinite loop
300
+ if not current_layer:
301
+ break
302
+
303
+ # Add the current layer to the list of layers
304
+ layers.append(current_layer)
305
+ # Remove the keys that are now layered
306
+ remaining_keys -= set(current_layer)
307
+
308
+ if len(layers) == 3:
309
+ break
310
+
311
+ ordered_indexes = [item for sublist in layers for item in sublist]
312
+ return ordered_indexes
313
+
314
+ def _create_graph_questions(self, questions: dict, dependencies: dict) -> set:
315
+ # create a question graph
316
+ layered_indexes = self.find_layers(dependencies)
317
+ print(layered_indexes)
318
+ sorted_questions = [questions[i] for i in layered_indexes]
319
+
320
+ return sorted_questions
321
+
322
+ @spaces.GPU(duration=120)
323
+ def get_reward(
324
+ self,
325
+ prompt: str,
326
+ questions: list[str],
327
+ dependencies: dict[list],
328
+ images: list[str],
329
+ mode="hybrid",
330
+ ):
331
+ """Get reward for the generated questions use structured question graph.
332
+
333
+ Args:
334
+ prompt (str): a prompt describe the image
335
+ questions (list[str]): a list of questions generated based on the tuples
336
+ dependencies (dict[list]): the dependencies between tuples
337
+ images (list[str]): a list of image urls
338
+ """
339
+ scores = {}
340
+
341
+ sorted_questions = self._create_graph_questions(questions, dependencies)
342
+ print(sorted_questions)
343
+
344
+ for i in range(len(images)):
345
+ scores[i] = [0] * len(sorted_questions)
346
+
347
+ def get_reward_for_a_question(
348
+ question: str,
349
+ question_dependencies: list[int],
350
+ image: Image.Image,
351
+ prev_scores: list[int],
352
+ ) -> float:
353
+ if any([not (prev_scores[i] > 0.5) for i in question_dependencies]):
354
+ print(
355
+ f"Skipping question: {question}. It depends on {[sorted_questions[i] for i in range(len(question_dependencies))]} that was answered as No."
356
+ )
357
+ return 0
358
+ if not isinstance(image, Image.Image):
359
+ raise ValueError("Invalid image type")
360
+
361
+ inputs = self.binary_vqa_processor(text=question, images=image, return_tensors="pt").to(self.device, torch.float16)
362
+ decoder_input_ids = torch.LongTensor([[self.binary_vqa.language_model.config.pad_token_id, self.binary_vqa.language_model.config.decoder_start_token_id]]).to(self.device)
363
+ outputs = self.binary_vqa(
364
+ input_ids=inputs["input_ids"],
365
+ pixel_values=inputs["pixel_values"],
366
+ decoder_input_ids=decoder_input_ids
367
+ )
368
+ logits = outputs.logits[:, -1]
369
+ score = logits[0].sigmoid().item()
370
+ print(f"The answer Yes has {score} probs")
371
+ return score
372
+
373
+ pbar = tqdm(
374
+ total=len(sorted_questions) * len(images),
375
+ desc=f"Calculating reward over {len(images)} images and {len(sorted_questions)} questions",
376
+ )
377
+ for i, question in enumerate(sorted_questions):
378
+ for j, image in enumerate(images):
379
+ scores[j][i] = get_reward_for_a_question(
380
+ question, dependencies[str(i)], image, scores[j]
381
+ )
382
+ pbar.update(1)
383
+
384
+ return scores
385
+
386
+
387
+ if __name__ == "__main__":
388
+ processor = DSGPromptProcessor(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1")
389
+ url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
390
+ image = Image.open(requests.get(url, stream=True).raw)
391
+ input_text = "ghibli style image of a cat"
392
+ tuples, tokens = processor.generate_tuples(input_text)
393
+ print(tuples)
394
+ dependencies, tokens = processor.generate_dependencies(tuples)
395
+ print(dependencies)
396
+ questions, tokens = processor.generate_questions(
397
+ input_text, tuples.tuples, dependencies
398
+ )
399
+ print(questions)
400
+
401
+ reward = processor.get_reward(input_text, questions, dependencies, [image])
402
+ print(reward)