jw2yang commited on
Commit
58f2b4f
·
verified ·
1 Parent(s): d2d3a96

Upload image_processing_magma.py

Browse files
Files changed (1) hide show
  1. image_processing_magma.py +333 -0
image_processing_magma.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Image processor class for Magma."""
17
+
18
+ from typing import List, Optional, Union
19
+ import ast
20
+ import numpy as np
21
+ import torchvision
22
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
23
+ from transformers.image_transforms import (
24
+ convert_to_rgb,
25
+ )
26
+ from transformers.image_utils import (
27
+ OPENAI_CLIP_MEAN,
28
+ OPENAI_CLIP_STD,
29
+ ImageInput,
30
+ make_list_of_images,
31
+ valid_images,
32
+ )
33
+ from transformers.utils import TensorType, is_vision_available, logging
34
+
35
+ from transformers import AutoImageProcessor
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ if is_vision_available():
41
+ from PIL import Image
42
+
43
+ import torch
44
+ import torchvision
45
+
46
+ def padding_336(b):
47
+ width, height = b.size
48
+ tar = int(np.ceil(height / 336) * 336)
49
+ top_padding = int((tar - height)/2)
50
+ bottom_padding = tar - height - top_padding
51
+ left_padding = 0
52
+ right_padding = 0
53
+ b = torchvision.transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255,255,255])
54
+
55
+ return b
56
+
57
+ def calc_padded_size(width, height, padding_unit=336):
58
+ target_height = int(np.ceil(height / padding_unit) * padding_unit)
59
+ top_padding = int((target_height - height) / 2)
60
+ bottom_padding = target_height - height - top_padding
61
+ left_padding = 0
62
+ right_padding = 0
63
+ padded_width = width + left_padding + right_padding
64
+ padded_height = height + top_padding + bottom_padding
65
+ return padded_width, padded_height
66
+
67
+ def HD_transform(img, hd_num=4, base_img_size=768):
68
+ width, height = img.size
69
+ trans = False
70
+ if width < height:
71
+ img = img.transpose(Image.TRANSPOSE)
72
+ trans = True
73
+ width, height = img.size
74
+ ratio = (width / height)
75
+ scale = 1
76
+ while scale*np.ceil(scale/ratio) <= hd_num:
77
+ scale += 1
78
+ scale -= 1
79
+ new_w = int(scale * base_img_size)
80
+ new_h = int(new_w / ratio)
81
+
82
+ img = torchvision.transforms.functional.resize(img, [new_h, new_w],)
83
+ img = padding_336(img)
84
+ width, height = img.size
85
+ if trans:
86
+ img = img.transpose(Image.TRANSPOSE)
87
+
88
+ return img
89
+
90
+ def calc_hd_transform_size(width, height, hd_num=16):
91
+ transposed = False
92
+ if width < height:
93
+ width, height = height, width
94
+ transposed = True
95
+
96
+ ratio = width / height
97
+ scale = 1
98
+ while scale * np.ceil(scale / ratio) <= hd_num:
99
+ scale += 1
100
+ scale -= 1
101
+
102
+ new_width = int(scale * 336)
103
+ new_height = int(new_width / ratio)
104
+
105
+ padded_width, padded_height = calc_padded_size(new_width, new_height)
106
+
107
+ if transposed:
108
+ padded_width, padded_height = padded_height, padded_width
109
+
110
+ return padded_width, padded_height
111
+
112
+ def pad_to_max_num_crops_tensor(images, max_crops=5):
113
+ """
114
+ images: B x 3 x H x W, B<=max_crops
115
+ """
116
+ B, _, H, W = images.shape
117
+ if B < max_crops:
118
+ pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device)
119
+ images = torch.cat([images, pad], dim=0)
120
+ return images
121
+
122
+ def select_best_resolution(original_size, possible_resolutions):
123
+ """
124
+ Selects the best resolution from a list of possible resolutions based on the original size.
125
+
126
+ Args:
127
+ original_size (tuple): The original size of the image in the format (width, height).
128
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
129
+
130
+ Returns:
131
+ tuple: The best fit resolution in the format (width, height).
132
+ """
133
+ original_width, original_height = original_size
134
+ best_fit = None
135
+ max_effective_resolution = 0
136
+ min_wasted_resolution = float('inf')
137
+
138
+ for width, height in possible_resolutions:
139
+ scale = min(width / original_width, height / original_height)
140
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
141
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
142
+ wasted_resolution = (width * height) - effective_resolution
143
+
144
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
145
+ max_effective_resolution = effective_resolution
146
+ min_wasted_resolution = wasted_resolution
147
+ best_fit = (width, height)
148
+
149
+ return best_fit
150
+
151
+ def process_anyres_image(image, max_num_crops=None, base_width=768, base_height=768):
152
+ """
153
+ Process an image with variable resolutions.
154
+
155
+ Args:
156
+ image (torch.Tensor): The input image to be processed.
157
+ max_num_crops (int): Maximum number of crops
158
+
159
+ Returns:
160
+ torch.Tensor: A tensor containing the processed image patches.
161
+ """
162
+ assert max_num_crops is not None
163
+ grid_pinpoints = []
164
+ for i in range(1, max_num_crops+1):
165
+ for j in range(1, max_num_crops // i + 1):
166
+ grid_pinpoints.append((i, j))
167
+ grid_pinpoints = [(int(res[0] * base_width), int(res[1] * base_height)) for res in grid_pinpoints]
168
+
169
+ if type(grid_pinpoints) is list:
170
+ possible_resolutions = grid_pinpoints
171
+ else:
172
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
173
+
174
+ best_resolution = select_best_resolution((image.shape[2], image.shape[1]), possible_resolutions)
175
+ # NOTE: reverse best_resolution from (width, height) to (height, width)
176
+ best_resolution = (best_resolution[1], best_resolution[0])
177
+ best_resolution_grid = (best_resolution[0] // base_height, best_resolution[1] // base_width)
178
+
179
+ # resize image tensor to best resolution
180
+ image = torch.nn.functional.interpolate(image[None,:,:,:], size=best_resolution, mode='bilinear')
181
+ # divide image tensor into patches
182
+ patches = image.unfold(2, base_height, base_height).unfold(3, base_width, base_width)
183
+ patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(best_resolution_grid[0]*best_resolution_grid[1], -1, base_height, base_width)
184
+ return (patches, best_resolution_grid)
185
+
186
+ def process_anyres_image_global(image, max_num_crops=None, base_width=768, base_height=768):
187
+ """
188
+ Process an image with variable resolutions.
189
+
190
+ Args:
191
+ image (torch.Tensor): The input image to be processed.
192
+ max_num_crops (int): Maximum number of crops
193
+
194
+ Returns:
195
+ torch.Tensor: A tensor containing the processed image patches.
196
+ """
197
+ assert max_num_crops is not None
198
+ grid_pinpoints = []
199
+ for i in range(1, max_num_crops+1):
200
+ for j in range(1, max_num_crops // i + 1):
201
+ grid_pinpoints.append((i, j))
202
+ grid_pinpoints = [(int(res[0] * base_width), int(res[1] * base_height)) for res in grid_pinpoints]
203
+
204
+ if type(grid_pinpoints) is list:
205
+ possible_resolutions = grid_pinpoints
206
+ else:
207
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
208
+
209
+ best_resolution = select_best_resolution((image.shape[2], image.shape[1]), possible_resolutions)
210
+ # NOTE: reverse best_resolution from (width, height) to (height, width)
211
+ best_resolution = (best_resolution[1], best_resolution[0])
212
+ best_resolution_grid = (best_resolution[0] // base_height, best_resolution[1] // base_width)
213
+
214
+ # resize image tensor to best resolution
215
+ image = torch.nn.functional.interpolate(image[None,:,:,:], size=best_resolution, mode='bilinear')
216
+ return image
217
+
218
+ class preprocessor():
219
+ def __init__(self, image_preprocessor, base_resolution=(256, 256)):
220
+ self.image_preprocessor = image_preprocessor
221
+ self.crop_size = {
222
+ 'height': base_resolution[0],
223
+ 'width': base_resolution[1]
224
+ }
225
+ self.image_mean = image_preprocessor.transforms[-1].mean
226
+
227
+ def preprocess(self, image, return_tensors='pt'):
228
+ image = self.image_preprocessor(image).unsqueeze(0)
229
+ return {
230
+ 'pixel_values': image,
231
+ }
232
+
233
+ class MagmaImageProcessor(BaseImageProcessor):
234
+ r"""
235
+ Constructs a Magma image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques
236
+ for processing high resolution images as explained in the [InternLM-XComposer2-4KHD](https://arxiv.org/pdf/2404.06512)
237
+
238
+ Args:
239
+ anyres_strategy (`str`):
240
+ strategy to cope with high-resolution images. one conventional way is multi-crop and many other works to accomadate clip-vit models.
241
+ however, since we are using convnext, which is essentially convnet, so we can use arbitary resolution images. as such, we use global strategy by defualt,
242
+ i.e., directly resize image holistically to a certain resolution.
243
+ base_img_size (int, *optional*, defaults to 768):
244
+ as convnext has 1/32 downsample rate, we use 768 as the base resolution so that the resulted feature map is 24x24.
245
+ num_crops (int, *optional*, defaults to 1):
246
+ number of effective crops when coping with images with higher resolution than 768x768. note that num_crops > 1 does not mean we are cropping the image.
247
+ """
248
+
249
+ model_input_names = ["pixel_values"]
250
+
251
+ def __init__(
252
+ self,
253
+ anyres_strategy: str = 'global',
254
+ base_img_size: int = 768,
255
+ num_crops: int = 1,
256
+ do_convert_rgb: bool = True,
257
+ image_mean: List[float] = OPENAI_CLIP_MEAN,
258
+ image_std: List[float] = OPENAI_CLIP_STD,
259
+ **kwargs,
260
+ ) -> None:
261
+ super().__init__(**kwargs)
262
+ self.base_img_size = base_img_size
263
+ self.anyres_strategy = anyres_strategy
264
+ self.num_crops = num_crops
265
+ self.do_convert_rgb = do_convert_rgb
266
+ self.image_mean = image_mean
267
+ self.image_std = image_std
268
+
269
+ def preprocess(
270
+ self,
271
+ images: Union[ImageInput, List[ImageInput]],
272
+ do_pad: bool = False,
273
+ do_convert_rgb: bool = None,
274
+ return_tensors: Optional[Union[str, TensorType]] = None,
275
+ num_crops: int = None,
276
+ ):
277
+ """
278
+ Args:
279
+ images (`ImageInput` or `List[ImageInput]`):
280
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
281
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
282
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
283
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
284
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
285
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
286
+ `True`.
287
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
288
+ Whether to convert the image to RGB.
289
+ return_tensors (`str` or `TensorType`, *optional*):
290
+ The type of tensors to return. Can be one of:
291
+ - Unset: Return a list of `np.ndarray`.
292
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
293
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
294
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
295
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
296
+ """
297
+ images = make_list_of_images(images)
298
+
299
+ if not valid_images(images):
300
+ raise ValueError(
301
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
302
+ "torch.Tensor, tf.Tensor or jax.ndarray."
303
+ )
304
+
305
+ if do_convert_rgb:
306
+ images = [convert_to_rgb(image) for image in images]
307
+
308
+ # tensor transform and normalize
309
+ img_processor = torchvision.transforms.Compose([
310
+ torchvision.transforms.ToTensor(),
311
+ torchvision.transforms.Normalize(self.image_mean, self.image_std)
312
+ ])
313
+
314
+ images = [img_processor(image) for image in images]
315
+ image_data_type = 'half' if images[0].type() == 'torch.HalfTensor' else 'float'
316
+ images = [image.float() for image in images]
317
+
318
+ # crop images to the same size
319
+ image_patches = [process_anyres_image(image, self.num_crops if num_crops is None else num_crops, base_width=self.base_img_size, base_height=self.base_img_size) for image in images]
320
+ pixel_values = torch.cat([image[0] for image in image_patches], dim=0)
321
+ # pixel_values = [image[0] for image in image_patches]
322
+ image_sizes = [image_patch[1] for image_patch in image_patches]
323
+
324
+ if image_data_type == 'half':
325
+ pixel_values = pixel_values.half()
326
+
327
+ data = {
328
+ "pixel_values": pixel_values,
329
+ "image_sizes": image_sizes,
330
+ }
331
+ return BatchFeature(data=data, tensor_type=return_tensors)
332
+
333
+ AutoImageProcessor.register("MagmaImageProcessor", MagmaImageProcessor)