chrisc36 commited on
Commit
21ac790
1 Parent(s): 37daec6

Upload 5 files

Browse files
config_molmo.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from transformers import PretrainedConfig, AutoTokenizer
4
+
5
+
6
+ class MolmoConfig(PretrainedConfig):
7
+ model_type = "molmo"
8
+ keys_to_ignore_at_inference = ["past_key_values"]
9
+
10
+ def __init__(
11
+ self,
12
+ vocab_size=50304,
13
+ embedding_size=50304,
14
+ hidden_size=4096,
15
+ intermediate_size=11008,
16
+ num_hidden_layers=32,
17
+ num_attention_heads=32,
18
+ num_key_value_heads=None,
19
+ max_position_embeddings=2048,
20
+ initializer_range=0.02,
21
+ use_cache=True,
22
+ layer_norm_eps: float = 1e-5,
23
+ rope_theta=10000.0,
24
+ clip_qkv=None,
25
+ qkv_bias: bool = False,
26
+ weight_tying: bool = False,
27
+ use_position_ids: bool=True,
28
+ tie_word_embeddings: bool=True,
29
+ **kwargs,
30
+ ):
31
+ self.vocab_size = vocab_size
32
+ self.embedding_size = embedding_size
33
+ self.max_position_embeddings = max_position_embeddings
34
+ self.hidden_size = hidden_size
35
+ self.intermediate_size = intermediate_size
36
+ self.num_hidden_layers = num_hidden_layers
37
+ self.num_attention_heads = num_attention_heads
38
+ self.layer_norm_eps = layer_norm_eps
39
+ self.weight_tying = weight_tying
40
+ self.use_position_ids = use_position_ids
41
+
42
+ # for backward compatibility
43
+ if num_key_value_heads is None:
44
+ num_key_value_heads = num_attention_heads
45
+
46
+ self.num_key_value_heads = num_key_value_heads
47
+ self.initializer_range = initializer_range
48
+ self.use_cache = use_cache
49
+ self.rope_theta = rope_theta
50
+ self.clip_qkv = clip_qkv
51
+ self.qkv_bias = qkv_bias
52
+ self.tie_word_embeddings = tie_word_embeddings
53
+
54
+ super().__init__(
55
+ tie_word_embeddings=tie_word_embeddings,
56
+ **kwargs,
57
+ )
convert_to_hf.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+
5
+ import torch
6
+
7
+ from hf_molmo.config_molmo import MolmoConfig
8
+ from hf_molmo.image_preprocessing_molmo import MolmoImageProcessor
9
+ from hf_molmo.modelling_molmo import MOLMoForCausalLM
10
+ from hf_molmo.preprocessing_molmo import MolmoProcessor
11
+ from olmo import ModelConfig
12
+ from olmo.mm_data.data_utils import build_tokenizer
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def write_config(checkpoint_dir: str, output_dir: str):
18
+ # save config as HF config
19
+
20
+ logger.info(f"Loading checkpoint from {checkpoint_dir}")
21
+
22
+ config_path = os.path.join(checkpoint_dir, "config.yaml")
23
+ model_config = ModelConfig.load(config_path, key="model")
24
+ config_kwargs = model_config.asdict()
25
+ config_kwargs["use_cache"] = True
26
+ config_kwargs["vit_load_path"] = None
27
+ config_kwargs["llm_load_path"] = None
28
+ config = MolmoConfig(
29
+ vocab_size=model_config.vocab_size,
30
+ embedding_size=model_config.embedding_size,
31
+ hidden_size=model_config.d_model,
32
+ intermediate_size=model_config.mlp_hidden_size,
33
+ num_hidden_layers=model_config.n_layers,
34
+ num_attention_heads=model_config.n_heads,
35
+ num_key_value_heads=model_config.n_kv_heads,
36
+ max_position_embeddings=model_config.max_position_embeddings or model_config.max_sequence_length,
37
+ initializer_range=model_config.initializer_range,
38
+ use_cache=True,
39
+ layer_norm_eps=model_config.layer_norm_eps,
40
+ rope_theta=model_config.rope_theta,
41
+ clip_qkv=model_config.clip_qkv,
42
+ qkv_bias=model_config.qkv_bias,
43
+ weight_tying=model_config.weight_tying,
44
+ use_position_ids=True,
45
+ tie_word_embeddings=False
46
+ )
47
+
48
+ logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}")
49
+ config.save_pretrained(output_dir)
50
+
51
+ preprocessor = MolmoProcessor(
52
+ MolmoImageProcessor(
53
+ max_crops=model_config.max_crops
54
+ ), # FIXME now just assumes everything if fixed
55
+ build_tokenizer(model_config.tokenizer.identifier.split("m:")[1]).tokenizer
56
+ )
57
+ preprocessor.save_pretrained(output_dir)
58
+
59
+
60
+ def write_model(checkpoint_dir: str, output_dir: str, ignore_olmo_compatibility: bool = False):
61
+ # For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly.
62
+ # So, we explicitly store the model with the expected prefix.
63
+ old_model_path = os.path.join(checkpoint_dir, "model.pt")
64
+ new_model_path = os.path.join(output_dir, "pytorch_model.bin")
65
+
66
+ state_dict = torch.load(old_model_path)
67
+ new_state_dict = {f"{MOLMoForCausalLM.base_model_prefix}.{key}": val for key, val in state_dict.items()}
68
+ torch.save(new_state_dict, new_model_path)
69
+
70
+
71
+ def convert_checkpoint(checkpoint_dir: str, output_dir: str):
72
+ os.makedirs(output_dir, exist_ok=True)
73
+ write_config(checkpoint_dir, output_dir)
74
+ write_model(checkpoint_dir, output_dir)
75
+
76
+
77
+ def main():
78
+ parser = argparse.ArgumentParser(
79
+ description="Adds a config.json to the checkpoint directory, and creates pytorch_model.bin, "
80
+ "making it easier to load weights as HF models."
81
+ )
82
+ parser.add_argument("checkpoint_dir")
83
+ parser.add_argument("output_dir")
84
+ args = parser.parse_args()
85
+ convert_checkpoint(args.checkpoint_dir, args.output_dir)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ main()
image_preprocessing_molmo.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processor class for Molmo"""
2
+ from typing import List, Optional, Union, Mapping
3
+
4
+ import numpy as np
5
+ import einops
6
+ import torch
7
+ import torchvision.transforms
8
+ from torchvision.transforms import InterpolationMode
9
+ from torchvision.transforms.functional import convert_image_dtype
10
+
11
+ from transformers.image_utils import (
12
+ OPENAI_CLIP_MEAN,
13
+ OPENAI_CLIP_STD,
14
+ ImageInput,
15
+ is_valid_image,
16
+ )
17
+ from transformers.processing_utils import ImagesKwargs
18
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
19
+ from transformers.utils import TensorType, is_vision_available, logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ def make_batched_images(images) -> List[List[ImageInput]]:
26
+ """
27
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
28
+
29
+ Args:
30
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
31
+ The input image.
32
+
33
+ Returns:
34
+ list: A list of images.
35
+ """
36
+ if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
37
+ return [img for img_list in images for img in img_list]
38
+
39
+ elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
40
+ return images
41
+
42
+ elif is_valid_image(images):
43
+ return [images]
44
+
45
+ raise ValueError(f"Could not make batched images from {images}")
46
+
47
+
48
+ def pad_to_bounding_box(
49
+ image, offset_height, offset_width, target_height,
50
+ target_width, value=0
51
+ ):
52
+ height, width = image.shape[:2]
53
+ after_padding_width = target_width - offset_width - width
54
+ after_padding_height = target_height - offset_height - height
55
+ return np.pad(image, [
56
+ [offset_height, after_padding_height],
57
+ [offset_width, after_padding_width],
58
+ [0, 0]
59
+ ], constant_values=value)
60
+
61
+
62
+ def normalize_image(image, offset, scale):
63
+ image -= np.array(offset, dtype=np.float32)[None, None, :]
64
+ image /= np.array(scale, dtype=np.float32)[None, None, :]
65
+ return image
66
+
67
+
68
+ def resize_and_pad(
69
+ image,
70
+ desired_output_size,
71
+ resize_method=InterpolationMode.BILINEAR,
72
+ pad_value=0,
73
+ normalize=True,
74
+ image_mean=OPENAI_CLIP_MEAN,
75
+ image_std=OPENAI_CLIP_STD,
76
+ ):
77
+ desired_height, desired_width = desired_output_size
78
+ height, width = image.shape[:2]
79
+
80
+ # Cast into float32 since the training code did this in float32 and it (very rarely) effects
81
+ # the results after rounding.
82
+ image_scale_y = np.array(desired_height, np.float32) / np.array(height, np.float32)
83
+ image_scale_x = np.array(desired_width, np.float32) / np.array(width, np.float32)
84
+ image_scale = min(image_scale_x, image_scale_y)
85
+ scaled_height = int(np.array(height, np.float32) * image_scale)
86
+ scaled_width = int(np.array(width, np.float32) * image_scale)
87
+
88
+ # if resize_method == "tensorflow":
89
+ # FIXME remove
90
+ import tensorflow as tf
91
+ image = tf.image.convert_image_dtype(tf.constant(image), dtype=tf.float32)
92
+ image = tf.image.resize(
93
+ image,
94
+ [scaled_height, scaled_width],
95
+ method=tf.image.ResizeMethod.BILINEAR,
96
+ antialias=True,
97
+ )
98
+ image = tf.clip_by_value(image, 0.0, 1.0)
99
+ image = image.numpy()
100
+ # else:
101
+ # image = torch.permute(torch.from_numpy(image), [2, 0, 1])
102
+ # image = convert_image_dtype(image) # resize in flaot32
103
+ # image = torchvision.transforms.Resize(
104
+ # [scaled_height, scaled_width], InterpolationMode.BILINEAR, antialias=True
105
+ # )(image)
106
+ # image = torch.clip(image, 0.0, 1.0)
107
+ # image = torch.permute(image, [1, 2, 0]).numpy()
108
+
109
+ top_pad = (desired_height - scaled_height) // 2
110
+ left_pad = (desired_width - scaled_width) // 2
111
+ padding = [
112
+ [top_pad, desired_height - scaled_height - top_pad],
113
+ [left_pad, desired_width - scaled_width - left_pad],
114
+ [0, 0]
115
+ ]
116
+ image_mask = np.pad(np.ones_like(image[:, :, 0], dtype=bool), padding[:2])
117
+ image = np.pad(image, padding, constant_values=pad_value)
118
+ if normalize:
119
+ image = normalize_image(image, offset=image_mean, scale=image_std)
120
+ return image, image_mask
121
+
122
+
123
+ def select_tiling(h, w, patch_size, max_num_patches):
124
+ """Decide how best to divide in image of size [w, h] in up to max_num_patches of size patch_size"""
125
+ original_size = np.stack([h, w]) # [1, 2]
126
+ original_res = h * w
127
+ tilings = []
128
+ for i in range(1, max_num_patches+1):
129
+ for j in range(1, max_num_patches+1):
130
+ if i*j <= max_num_patches:
131
+ tilings.append((i, j))
132
+ # sort so argmin and argmax favour smaller tilings in the event of a tie
133
+ tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
134
+ candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
135
+ candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
136
+
137
+ # How much we would need to scale the image to fit exactly in each tiling
138
+ original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
139
+ required_scale_d = candidate_resolutions.astype(np.float32) / original_size
140
+ required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
141
+ if np.all(required_scale < 1):
142
+ # We are forced to downscale, so try to minimize the amount of downscaling
143
+ ix = np.argmax(required_scale)
144
+ else:
145
+ # Pick the resolution that required the least upscaling so that it most closely fits the image
146
+ required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
147
+ ix = np.argmin(required_scale)
148
+ return candidate_tilings[ix]
149
+
150
+
151
+ class MolmoImagesKwargs(ImagesKwargs, total=False):
152
+ max_crops: Optional[int]
153
+ overlap_margins: Optional[List[int]]
154
+ base_image_input_size: Optional[List[int]]
155
+ image_token_length_w: Optional[int]
156
+ image_token_length_h: Optional[int]
157
+ image_patch_size: Optional[int]
158
+ image_padding_mask: Optional[bool]
159
+
160
+
161
+ class MolmoImageProcessor(BaseImageProcessor):
162
+ """Preprocess images and multi-model inputs"""
163
+
164
+ def __init__(
165
+ self,
166
+ max_crops: int = 12,
167
+ overlap_margins: List[int] = (4, 4),
168
+ base_image_input_size: List[int] = (336, 336),
169
+ image_token_length_w: int = 12,
170
+ image_token_length_h: int = 12,
171
+ image_patch_size: int = 14,
172
+ image_padding_mask: bool = True,
173
+ do_normalize: bool = True,
174
+ image_mean: Optional[Union[float, List[float]]] = None,
175
+ image_std: Optional[Union[float, List[float]]] = None,
176
+ **kwargs,
177
+ ):
178
+ super().__init__(**kwargs)
179
+ self.max_crops = max_crops
180
+ self.overlap_margins = overlap_margins
181
+ self.base_image_input_size = base_image_input_size
182
+ self.image_token_length_w = image_token_length_w
183
+ self.image_token_length_h = image_token_length_h
184
+ self.image_patch_size = image_patch_size
185
+ self.image_padding_mask = image_padding_mask
186
+ self.do_normalize = do_normalize
187
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
188
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
189
+
190
+ def image_to_patches_and_tokens(
191
+ self,
192
+ image: ImageInput,
193
+ image_patch_token_id: int,
194
+ image_col_token_id: int,
195
+ image_start_token_id: int,
196
+ image_end_token_id: int,
197
+ max_crops: Optional[int] = None,
198
+ overlap_margins: Optional[List[int]] = None,
199
+ base_image_input_size: Optional[Union[int, List[int]]] = None,
200
+ image_token_length_w: Optional[int] = None,
201
+ image_token_length_h: Optional[int] = None,
202
+ image_patch_size: Optional[int] = None,
203
+ ):
204
+ """Preprocesses an image
205
+
206
+ Returns:
207
+ crops: (n_crops, n_patches, patch_dim) individual crops, `n_crops` might
208
+ change between images but the other dimension are fixed
209
+ tokens: (n_tokens,) int32 tokens, pad tokens indicating where to insert the
210
+ patch features, might include other special tokens as well
211
+ patch_ordering: (n_crops, n_tokens_per_crop) order image features should be inserted
212
+ into the `tokens`, negative values indicates patches features to exclude
213
+ padding_mask: (n_crops, n_patches) what percent of each crop is padding, be None
214
+ if the image mask is not being used.
215
+ """
216
+ if isinstance(base_image_input_size, int):
217
+ base_image_input_size = (base_image_input_size, base_image_input_size)
218
+
219
+ base_image_input_d = image_patch_size
220
+ tokens_per_image = image_token_length_w * image_token_length_h
221
+ image_base_patch_w = base_image_input_size[1] // base_image_input_d
222
+ image_base_patch_h = base_image_input_size[0] // base_image_input_d
223
+
224
+ original_image_h, original_image_w = image.shape[:2]
225
+ crop_size = base_image_input_size[0]
226
+
227
+ # Discard this many patches from the (left/top, right/bottom) of crops
228
+ left_margin, right_margin = overlap_margins
229
+ # left_margin, right_margin = 2, 2
230
+ assert left_margin % 2 == 0 # Required for compatibility with 2x2 pooling
231
+ total_margin_pixels = base_image_input_d*(right_margin + left_margin) # pixels removed per dim
232
+ crop_patches = base_image_input_size[0] // base_image_input_d # patches per crop dim
233
+ crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
234
+ crop_window_size = crop_window_patches * base_image_input_d
235
+ tiling = select_tiling(
236
+ original_image_h - total_margin_pixels,
237
+ original_image_w - total_margin_pixels,
238
+ crop_window_size,
239
+ max_crops
240
+ )
241
+ src, img_mask = resize_and_pad(
242
+ image,
243
+ [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels]
244
+ )
245
+
246
+ # Now we have to split the image into crops, while keeping track of how each patch in the
247
+ # each crop should be ordered in the global image, this require a lot of tricky booking
248
+ n_crops = tiling[0] * tiling[1]
249
+ patches_arr = []
250
+ mask_arr = []
251
+ patch_ordering_arr = []
252
+
253
+ # We assume 2x2 pooling, but can allow padding the right/bottom with extra
254
+ # patches if the number of patches per side is not even
255
+ assert (crop_patches+1)//2 == image_token_length_h
256
+ assert (crop_patches+1)//2 == image_token_length_w
257
+ on = 0
258
+ on_patch = 0
259
+ for i in range(tiling[0]):
260
+ y0 = i*crop_window_size
261
+ if i == 0:
262
+ crop_y0 = 0
263
+ else:
264
+ crop_y0 = left_margin // 2
265
+
266
+ crop_h = image_base_patch_h - (right_margin + left_margin)
267
+ if i == 0:
268
+ crop_h += left_margin
269
+ if i == (tiling[0]-1):
270
+ crop_h += right_margin
271
+ for j in range(tiling[1]):
272
+ x0 = j*crop_window_size
273
+ if j == 0:
274
+ crop_x0 = 0
275
+ else:
276
+ crop_x0 = left_margin // 2
277
+
278
+ crop_w = image_base_patch_w - (right_margin + left_margin)
279
+ if j == 0:
280
+ crop_w += left_margin
281
+ if j == (tiling[1]-1):
282
+ crop_w += right_margin
283
+
284
+ pooled_w = (crop_w + 1) // 2
285
+ pooled_h = (crop_h + 1) // 2
286
+ patch_ordering_arr.append(
287
+ pad_to_bounding_box(
288
+ np.reshape(np.arange(on, on+pooled_h*pooled_w, dtype=np.int32), (pooled_h, pooled_w, 1)),
289
+ crop_y0, crop_x0, image_token_length_h, image_token_length_w, value=-1
290
+ )[:, :, 0]
291
+ )
292
+ patches_arr.append(src[y0:y0+crop_size, x0:x0+crop_size])
293
+ mask_arr.append(img_mask[y0:y0+crop_size, x0:x0+crop_size])
294
+
295
+ on += pooled_h*pooled_w
296
+ on_patch += 1
297
+ patches = np.stack(patches_arr)
298
+ patch_ordering = np.stack(patch_ordering_arr)
299
+ img_mask = np.stack(mask_arr)
300
+
301
+ # Switch to [n_crops, n_patches, pixels_per_patch] format
302
+ image_layout_impatch_w, image_layout_impatch_h = tiling[0], tiling[1]
303
+ patches = einops.rearrange(
304
+ patches, 'p (h dh) (w dw) c -> p (h w) (dh dw c)',
305
+ dh=base_image_input_d,
306
+ dw=base_image_input_d,
307
+ h=image_base_patch_h,
308
+ w=image_base_patch_w
309
+ )
310
+ img_mask = einops.rearrange(
311
+ img_mask, 'p (h dh) (w dw) -> p (h w) (dh dw)',
312
+ dh=base_image_input_d,
313
+ dw=base_image_input_d,
314
+ h=image_base_patch_h,
315
+ w=image_base_patch_w
316
+ )
317
+
318
+ img_mask = img_mask.astype(np.float32).mean(axis=-1)
319
+ patch_ordering = np.reshape(patch_ordering, [-1])
320
+ valid = patch_ordering >= 0
321
+
322
+ # Transpose order, to get left-to-right order instead of crop-by-crop order
323
+ patch_ordering_rh = np.reshape(
324
+ patch_ordering,
325
+ [tiling[0], tiling[1], image_token_length_h, image_token_length_w]
326
+ )
327
+ patch_ordering_rh = np.transpose(patch_ordering_rh, [0, 2, 1, 3])
328
+ patch_ordering_rh = np.reshape(patch_ordering_rh, [-1])
329
+
330
+ # The transpose will screw up which patches are masked, project the
331
+ # new order into sparse structure of `patch_ordering` to fix this
332
+ patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
333
+
334
+ # Now build the output tokens
335
+ h = tiling[0] * crop_window_patches + (right_margin+left_margin)
336
+ w = tiling[1] * crop_window_patches + (right_margin+left_margin)
337
+ per_row = np.full(
338
+ ((w+1)//2,),
339
+ image_patch_token_id,
340
+ )
341
+ per_row = np.concatenate([per_row, [image_col_token_id]], 0)
342
+
343
+ joint = np.tile(per_row, [(h+1)//2])
344
+ joint = [
345
+ [image_start_token_id],
346
+ joint,
347
+ [image_end_token_id]
348
+ ]
349
+
350
+ # Finally do the same for the global image
351
+ resized, _ = resize_and_pad(image, base_image_input_size)
352
+ resized = einops.rearrange(
353
+ resized, '(h dh) (w dw) c -> (h w) (dh dw c)',
354
+ dh=base_image_input_d,
355
+ dw=base_image_input_d,
356
+ h=image_base_patch_h,
357
+ w=image_base_patch_w
358
+ )
359
+ patches = np.concatenate([np.expand_dims(resized, 0), patches], 0)
360
+
361
+ # Global image goes first, so the order of patches in previous crops gets increased
362
+ patch_ordering = np.where(
363
+ patch_ordering >= 0,
364
+ patch_ordering + tokens_per_image,
365
+ -1
366
+ )
367
+ patch_ordering = np.concatenate([np.arange(0, tokens_per_image), patch_ordering], 0)
368
+ per_row = np.full(
369
+ (image_token_length_w,),
370
+ image_patch_token_id,
371
+ )
372
+ per_row = np.concatenate([per_row, [image_col_token_id]], 0)
373
+ extra_tokens = np.tile(per_row, [image_token_length_h])
374
+ joint = [
375
+ [image_start_token_id],
376
+ extra_tokens,
377
+ [image_end_token_id],
378
+ ] + joint
379
+
380
+ joint = np.concatenate(joint, 0)
381
+ img_mask = np.pad(img_mask, [[0, 1], [0, 0]], constant_values=-1)
382
+ return patches, joint, patch_ordering, img_mask
383
+
384
+ def build_image_input_idx(
385
+ self,
386
+ image_tokens: np.ndarray,
387
+ patch_order: np.ndarray,
388
+ image_patch_token_id: int,
389
+ no_image: Optional[bool] = None,
390
+ image_token_length_w: Optional[int] = None,
391
+ image_token_length_h: Optional[int] = None,
392
+ ):
393
+ """Converts `patch_order` into a mapping of token_id -> patch_id"""
394
+
395
+ tokens_per_image = image_token_length_w * image_token_length_h
396
+ if no_image is not None and no_image:
397
+ return np.zeros((0, tokens_per_image), np.int32)
398
+
399
+ # Indices to insert the patches
400
+ image_input_idx = image_tokens == image_patch_token_id
401
+ image_input_idx = np.nonzero(image_input_idx)[0].astype(np.int32)
402
+
403
+ if patch_order is not None:
404
+ n_tokens = image_input_idx.shape[0]
405
+ patch_order = np.reshape(patch_order, [-1])
406
+ n_patches = patch_order.shape[0]
407
+
408
+ valid = patch_order >= 0
409
+ n_valid_patches = valid.sum()
410
+ assert len(image_input_idx) == n_valid_patches
411
+
412
+ sorted_patch_ixs = np.zeros([n_tokens], np.int32)
413
+ sorted_patch_ixs[patch_order[valid]] = np.arange(n_valid_patches, dtype=np.int32)
414
+
415
+ # Project the inverted mapping into same sparse structure
416
+ sorted_patch_ixs_ex = np.full(np.shape(patch_order), -1)
417
+ sorted_patch_ixs_ex[valid] = sorted_patch_ixs
418
+
419
+ # Do the gather and then re-masked outputs that were masked in `sorted_patch_ixs`
420
+ valid = (sorted_patch_ixs_ex >= 0).astype(np.int32)
421
+ image_input_idx = image_input_idx[sorted_patch_ixs_ex*valid]
422
+ image_input_idx = image_input_idx*valid - 100*(1 - valid)
423
+ image_input_idx = np.reshape(image_input_idx, [-1, tokens_per_image])
424
+ return image_input_idx
425
+
426
+ def preprocess(
427
+ self,
428
+ image: np.ndarray,
429
+ image_patch_token_id: int,
430
+ image_col_token_id: int,
431
+ image_start_token_id: int,
432
+ image_end_token_id: int,
433
+ max_crops: Optional[int] = None,
434
+ overlap_margins: Optional[List[int]] = None,
435
+ base_image_input_size: Optional[Union[int, List[int]]] = None,
436
+ image_token_length_w: Optional[int] = None,
437
+ image_token_length_h: Optional[int] = None,
438
+ image_patch_size: Optional[int] = None,
439
+ **kwargs,
440
+ ):
441
+ """Preprocesses a single image"""
442
+
443
+ max_crops = max_crops or self.max_crops
444
+ overlap_margins = overlap_margins or self.overlap_margins
445
+ base_image_input_size = base_image_input_size or self.base_image_input_size
446
+ image_token_length_w = image_token_length_w or self.image_token_length_w
447
+ image_token_length_h = image_token_length_h or self.image_token_length_h
448
+ image_patch_size = image_patch_size or self.image_patch_size
449
+
450
+ crops, image_tokens, patch_ordering, img_mask = self.image_to_patches_and_tokens(
451
+ image,
452
+ image_patch_token_id,
453
+ image_col_token_id,
454
+ image_start_token_id,
455
+ image_end_token_id,
456
+ max_crops,
457
+ overlap_margins,
458
+ base_image_input_size,
459
+ image_token_length_w,
460
+ image_token_length_h,
461
+ image_patch_size,
462
+ )
463
+ patch_idx = self.build_image_input_idx(
464
+ image_tokens,
465
+ patch_ordering,
466
+ image_patch_token_id,
467
+ image_token_length_w=image_token_length_w,
468
+ image_token_length_h=image_token_length_h,
469
+ )
470
+ return crops, image_tokens, patch_idx, img_mask
471
+
472
+ def multimodal_preprocess(
473
+ self,
474
+ images: np.ndarray,
475
+ tokens: List[int],
476
+ image_idx: np.ndarray,
477
+ sequence_length: int,
478
+ image_patch_token_id: int,
479
+ image_col_token_id: int,
480
+ image_start_token_id: int,
481
+ image_end_token_id: int,
482
+ **kwargs,
483
+ ):
484
+ """Merge images and text tokens into multi-modal features for the model
485
+
486
+ :param images: images to use as input
487
+ :param tokens: input text tokens
488
+ :param image_idx: where to insert the images into `tokens`
489
+ :params image_patch_token_id: id to use of tokens that will contain image features
490
+ :params image_col_token_id: token id for image column special tokens
491
+ :params image_start_token_id: token id for image start special tokens
492
+ :params image_end_token_id: token id for image end special tokens
493
+ :params kwargs: override preprocessor default args
494
+ """
495
+ max_total_crops = kwargs.get("max_crops") or self.max_crops
496
+ image_token_length_w = kwargs.get("image_token_length_w") or self.image_token_length_w
497
+ image_token_length_h = kwargs.get("image_token_length_h") or self.image_token_length_h
498
+ image_patch_size = kwargs.get("image_patch_size") or self.image_patch_size
499
+ base_image_input_size = kwargs.get("base_image_input_size") or self.base_image_input_size
500
+ image_num_patch = (
501
+ base_image_input_size[0] // image_patch_size,
502
+ base_image_input_size[1] // image_patch_size,
503
+ )
504
+ image_padding_mask = kwargs.get("image_padding_mask") or self.image_padding_mask
505
+
506
+ tokens_per_image = image_token_length_w * image_token_length_h
507
+ n_pixels = image_patch_size * image_patch_size * 3
508
+ n_patches = image_num_patch[0] * image_num_patch[1]
509
+
510
+ if images is None:
511
+ return {
512
+ "input_ids": tokens,
513
+ "images": None,
514
+ "image_input_idx": None
515
+ }
516
+ else:
517
+ n = len(images)
518
+ all_crops = []
519
+ all_image_idx = []
520
+ out_tokens = []
521
+ all_crop_masks = []
522
+
523
+ for ix in range(n):
524
+ token_ix = image_idx[ix]
525
+ crops, image_tokens, patch_idx, img_mask = self.preprocess(
526
+ images[ix],
527
+ image_patch_token_id,
528
+ image_col_token_id,
529
+ image_start_token_id,
530
+ image_end_token_id,
531
+ **kwargs,
532
+ )
533
+
534
+ if token_ix == -1: # -1 is an image inserted at the very start
535
+ start = 0
536
+ token_ix = 0
537
+ end = 0
538
+ else:
539
+ start = 0 if ix == 0 else image_idx[ix-1] + 1
540
+ end = token_ix + 1
541
+
542
+ all_image_idx.append(patch_idx + token_ix)
543
+ all_crops.append(crops)
544
+ out_tokens.append(tokens[start:token_ix])
545
+ out_tokens.append(image_tokens)
546
+ if ix == (n - 1):
547
+ out_tokens.append(tokens[end:])
548
+ if image_padding_mask:
549
+ all_crop_masks.append(img_mask)
550
+
551
+ input_ids = np.concatenate(out_tokens, 0)
552
+ images = np.concatenate(all_crops, 0)
553
+ image_input_idx = np.concatenate(all_image_idx, 0)
554
+ if image_padding_mask:
555
+ image_masks = np.concatenate(all_crop_masks, 0)
556
+ else:
557
+ image_masks = None
558
+
559
+ out = {
560
+ "input_ids": input_ids,
561
+ "images": images,
562
+ "image_input_idx": image_input_idx
563
+ }
564
+ if image_masks is not None:
565
+ out["image_masks"] = image_masks
566
+ return out
modelling_molmo.py ADDED
The diff for this file is too large to render. See raw diff
 
preprocessing_molmo.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor class for Molmo.
3
+ """
4
+
5
+ from typing import List, Union, Optional
6
+
7
+
8
+ try:
9
+ from typing import Unpack
10
+ except ImportError:
11
+ from typing_extensions import Unpack
12
+
13
+ import numpy as np
14
+ import torch
15
+
16
+ from transformers.image_utils import ImageInput
17
+ from transformers.processing_utils import (
18
+ TextKwargs,
19
+ ProcessingKwargs,
20
+ ProcessorMixin,
21
+ )
22
+
23
+ from transformers.tokenization_utils_base import TextInput
24
+ from transformers.utils import logging
25
+
26
+ from transformers import AutoTokenizer
27
+ from hf_molmo.image_preprocessing_molmo import MolmoImagesKwargs, make_batched_images, MolmoImageProcessor
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ DEFAULT_IMAGE_PATCH_TOKEN = f"<im_patch>"
34
+ DEFAULT_IM_START_TOKEN = f"<im_start>"
35
+ DEFAULT_IM_END_TOKEN = f"<im_end>"
36
+ DEFAULT_IM_COL_TOKEN = f"<im_col>"
37
+ IMAGE_PROMPT = "<|image|>"
38
+
39
+ EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
40
+
41
+
42
+ def get_special_token_ids(tokenizer):
43
+ ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
44
+ assert len(ids) == len(EXTRA_TOKENS)
45
+ return {k: i for k, i in zip(EXTRA_TOKENS, ids)}
46
+
47
+
48
+ class MolmoTextKwargs(TextKwargs, total=False):
49
+ style: Optional[str]
50
+ system_prompt: Optional[str]
51
+ message_format: Optional[str]
52
+ always_start_with_space: Optional[bool]
53
+ sequence_length: Optional[int]
54
+
55
+
56
+ class MolmoProcessorKwargs(ProcessingKwargs, total=False):
57
+ text_kwargs: MolmoTextKwargs
58
+ images_kwargs: MolmoImagesKwargs
59
+ _defaults = {
60
+ "images_kwargs": {
61
+ "max_crops": 12,
62
+ "overlap_margins": [4, 4],
63
+ "base_image_input_size": [336, 336],
64
+ "image_token_length_w": 12,
65
+ "image_token_length_h": 12,
66
+ "image_patch_size": 14,
67
+ "image_padding_mask": True,
68
+ },
69
+ "text_kwargs": {
70
+ "style": "long_caption",
71
+ "system_prompt": "none",
72
+ "message_format": "role",
73
+ "always_start_with_space": True,
74
+ "sequence_length": 1536,
75
+ "padding": False,
76
+ },
77
+ }
78
+
79
+
80
+ class MolmoProcessor(ProcessorMixin):
81
+ attributes = ["image_processor", "tokenizer"]
82
+ image_processor_class = "MolmoImageProcessor"
83
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
84
+
85
+ def __init__(self, image_processor: MolmoImageProcessor = None, tokenizer : AutoTokenizer = None, **kwargs):
86
+ self.image_processor = image_processor
87
+ self.tokenizer = tokenizer
88
+ self._special_tokens = None
89
+
90
+ @property
91
+ def special_token_ids(self):
92
+ if self._special_tokens is None:
93
+ self._special_tokens = get_special_token_ids(self.tokenizer)
94
+ return self._special_tokens
95
+
96
+ def get_tokens_input(self, prompt, message_format, always_start_with_space):
97
+ if message_format == "none" or message_format is None:
98
+ pass
99
+ elif message_format == "role":
100
+ prompt = "User: " + prompt + " Assistant:"
101
+ else:
102
+ raise NotImplementedError(f"Message format {message_format} not implemented")
103
+
104
+ if always_start_with_space:
105
+ prompt = " " + prompt
106
+
107
+ tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
108
+
109
+ return tokens
110
+
111
+ def process(
112
+ self,
113
+ text: TextInput = None,
114
+ images: ImageInput = None,
115
+ **kwargs: Unpack[MolmoProcessorKwargs],
116
+ ):
117
+ output_kwargs = self._merge_kwargs(
118
+ MolmoProcessorKwargs,
119
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
120
+ **kwargs,
121
+ )
122
+
123
+ tokens = self.get_tokens_input(
124
+ text,
125
+ output_kwargs["text_kwargs"]["message_format"],
126
+ output_kwargs["text_kwargs"]["always_start_with_space"],
127
+ )
128
+
129
+ image_token_id = self.special_token_ids[IMAGE_PROMPT]
130
+
131
+ if images is not None:
132
+ images = make_batched_images(images)
133
+ images = [np.array(image).astype(np.uint8) for image in images]
134
+ # For now only support inserting images at the start
135
+ image_idx = [-1]*len(images)
136
+ else:
137
+ image_idx = None
138
+
139
+ sequence_length = output_kwargs["text_kwargs"]["sequence_length"]
140
+
141
+ image_patch_token_id = self.special_token_ids[DEFAULT_IMAGE_PATCH_TOKEN]
142
+ image_col_token_id = self.special_token_ids[DEFAULT_IM_COL_TOKEN]
143
+ image_start_token_id = self.special_token_ids[DEFAULT_IM_START_TOKEN]
144
+ image_end_token_id = self.special_token_ids[DEFAULT_IM_END_TOKEN]
145
+ out = self.image_processor.multimodal_preprocess(
146
+ images=images,
147
+ image_idx=image_idx,
148
+ tokens=np.asarray(tokens).astype(np.int32),
149
+ sequence_length=sequence_length,
150
+ image_patch_token_id=image_patch_token_id,
151
+ image_col_token_id=image_col_token_id,
152
+ image_start_token_id=image_start_token_id,
153
+ image_end_token_id=image_end_token_id,
154
+ **output_kwargs["images_kwargs"]
155
+ )
156
+
157
+ # Prepend BOS
158
+ # qwen2 and olmo do not have a BOS, and instead use EOS as a generic seperator token.
159
+ bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
160
+ decoder_input_tokens = np.pad(out["input_ids"], [[1, 0]], constant_values=bos)
161
+ out["input_ids"] = decoder_input_tokens
162
+ if "image_input_idx" in out:
163
+ # Shift patch mapping up by one since we added BOS
164
+ image_input_idx = out["image_input_idx"]
165
+ out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
166
+
167
+ for k, v in out.items():
168
+ out[k] = torch.from_numpy(v)
169
+
170
+ return out