TianheWu commited on
Commit
719cdd3
·
verified ·
1 Parent(s): edea5ef

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +332 -3
README.md CHANGED
@@ -1,3 +1,332 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ base_model:
6
+ - Qwen/Qwen2.5-VL-7B-Instruct
7
+ pipeline_tag: reinforcement-learning
8
+ tags:
9
+ - IQA
10
+ - Reasoning
11
+ - VLM
12
+ - Pytorch
13
+ - R1
14
+ ---
15
+
16
+ # VisualQuality-R1-7B
17
+ This is the final version of VisualQuality-R1, trained on a diverse combination of synthetic and realistic datasets.<br>
18
+ Paper link: [arXiv](https://arxiv.org/abs/2505.14460)<br>
19
+ Code link: [github](https://github.com/TianheWu/VisualQuality-R1)
20
+
21
+ > The first NR-IQA model enhanced by RL2R, capable of both quality description and rating through reasoning.
22
+
23
+
24
+
25
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/655de51982afda0fc479fb91/JZgVeMtAVASCCNYO5VCyn.png)
26
+
27
+
28
+
29
+ ## Quick Start
30
+ This section includes the usages of **VisualQuality-R1**.
31
+
32
+ <details>
33
+ <summary>Example Code (Single Image Quality Rating)</summary>
34
+
35
+ ```python
36
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
37
+ from qwen_vl_utils import process_vision_info
38
+
39
+ import torch
40
+ import random
41
+ import re
42
+ import os
43
+
44
+
45
+ def score_image(image_path, model, processor):
46
+ PROMPT = (
47
+ "You are doing the image quality assessment task. Here is the question: "
48
+ "What is your overall rating on the quality of this picture? The rating should be a float between 1 and 5, "
49
+ "rounded to two decimal places, with 1 representing very poor quality and 5 representing excellent quality. "
50
+ "First output the thinking process in <think> </think> tags and then output the final answer with only one score in <answer> </answer> tags."
51
+ )
52
+
53
+ QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer with only one score in <answer> </answer> tags."
54
+ # QUESTION_TEMPLATE = "Please describe the quality of this image."
55
+ message = [
56
+ {
57
+ "role": "user",
58
+ "content": [
59
+ {'type': 'image', 'image': image_path},
60
+ {"type": "text", "text": PROMPT}
61
+ ],
62
+ }
63
+ ]
64
+
65
+ batch_messages = [message]
66
+
67
+ # Preparation for inference
68
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True, add_vision_id=True) for msg in batch_messages]
69
+ image_inputs, video_inputs = process_vision_info(batch_messages)
70
+ inputs = processor(
71
+ text=text,
72
+ images=image_inputs,
73
+ videos=video_inputs,
74
+ padding=True,
75
+ return_tensors="pt",
76
+ )
77
+ inputs = inputs.to(device)
78
+
79
+ # Inference: Generation of the output
80
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=2048, do_sample=True, top_k=50, top_p=1)
81
+ generated_ids_trimmed = [
82
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
83
+ ]
84
+ batch_output_text = processor.batch_decode(
85
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
86
+ )
87
+
88
+ reasoning = re.findall(r'<think>(.*?)</think>', batch_output_text[0], re.DOTALL)
89
+ reasoning = reasoning[-1].strip()
90
+
91
+ try:
92
+ model_output_matches = re.findall(r'<answer>(.*?)</answer>', batch_output_text[0], re.DOTALL)
93
+ model_answer = model_output_matches[-1].strip() if model_output_matches else batch_output_text[0].strip()
94
+ score = float(re.search(r'\d+(\.\d+)?', model_answer).group())
95
+ except:
96
+ print(f"================= Meet error with {img_path}, please generate again. =================")
97
+ score = random.randint(1, 5)
98
+
99
+ return reasoning, score
100
+
101
+
102
+ random.seed(1)
103
+ MODEL_PATH = ""
104
+ device = torch.device("cuda:5") if torch.cuda.is_available() else torch.device("cpu")
105
+ image_path = ""
106
+
107
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
108
+ MODEL_PATH,
109
+ torch_dtype=torch.bfloat16,
110
+ attn_implementation="flash_attention_2",
111
+ device_map=device,
112
+ )
113
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
114
+ processor.tokenizer.padding_side = "left"
115
+
116
+ reasoning, score = score_image(
117
+ image_path, model, processor
118
+ )
119
+
120
+ print(reasoning)
121
+ print(score)
122
+ ```
123
+ </details>
124
+
125
+
126
+ <details>
127
+ <summary>Example Code (Batch Images Quality Rating)</summary>
128
+
129
+ ```python
130
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
131
+ from qwen_vl_utils import process_vision_info
132
+ from tqdm import tqdm
133
+
134
+ import torch
135
+ import random
136
+ import re
137
+ import os
138
+
139
+
140
+ def get_image_paths(folder_path):
141
+ image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp'}
142
+ image_paths = []
143
+
144
+ for root, dirs, files in os.walk(folder_path):
145
+ for file in files:
146
+ _, ext = os.path.splitext(file)
147
+ if ext.lower() in image_extensions:
148
+ image_paths.append(os.path.join(root, file))
149
+
150
+ return image_paths
151
+
152
+ def score_batch_image(image_paths, model, processor):
153
+ PROMPT = (
154
+ "You are doing the image quality assessment task. Here is the question: "
155
+ "What is your overall rating on the quality of this picture? The rating should be a float between 1 and 5, "
156
+ "rounded to two decimal places, with 1 representing very poor quality and 5 representing excellent quality."
157
+ )
158
+
159
+ QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer with only one score in <answer> </answer> tags."
160
+
161
+ messages = []
162
+ for img_path in image_paths:
163
+ message = [
164
+ {
165
+ "role": "user",
166
+ "content": [
167
+ {'type': 'image', 'image': img_path},
168
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=PROMPT)}
169
+ ],
170
+ }
171
+ ]
172
+ messages.append(message)
173
+
174
+ BSZ = 32
175
+ all_outputs = [] # List to store all answers
176
+ for i in tqdm(range(0, len(messages), BSZ)):
177
+ batch_messages = messages[i:i + BSZ]
178
+
179
+ # Preparation for inference
180
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True, add_vision_id=True) for msg in batch_messages]
181
+
182
+ image_inputs, video_inputs = process_vision_info(batch_messages)
183
+ inputs = processor(
184
+ text=text,
185
+ images=image_inputs,
186
+ videos=video_inputs,
187
+ padding=True,
188
+ return_tensors="pt",
189
+ )
190
+ inputs = inputs.to(device)
191
+
192
+ # Inference: Generation of the output
193
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=512, do_sample=True, top_k=50, top_p=1)
194
+ generated_ids_trimmed = [
195
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
196
+ ]
197
+ batch_output_text = processor.batch_decode(
198
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
199
+ )
200
+
201
+ all_outputs.extend(batch_output_text)
202
+
203
+ path_score_dict = {}
204
+ for img_path, model_output in zip(image_paths, all_outputs):
205
+ reasoning = re.findall(r'<think>(.*?)</think>', model_output, re.DOTALL)
206
+ reasoning = reasoning[-1].strip()
207
+
208
+ try:
209
+ model_output_matches = re.findall(r'<answer>(.*?)</answer>', model_output, re.DOTALL)
210
+ model_answer = model_output_matches[-1].strip() if model_output_matches else model_output.strip()
211
+ score = float(re.search(r'\d+(\.\d+)?', model_answer).group())
212
+ except:
213
+ print(f"Meet error with {img_path}, please generate again.")
214
+ score = random.randint(1, 5)
215
+
216
+ path_score_dict[img_path] = score
217
+
218
+ return path_score_dict
219
+
220
+
221
+ random.seed(1)
222
+ MODEL_PATH = ""
223
+ device = torch.device("cuda:3") if torch.cuda.is_available() else torch.device("cpu")
224
+
225
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
226
+ MODEL_PATH,
227
+ torch_dtype=torch.bfloat16,
228
+ attn_implementation="flash_attention_2",
229
+ device_map=device,
230
+ )
231
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
232
+ processor.tokenizer.padding_side = "left"
233
+
234
+ image_root = ""
235
+ image_paths = get_image_paths(image_root) # It should be a list
236
+
237
+ path_score_dict = score_batch_image(
238
+ image_paths, model, processor
239
+ )
240
+
241
+ file_name = "output.txt"
242
+ with open(file_name, "w") as file:
243
+ for key, value in path_score_dict.items():
244
+ file.write(f"{key} {value}\n")
245
+
246
+ print("Done!")
247
+ ```
248
+ </details>
249
+
250
+
251
+ <details>
252
+ <summary>Example Code (Images Inference)</summary>
253
+
254
+ You can prompt anything what you like in the following commands (including multi-image as input)
255
+ ```python
256
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
257
+ from qwen_vl_utils import process_vision_info
258
+
259
+ import torch
260
+ import random
261
+ import re
262
+ import os
263
+
264
+
265
+ def generate(image_paths, model, prompt, processor):
266
+ message = [
267
+ {
268
+ "role": "user",
269
+ "content": [
270
+ *({'type': 'image', 'image': img_path} for img_path in image_paths),
271
+ {"type": "text", "text": prompt}
272
+ ],
273
+ }
274
+ ]
275
+
276
+ batch_messages = [message]
277
+
278
+ # Preparation for inference
279
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True, add_vision_id=True) for msg in batch_messages]
280
+ image_inputs, video_inputs = process_vision_info(batch_messages)
281
+ inputs = processor(
282
+ text=text,
283
+ images=image_inputs,
284
+ videos=video_inputs,
285
+ padding=True,
286
+ return_tensors="pt",
287
+ )
288
+ inputs = inputs.to(device)
289
+
290
+ # Inference: Generation of the output
291
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=2048, do_sample=True, top_k=50, top_p=1)
292
+ generated_ids_trimmed = [
293
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
294
+ ]
295
+ batch_output_text = processor.batch_decode(
296
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
297
+ )
298
+
299
+ return batch_output_text[0]
300
+
301
+
302
+ random.seed(1)
303
+ MODEL_PATH = ""
304
+ device = torch.device("cuda:5") if torch.cuda.is_available() else torch.device("cpu")
305
+ image_path = [
306
+ "",
307
+ ""
308
+ ]
309
+
310
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
311
+ MODEL_PATH,
312
+ torch_dtype=torch.bfloat16,
313
+ attn_implementation="flash_attention_2",
314
+ device_map=device,
315
+ )
316
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
317
+ processor.tokenizer.padding_side = "left"
318
+
319
+ prompt = "Please describe the quality of given two images."
320
+ answer = generate(
321
+ image_path, model, prompt, processor
322
+ )
323
+
324
+ print(answer)
325
+ ```
326
+ </details>
327
+
328
+
329
+
330
+
331
+
332
+