shervin-dadashzadeh commited on
Commit
1c13578
Β·
1 Parent(s): 8df44c6

added util

Browse files
util/__init__.py ADDED
File without changes
util/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (186 Bytes). View file
 
util/__pycache__/logutil.cpython-310.pyc ADDED
Binary file (1.09 kB). View file
 
util/__pycache__/vision_util.cpython-310.pyc ADDED
Binary file (6.34 kB). View file
 
util/logutil.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import datetime
3
+ import os
4
+
5
+ # Create a logger object
6
+ _logger = None
7
+
8
+ def init_logger(log_dir="./"):
9
+ os.makedirs(log_dir, exist_ok=True)
10
+ global _logger
11
+ _logger = logging.getLogger('MyLogger')
12
+ _logger.setLevel(logging.INFO) # Set the default logging level to INFO
13
+
14
+ # Create a formatter with detailed format including filename and line number
15
+ _formatter = logging.Formatter('%(asctime)s-%(filename)s:%(lineno)d-%(levelname)s >> %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
16
+
17
+ # Create a file handler and set the level to INFO
18
+ _file_handler = logging.FileHandler(f'{log_dir}output.{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}.log.txt', mode='w')
19
+ _file_handler.setLevel(logging.INFO)
20
+ _file_handler.setFormatter(_formatter)
21
+
22
+ # Create a console handler and set the level to INFO
23
+ _console_handler = logging.StreamHandler()
24
+ _console_handler.setLevel(logging.INFO)
25
+ _console_handler.setFormatter(_formatter)
26
+
27
+ # Add the handlers to the logger
28
+ _logger.addHandler(_file_handler)
29
+ _logger.addHandler(_console_handler)
30
+
31
+ def get_logger():
32
+ assert _logger is not None, "Logger is not initialized. Please call init_logger() first."
33
+ return _logger
util/vision_util.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original file: https://github.com/kq-chen/qwen-vl-utils/blob/main/src/qwen_vl_utils/vision_process.py
2
+ # I made some modifications to the original code.
3
+ # 1. Use torchvision.io.VideoReader to read video frames instead of torchvision.io.read_video. The former is much much faster.
4
+ # 2. Remove FPS parameter. It is not that necessary.
5
+ from __future__ import annotations
6
+
7
+ import base64
8
+ import math
9
+ from io import BytesIO
10
+
11
+ import requests
12
+ import torch
13
+ import torchvision
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+ from torchvision.transforms import InterpolationMode
17
+
18
+
19
+ IMAGE_FACTOR = 28
20
+ MIN_PIXELS = 4 * 28 * 28
21
+ MAX_PIXELS = 16384 * 28 * 28
22
+ MAX_RATIO = 200
23
+
24
+ VIDEO_MIN_PIXELS = 128 * 28 * 28
25
+ VIDEO_MAX_PIXELS = 768 / 4 * 28 * 28
26
+ VIDEO_TOTAL_PIXELS = 24576 / 4 * 28 * 28
27
+ FRAME_FACTOR = 2
28
+ FPS_MIN_FRAMES = 4
29
+ FPS_MAX_FRAMES = 768 / 4
30
+
31
+
32
+ def round_by_factor(number: int, factor: int) -> int:
33
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
34
+ return round(number / factor) * factor
35
+
36
+
37
+ def ceil_by_factor(number: int, factor: int) -> int:
38
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
39
+ return math.ceil(number / factor) * factor
40
+
41
+
42
+ def floor_by_factor(number: int, factor: int) -> int:
43
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
44
+ return math.floor(number / factor) * factor
45
+
46
+
47
+ def smart_resize(
48
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
49
+ ) -> tuple[int, int]:
50
+ """
51
+ Rescales the image so that the following conditions are met:
52
+
53
+ 1. Both dimensions (height and width) are divisible by 'factor'.
54
+
55
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
56
+
57
+ 3. The aspect ratio of the image is maintained as closely as possible.
58
+ """
59
+ if max(height, width) / min(height, width) > MAX_RATIO:
60
+ raise ValueError(
61
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
62
+ )
63
+ h_bar = max(factor, round_by_factor(height, factor))
64
+ w_bar = max(factor, round_by_factor(width, factor))
65
+ if h_bar * w_bar > max_pixels:
66
+ beta = math.sqrt((height * width) / max_pixels)
67
+ h_bar = floor_by_factor(height / beta, factor)
68
+ w_bar = floor_by_factor(width / beta, factor)
69
+ elif h_bar * w_bar < min_pixels:
70
+ beta = math.sqrt(min_pixels / (height * width))
71
+ h_bar = ceil_by_factor(height * beta, factor)
72
+ w_bar = ceil_by_factor(width * beta, factor)
73
+ return h_bar, w_bar
74
+
75
+
76
+ def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
77
+ if "image" in ele:
78
+ image = ele["image"]
79
+ else:
80
+ image = ele["image_url"]
81
+ image_obj = None
82
+ if isinstance(image, Image.Image):
83
+ image_obj = image
84
+ elif image.startswith("http://") or image.startswith("https://"):
85
+ image_obj = Image.open(requests.get(image, stream=True).raw)
86
+ elif image.startswith("file://"):
87
+ image_obj = Image.open(image[7:])
88
+ elif image.startswith("data:image"):
89
+ data = image.split(";", 1)[1]
90
+ if data.startswith("base64,"):
91
+ data = base64.b64decode(data[7:])
92
+ image_obj = Image.open(BytesIO(data))
93
+ else:
94
+ image_obj = Image.open(image)
95
+ if image_obj is None:
96
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
97
+ image = image_obj.convert("RGB")
98
+ ## resize
99
+ if "resized_height" in ele and "resized_width" in ele:
100
+ resized_height, resized_width = smart_resize(
101
+ ele["resized_height"],
102
+ ele["resized_width"],
103
+ factor=size_factor,
104
+ )
105
+ else:
106
+ width, height = image.size
107
+ min_pixels = ele.get("min_pixels", MIN_PIXELS)
108
+ max_pixels = ele.get("max_pixels", MAX_PIXELS)
109
+ resized_height, resized_width = smart_resize(
110
+ height,
111
+ width,
112
+ factor=size_factor,
113
+ min_pixels=min_pixels,
114
+ max_pixels=max_pixels,
115
+ )
116
+ image = image.resize((resized_width, resized_height))
117
+
118
+ return image
119
+
120
+
121
+ def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:
122
+ if isinstance(ele["video"], str):
123
+ # TODO: support http url
124
+
125
+ video = ele["video"]
126
+ if video.startswith("file://"):
127
+ video = video[7:]
128
+
129
+ frames_data = [f for f in torchvision.io.VideoReader(video, "video")]
130
+ assert(len(frames_data) > 0)
131
+
132
+ duration = frames_data[-1]['pts'] - frames_data[0]['pts']
133
+ fps = len(frames_data) / duration
134
+
135
+ video = torch.stack([f["data"] for f in frames_data])
136
+
137
+ if "nframes" in ele:
138
+ nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
139
+ else:
140
+ min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
141
+ max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, video.size(0))), FRAME_FACTOR)
142
+ nframes = video.size(0) / fps
143
+ nframes = min(max(nframes, min_frames), max_frames)
144
+ nframes = round_by_factor(nframes, FRAME_FACTOR)
145
+ if not (FRAME_FACTOR <= nframes and nframes <= video.size(0)):
146
+ raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {video.size(0)}], but got {nframes}.")
147
+
148
+ idx = torch.linspace(0, video.size(0) - 1, nframes).round().long()
149
+ height, width = video.shape[2:]
150
+ video = video[idx]
151
+
152
+ min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
153
+ total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
154
+ max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
155
+ max_pixels = ele.get("max_pixels", max_pixels)
156
+ if "resized_height" in ele and "resized_width" in ele:
157
+ resized_height, resized_width = smart_resize(
158
+ ele["resized_height"],
159
+ ele["resized_width"],
160
+ factor=image_factor,
161
+ )
162
+ else:
163
+ resized_height, resized_width = smart_resize(
164
+ height,
165
+ width,
166
+ factor=image_factor,
167
+ min_pixels=min_pixels,
168
+ max_pixels=max_pixels,
169
+ )
170
+
171
+ video = transforms.functional.resize(
172
+ video,
173
+ [resized_height, resized_width],
174
+ interpolation=InterpolationMode.BICUBIC,
175
+ antialias=True,
176
+ ).float()
177
+ return video
178
+ else:
179
+ assert isinstance(ele["video"], (list, tuple))
180
+ process_info = ele.copy()
181
+ process_info.pop("type", None)
182
+ process_info.pop("video", None)
183
+ images = [
184
+ fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
185
+ for video_element in ele["video"]
186
+ ]
187
+ nframes = ceil_by_factor(len(images), FRAME_FACTOR)
188
+ if len(images) < nframes:
189
+ images.extend([images[-1]] * (nframes - len(images)))
190
+ return images
191
+
192
+
193
+ def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
194
+ vision_infos = []
195
+ if isinstance(conversations[0], dict):
196
+ conversations = [conversations]
197
+ for conversation in conversations:
198
+ for message in conversation:
199
+ if isinstance(message["content"], list):
200
+ for ele in message["content"]:
201
+ if (
202
+ "image" in ele
203
+ or "image_url" in ele
204
+ or "video" in ele
205
+ or ele["type"] in ("image", "image_url", "video")
206
+ ):
207
+ vision_infos.append(ele)
208
+ return vision_infos
209
+
210
+
211
+ def process_vision_info(
212
+ conversations: list[dict] | list[list[dict]],
213
+ ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None]:
214
+ vision_infos = extract_vision_info(conversations)
215
+ ## Read images or videos
216
+ image_inputs = []
217
+ video_inputs = []
218
+ for vision_info in vision_infos:
219
+ if "image" in vision_info or "image_url" in vision_info:
220
+ image_inputs.append(fetch_image(vision_info))
221
+ elif "video" in vision_info:
222
+ video_inputs.append(fetch_video(vision_info))
223
+ else:
224
+ raise ValueError("image, image_url or video should in content.")
225
+ if len(image_inputs) == 0:
226
+ image_inputs = None
227
+ if len(video_inputs) == 0:
228
+ video_inputs = None
229
+ return image_inputs, video_inputs