ryanzhangfan commited on
Commit
b044178
1 Parent(s): 65c5db2

Upload 6 files

Browse files
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Emu3VisionVQModel"
4
+ ],
5
+ "attn_resolutions": [
6
+ 3
7
+ ],
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_emu3visionvq.Emu3VisionVQConfig",
10
+ "AutoModel": "modeling_emu3visionvq.Emu3VisionVQModel"
11
+ },
12
+ "ch": 256,
13
+ "ch_mult": [
14
+ 1,
15
+ 2,
16
+ 2,
17
+ 4
18
+ ],
19
+ "codebook_size": 32768,
20
+ "double_z": false,
21
+ "dropout": 0.0,
22
+ "embed_dim": 4,
23
+ "in_channels": 3,
24
+ "model_type": "Emu3VisionVQ",
25
+ "num_res_blocks": 2,
26
+ "out_channels": 3,
27
+ "temporal_downsample_factor": 4,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.44.0",
30
+ "z_channels": 4
31
+ }
configuration_emu3visionvq.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Emu3VisionVQ model configuration """
16
+
17
+ from typing import List
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class Emu3VisionVQConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`Emu3VisionVQ`]. It is used to instantiate an video movq
29
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30
+ defaults will yield a configuration to the VQ model presented in Emu3 paper.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ codebook_size (`int`, *optional*, defaults to 32768):
38
+ Codebook size of the VQ model.
39
+ embed_dim (`int`, *optional*, defaults to 4):
40
+ Dimension of the quantized vector in codebook.
41
+ z_channels (`int`, *optional*, defaults to 4):
42
+ Dimension of the output channel of encoder and the input channel of decoder
43
+ double_z (`bool`, *optional*, defaults to False):
44
+ Whether double the output dim of the encoder.
45
+ in_channels (`int`, *optional*, defaults to 3):
46
+ Input channel of encoder.
47
+ out_channels (`int`, *optional*, defaults to 3):
48
+ Output channel of decoder.
49
+ temporal_downsample_factor (`int`, *optional*, defaults to 4):
50
+ Temporal downsample factor.
51
+ ch (`int`, *optional*, defaults to 256):
52
+ Basic channel number of the intermediate blocks.
53
+ ch_mult (`List[int]`, *optional*, defaults to `[1, 2, 2, 4]`):
54
+ Channel scaling factor of the intermediate blocks.
55
+ num_res_blocks (`int`, *optional*, defaults to 2):
56
+ Residual block number in each stage.
57
+ attn_resolutions (`List[int]`, *optional*, defaults to 3):
58
+ Stage indices to apply attention.
59
+ dropout (`float`, *optional*, defaults to 0.0):
60
+ Dropout probability.
61
+
62
+ ```python
63
+ >>> from transformers import Emu3VisionVQ, Emu3VisionVQConfig
64
+
65
+ >>> # Initializing a video VQ model of Emu3 configuration
66
+ >>> configuration = Emu3VisionVQConfig()
67
+
68
+ >>> # Initializing a model from the Emu3 VQ model style configuration
69
+ >>> model = Emu3VisionVQModel(configuration)
70
+
71
+ >>> # Accessing the model configuration
72
+ >>> configuration = model.config
73
+ ```"""
74
+
75
+ model_type = "Emu3VisionVQ"
76
+
77
+ def __init__(
78
+ self,
79
+ codebook_size: int = 32768,
80
+ embed_dim: int = 4,
81
+ z_channels: int = 4,
82
+ double_z: bool = False,
83
+ in_channels: int = 3,
84
+ out_channels: int = 3,
85
+ temporal_downsample_factor: int = 4,
86
+ ch: int = 256,
87
+ ch_mult: List[int] = [1, 2, 2, 4],
88
+ num_res_blocks: int = 2,
89
+ attn_resolutions: List[int] = [3],
90
+ dropout: float = 0.0,
91
+ **kwargs,
92
+ ):
93
+ super().__init__(**kwargs)
94
+
95
+ self.codebook_size = codebook_size
96
+ self.embed_dim = embed_dim
97
+ self.z_channels = z_channels
98
+ self.double_z = double_z
99
+ self.in_channels = in_channels
100
+ self.out_channels = out_channels
101
+ self.temporal_downsample_factor = temporal_downsample_factor
102
+ self.ch = ch
103
+ self.ch_mult = ch_mult
104
+ self.num_res_blocks = num_res_blocks
105
+ self.attn_resolutions = attn_resolutions
106
+ self.dropout = dropout
image_processing_emu3visionvq.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for Emu3VisionVQ."""
16
+
17
+
18
+ import math
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ import numpy as np
22
+
23
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
24
+ from transformers.image_transforms import (
25
+ convert_to_rgb,
26
+ resize,
27
+ to_channel_dimension_format,
28
+ )
29
+ from transformers.image_utils import (
30
+ IMAGENET_STANDARD_MEAN,
31
+ IMAGENET_STANDARD_STD,
32
+ ChannelDimension,
33
+ ImageInput,
34
+ PILImageResampling,
35
+ get_image_size,
36
+ infer_channel_dimension_format,
37
+ is_scaled_image,
38
+ make_list_of_images,
39
+ to_numpy_array,
40
+ valid_images,
41
+ validate_preprocess_arguments,
42
+ )
43
+ from transformers.utils import TensorType, is_vision_available, logging
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+
49
+ if is_vision_available():
50
+ from PIL import Image
51
+
52
+
53
+ def smart_resize(
54
+ height: int, width: int, factor: int = 8, min_pixels: int = 512 * 512, max_pixels: int = 1024 * 1024
55
+ ):
56
+ """Rescales the image so that the following conditions are met:
57
+
58
+ 1. Both dimensions (height and width) are divisible by 'factor'.
59
+
60
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
61
+
62
+ 3. The aspect ratio of the image is maintained as closely as possible.
63
+
64
+ """
65
+ if height < factor or width < factor:
66
+ raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
67
+ elif max(height, width) / min(height, width) > 5:
68
+ raise ValueError(
69
+ f"absolute aspect ratio must be smaller than 5, got {max(height, width) / min(height, width)}"
70
+ )
71
+
72
+ h_bar = round(height / factor) * factor
73
+ w_bar = round(width / factor) * factor
74
+ if h_bar * w_bar > max_pixels:
75
+ beta = math.sqrt((height * width) / max_pixels)
76
+ h_bar = math.floor(height / beta / factor) * factor
77
+ w_bar = math.floor(width / beta / factor) * factor
78
+ elif h_bar * w_bar < min_pixels:
79
+ beta = math.sqrt(min_pixels / (height * width))
80
+ h_bar = math.ceil(height * beta / factor) * factor
81
+ w_bar = math.ceil(width * beta / factor) * factor
82
+
83
+ return h_bar, w_bar
84
+
85
+
86
+ class Emu3VisionVQImageProcessor(BaseImageProcessor):
87
+ r"""
88
+ Constructs a Emu3VisionVQ image processor that dynamically resizes images based on the original images.
89
+
90
+ Args:
91
+ do_resize (`bool`, *optional*, defaults to `True`):
92
+ Whether to resize the image's (height, width) dimensions.
93
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
94
+ Resampling filter to use when resizing the image.
95
+ do_rescale (`bool`, *optional*, defaults to `True`):
96
+ Whether to rescale the image by the specified scale `rescale_factor`.
97
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
98
+ Scale factor to use if rescaling the image.
99
+ do_normalize (`bool`, *optional*, defaults to `True`):
100
+ Whether to normalize the image.
101
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
102
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
103
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
104
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
105
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
106
+ Whether to convert the image to RGB.
107
+ min_pixels (`int`, *optional*, defaults to `512 * 512`):
108
+ The min pixels of the image to resize the image.
109
+ max_pixels (`int`, *optional*, defaults to `1024 * 1024`):
110
+ The max pixels of the image to resize the image.
111
+ spatial_factor (`int`, *optional*, defautls to 8):
112
+ The spatial downsample factor the image will be downsampled in feature extracting phase
113
+ """
114
+
115
+ model_input_names = ["pixel_values"]
116
+
117
+ def __init__(
118
+ self,
119
+ do_resize: bool = True,
120
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
121
+ do_rescale: bool = True,
122
+ rescale_factor: Union[int, float] = 1 / 255,
123
+ do_normalize: bool = True,
124
+ image_mean: Optional[Union[float, List[float]]] = None,
125
+ image_std: Optional[Union[float, List[float]]] = None,
126
+ do_convert_rgb: bool = True,
127
+ min_pixels: int = 512 * 512,
128
+ max_pixels: int = 1024 * 1024,
129
+ spatial_factor: int = 8,
130
+ **kwargs,
131
+ ) -> None:
132
+ super().__init__(**kwargs)
133
+ self.do_resize = do_resize
134
+ self.resample = resample
135
+ self.do_rescale = do_rescale
136
+ self.rescale_factor = rescale_factor
137
+ self.do_normalize = do_normalize
138
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
139
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
140
+ self.min_pixels = min_pixels
141
+ self.max_pixels = max_pixels
142
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
143
+ self.do_convert_rgb = do_convert_rgb
144
+ self.spatial_factor = spatial_factor
145
+
146
+ def _preprocess(
147
+ self,
148
+ images: ImageInput,
149
+ do_resize: Optional[bool] = None,
150
+ resample: PILImageResampling = None,
151
+ do_rescale: Optional[bool] = None,
152
+ rescale_factor: Optional[float] = None,
153
+ do_normalize: Optional[bool] = None,
154
+ image_mean: Optional[Union[float, List[float]]] = None,
155
+ image_std: Optional[Union[float, List[float]]] = None,
156
+ do_convert_rgb: Optional[bool] = None,
157
+ spatial_factor: Optional[int] = None,
158
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
159
+ output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
160
+ ):
161
+ """
162
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
163
+
164
+ Args:
165
+ images (`ImageInput`):
166
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
167
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
168
+ Whether to resize the image.
169
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
170
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
171
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
172
+ Whether to rescale the image.
173
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
174
+ Scale factor to use if rescaling the image.
175
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
176
+ Whether to normalize the image.
177
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
178
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
179
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
180
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
181
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
182
+ Whether to convert the image to RGB.
183
+ spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`):
184
+ The spatial downsample factor the image will be downsampled in feature extracting phase
185
+ input_data_format (`ChannelDimension` or `str`, *optional*):
186
+ The channel dimension format for the input image. Can be one of:
187
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
188
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
189
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
190
+ output_data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
191
+ The channel dimension format for the output image. Can be one of:
192
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
193
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
194
+ - Unset: Use the channel dimension format of the input image.
195
+ """
196
+ spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor
197
+
198
+ images = make_list_of_images(images)
199
+ if do_convert_rgb:
200
+ images = [convert_to_rgb(image) for image in images]
201
+
202
+ # All transformations expect numpy arrays.
203
+ images = [to_numpy_array(image) for image in images]
204
+
205
+ if is_scaled_image(images[0]) and do_rescale:
206
+ logger.warning_once(
207
+ "It looks like you are trying to rescale already rescaled images. If the input"
208
+ "pixel_values.append()images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
209
+ )
210
+
211
+ if input_data_format is None:
212
+ # We assume that all images have the same channel dimension format.
213
+ input_data_format = infer_channel_dimension_format(images[0])
214
+
215
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
216
+ resized_height, resized_width = height, width
217
+ processed_images = []
218
+ for image in images:
219
+ if do_resize:
220
+ resized_height, resized_width = smart_resize(
221
+ height,
222
+ width,
223
+ factor=spatial_factor,
224
+ min_pixels=self.min_pixels,
225
+ max_pixels=self.max_pixels,
226
+ )
227
+ image = resize(
228
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
229
+ )
230
+
231
+ if do_rescale:
232
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
233
+
234
+ if do_normalize:
235
+ image = self.normalize(
236
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
237
+ )
238
+
239
+ image = to_channel_dimension_format(image, output_data_format, input_channel_dim=input_data_format)
240
+ processed_images.append(image)
241
+
242
+ image = np.array(processed_images)
243
+ return image
244
+
245
+ def preprocess(
246
+ self,
247
+ images: ImageInput,
248
+ do_resize: Optional[bool] = None,
249
+ resample: PILImageResampling = None,
250
+ do_rescale: Optional[bool] = None,
251
+ rescale_factor: Optional[float] = None,
252
+ do_normalize: Optional[bool] = None,
253
+ image_mean: Optional[Union[float, List[float]]] = None,
254
+ image_std: Optional[Union[float, List[float]]] = None,
255
+ do_convert_rgb: Optional[bool] = None,
256
+ spatial_factor: Optional[int] = None,
257
+ return_tensors: Optional[Union[str, TensorType]] = None,
258
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
259
+ output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
260
+ ):
261
+ """
262
+ Args:
263
+ images (`ImageInput`):
264
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
265
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
266
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
267
+ Whether to resize the image.
268
+ resample (`int`, *optional*, defaults to `self.resample`):
269
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
270
+ has an effect if `do_resize` is set to `True`.
271
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
272
+ Whether to rescale the image.
273
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
274
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
275
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
276
+ Whether to normalize the image.
277
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
278
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
279
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
280
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
281
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
282
+ Whether to convert the image to RGB.
283
+ spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`):
284
+ The spatial downsample factor the image will be downsampled in feature extracting phase
285
+ return_tensors (`str` or `TensorType`, *optional*):
286
+ The type of tensors to return. Can be one of:
287
+ - Unset: Return a list of `np.ndarray`.
288
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
289
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
290
+ input_data_format (`ChannelDimension` or `str`, *optional*):
291
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
292
+ from the input image. Can be one of:
293
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
294
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
295
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
296
+ output_data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
297
+ The channel dimension format for the output image. Can be one of:
298
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
299
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
300
+ - Unset: Use the channel dimension format of the input image.
301
+ """
302
+ do_resize = do_resize if do_resize is not None else self.do_resize
303
+ resample = resample if resample is not None else self.resample
304
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
305
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
306
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
307
+ image_mean = image_mean if image_mean is not None else self.image_mean
308
+ image_std = image_std if image_std is not None else self.image_std
309
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
310
+ spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor
311
+
312
+ images = make_list_of_images(images)
313
+ if images is None or not valid_images(images):
314
+ raise ValueError(
315
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
316
+ "torch.Tensor, tf.Tensor or jax.ndarray."
317
+ )
318
+
319
+ validate_preprocess_arguments(
320
+ rescale_factor=rescale_factor,
321
+ do_normalize=do_normalize,
322
+ image_mean=image_mean,
323
+ image_std=image_std,
324
+ do_resize=do_resize,
325
+ size=self.size,
326
+ resample=resample,
327
+ )
328
+
329
+ pixel_values = []
330
+ for image in images:
331
+ norm_image = self._preprocess(
332
+ image,
333
+ do_resize=do_resize,
334
+ resample=resample,
335
+ do_rescale=do_rescale,
336
+ rescale_factor=rescale_factor,
337
+ do_normalize=do_normalize,
338
+ image_mean=image_mean,
339
+ image_std=image_std,
340
+ do_convert_rgb=do_convert_rgb,
341
+ spatial_factor=spatial_factor,
342
+ input_data_format=input_data_format,
343
+ output_data_format=output_data_format,
344
+ )
345
+ pixel_values.extend(norm_image)
346
+ pixel_values = np.array(pixel_values)
347
+ data = {"pixel_values": pixel_values}
348
+
349
+ return BatchFeature(data=data, tensor_type=return_tensors)
350
+
351
+ def postprocess(
352
+ self,
353
+ images: ImageInput,
354
+ do_rescale: Optional[bool] = None,
355
+ rescale_factor: Optional[float] = None,
356
+ do_normalize: Optional[bool] = None,
357
+ image_mean: Optional[Union[float, List[float]]] = None,
358
+ image_std: Optional[Union[float, List[float]]] = None,
359
+ return_tensors: str | TensorType = "PIL.Image.Image",
360
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
361
+ ):
362
+ """
363
+ Postprocess an image or batch of images tensor. Postprocess is the reverse process of preprocess.
364
+ The parameters should be same as in preprocess.
365
+
366
+ Args:
367
+ images (`ImageInput`):
368
+ Image to postprocess. Expects a single or batch of images with pixel values ranging from -1 to 1.
369
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
370
+ Whether to rescale the image.
371
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
372
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
373
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
374
+ Whether to normalize the image.
375
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
376
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
377
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
378
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
379
+ return_tensors (`str` or `TensorType`, *optional*):
380
+ The type of tensors to return. Can be one of:
381
+ - Unset: Return a list of `np.ndarray`.
382
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
383
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
384
+ input_data_format (`ChannelDimension` or `str`, *optional*):
385
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
386
+ from the input image. Can be one of:
387
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
388
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
389
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
390
+ """
391
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
392
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
393
+ rescale_factor = 1 / rescale_factor
394
+
395
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
396
+ image_mean = image_mean if image_mean is not None else self.image_mean
397
+ image_std = image_std if image_std is not None else self.image_std
398
+ image_mean, image_std = self.inverse_meanstd(image_mean, image_std)
399
+
400
+ images = make_list_of_images(images)
401
+ if isinstance(images[0], Image.Image):
402
+ return images if len(images) > 1 else images[0]
403
+
404
+ if input_data_format is None:
405
+ # We assume that all images have the same channel dimension format.
406
+ input_data_format = infer_channel_dimension_format(images[0])
407
+
408
+ pixel_values = []
409
+ for image in images:
410
+ image = to_numpy_array(image)
411
+ if do_normalize:
412
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
413
+
414
+ if do_rescale:
415
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
416
+ image = image.clip(0, 255).astype(np.uint8)
417
+
418
+ if do_normalize and do_rescale and return_tensors == "PIL.Image.Image":
419
+ image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format)
420
+ pixel_values.append(Image.fromarray(image))
421
+ else:
422
+ pixel_values.extend(image)
423
+
424
+ data = {"pixel_values": pixel_values}
425
+ return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None
426
+
427
+ return BatchFeature(data=data, tensor_type=return_tensors)
428
+
429
+ def inverse_meanstd(self, image_mean, image_std):
430
+ image_mean = self.to_tuple(image_mean)
431
+ image_std = self.to_tuple(image_std)
432
+
433
+ rev_image_mean = tuple(-m / s for m, s in zip(image_mean, image_std))
434
+ rev_image_std = tuple(1 / s for s in image_std)
435
+
436
+ return rev_image_mean, rev_image_std
437
+
438
+ def to_tuple(self, value, dim=3):
439
+ if isinstance(value, int | float):
440
+ return (value,) * dim
441
+
442
+ return tuple(value)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89536431c69b08b10b449ec309f52dcea22f14b7647317f30f5715273392bbf1
3
+ size 1083015124
modeling_emu3visionvq.py ADDED
@@ -0,0 +1,822 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Emu3VisionVQ model """
16
+
17
+ import math
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import torch
21
+ from torch import nn
22
+ from torch.nn import functional as F
23
+ from transformers.modeling_utils import PreTrainedModel
24
+
25
+ from .configuration_emu3visionvq import Emu3VisionVQConfig
26
+
27
+
28
+ class Emu3VisionVQActivation(nn.Module):
29
+
30
+ def __init__(self):
31
+ super().__init__()
32
+
33
+ def __call__(self, x: torch.Tensor):
34
+ return x * torch.sigmoid(x)
35
+
36
+
37
+ class Emu3VisionVQUpsample(nn.Module):
38
+
39
+ def __init__(self, in_channels: int):
40
+ super().__init__()
41
+ self.conv = nn.Conv2d(
42
+ in_channels,
43
+ in_channels,
44
+ kernel_size=3,
45
+ stride=1,
46
+ padding=1,
47
+ )
48
+
49
+ def forward(self, x: torch.Tensor):
50
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
51
+ x = self.conv(x)
52
+ return x
53
+
54
+
55
+ class Emu3VisionVQDownsample(nn.Module):
56
+
57
+ def __init__(self, in_channels: int):
58
+ super().__init__()
59
+ self.conv = nn.Conv2d(
60
+ in_channels,
61
+ in_channels,
62
+ kernel_size=3,
63
+ stride=2,
64
+ padding=0,
65
+ )
66
+
67
+ def forward(self, x: torch.Tensor):
68
+ pad = (0, 1, 0, 1)
69
+ x = F.pad(x, pad, mode="constant", value=0)
70
+ x = self.conv(x)
71
+ return x
72
+
73
+
74
+ class Emu3VisionVQCausalConv3d(nn.Module):
75
+
76
+ def __init__(
77
+ self,
78
+ in_channel: int,
79
+ out_channel: int,
80
+ kernel_size: Union[int, Tuple[int, ...]] = (3, 1, 1),
81
+ stride: Union[int, Tuple[int, ...]] = (1, 1, 1),
82
+ ):
83
+ super().__init__()
84
+
85
+ if isinstance(kernel_size, int):
86
+ kernel_size = (kernel_size,) * 3
87
+ if isinstance(stride, int):
88
+ stride = (stride,) * 3
89
+
90
+ hw_pad = [k - s for k, s in zip(kernel_size[1:], stride[1:])]
91
+ self.padding = tuple()
92
+ for p in hw_pad[::-1]:
93
+ self.padding += (p // 2 + p % 2, p // 2)
94
+ self.padding += (2, 0)
95
+
96
+ self.conv = nn.Conv3d(
97
+ in_channel,
98
+ out_channel,
99
+ kernel_size,
100
+ stride=stride,
101
+ )
102
+
103
+ def forward(self, x: torch.Tensor):
104
+ x = F.pad(x, self.padding)
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Emu3VisionVQResnetTemporalBlock(nn.Module):
110
+
111
+ def __init__(
112
+ self,
113
+ in_channels: int,
114
+ out_channels: Optional[int] = None,
115
+ conv_shortcut: bool = False,
116
+ dropout: float = 0.0,
117
+ ):
118
+ super().__init__()
119
+ self.in_channels = in_channels
120
+ out_channels = in_channels if out_channels is None else out_channels
121
+ self.out_channels = out_channels
122
+ self.use_conv_shortcut = conv_shortcut
123
+
124
+ stride = (1, 1, 1)
125
+ kernel_size = (3, 3, 3)
126
+
127
+ self.norm1 = nn.BatchNorm3d(in_channels)
128
+ self.conv1 = Emu3VisionVQCausalConv3d(
129
+ in_channels,
130
+ out_channels,
131
+ kernel_size=kernel_size,
132
+ stride=stride,
133
+ )
134
+ self.norm2 = nn.BatchNorm3d(out_channels)
135
+ self.dropout = nn.Dropout(dropout)
136
+ self.conv2 = Emu3VisionVQCausalConv3d(
137
+ out_channels,
138
+ out_channels,
139
+ kernel_size=kernel_size,
140
+ stride=stride,
141
+ )
142
+ self.act = Emu3VisionVQActivation()
143
+
144
+ if self.in_channels != self.out_channels:
145
+ if self.use_conv_shortcut:
146
+ self.conv_shortcut = Emu3VisionVQCausalConv3d(
147
+ in_channels,
148
+ out_channels,
149
+ kernel_size=kernel_size,
150
+ stride=stride,
151
+ )
152
+ else:
153
+ self.nin_shortcut = nn.Conv3d(
154
+ in_channels,
155
+ out_channels,
156
+ kernel_size=1,
157
+ stride=1,
158
+ padding=0,
159
+ )
160
+
161
+ def forward(self, x: torch.Tensor):
162
+ h = self.norm1(x)
163
+ h = self.act(h)
164
+ h = self.conv1(h)
165
+
166
+ h = self.norm2(h)
167
+ h = self.act(h)
168
+ h = self.dropout(h)
169
+ h = self.conv2(h)
170
+
171
+ if self.in_channels != self.out_channels:
172
+ if self.use_conv_shortcut:
173
+ x = self.conv_shortcut(x)
174
+ else:
175
+ x = self.nin_shortcut(x)
176
+
177
+ return x + h
178
+
179
+
180
+ class Emu3VisionVQSpatialNorm(nn.Module):
181
+
182
+ def __init__(
183
+ self,
184
+ f_channels: int,
185
+ zq_channels: int,
186
+ norm_layer: nn.Module = nn.GroupNorm,
187
+ add_conv: bool = False,
188
+ num_groups: int = 32,
189
+ eps: float = 1e-6,
190
+ affine: bool = True,
191
+ ):
192
+ super().__init__()
193
+ self.norm_layer = norm_layer(
194
+ num_channels=f_channels,
195
+ num_groups=num_groups,
196
+ eps=eps,
197
+ affine=affine,
198
+ )
199
+
200
+ self.add_conv = add_conv
201
+ if self.add_conv:
202
+ self.conv = nn.Conv2d(
203
+ zq_channels,
204
+ zq_channels,
205
+ kernel_size=3,
206
+ stride=1,
207
+ padding=1,
208
+ )
209
+
210
+ self.conv_y = nn.Conv2d(
211
+ zq_channels,
212
+ f_channels,
213
+ kernel_size=1,
214
+ stride=1,
215
+ padding=0,
216
+ )
217
+ self.conv_b = nn.Conv2d(
218
+ zq_channels,
219
+ f_channels,
220
+ kernel_size=1,
221
+ stride=1,
222
+ padding=0,
223
+ )
224
+
225
+ def forward(self, x: torch.Tensor, zq: torch.Tensor):
226
+ zq = F.interpolate(zq, size=x.shape[-2:], mode="nearest")
227
+
228
+ if self.add_conv:
229
+ zq = self.conv(zq)
230
+
231
+ x = self.norm_layer(x)
232
+ x = x * self.conv_y(zq) + self.conv_b(zq)
233
+ return x
234
+
235
+
236
+ class Emu3VisionVQResnetBlock(nn.Module):
237
+
238
+ def __init__(
239
+ self,
240
+ in_channels: int,
241
+ out_channels: Optional[int] = None,
242
+ conv_shortcut: bool = False,
243
+ dropout: float = 0.0,
244
+ zq_ch: Optional[int] = None,
245
+ add_conv: bool = False,
246
+ ):
247
+ super().__init__()
248
+ self.in_channels = in_channels
249
+ out_channels = in_channels if out_channels is None else out_channels
250
+ self.out_channels = out_channels
251
+ self.use_conv_shortcut = conv_shortcut
252
+ self.zq_ch = zq_ch
253
+
254
+ if zq_ch is None:
255
+ norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
256
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, **norm_kwargs)
257
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, **norm_kwargs)
258
+ else:
259
+ self.norm1 = Emu3VisionVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv)
260
+ self.norm2 = Emu3VisionVQSpatialNorm(out_channels, zq_ch, add_conv=add_conv)
261
+
262
+ self.conv1 = nn.Conv2d(
263
+ in_channels,
264
+ out_channels,
265
+ kernel_size=3,
266
+ stride=1,
267
+ padding=1,
268
+ )
269
+
270
+ self.dropout = nn.Dropout(dropout)
271
+ self.conv2 = nn.Conv2d(
272
+ out_channels,
273
+ out_channels,
274
+ kernel_size=3,
275
+ stride=1,
276
+ padding=1,
277
+ )
278
+
279
+ self.act = Emu3VisionVQActivation()
280
+
281
+ if self.in_channels != self.out_channels:
282
+ if self.use_conv_shortcut:
283
+ self.conv_shortcut = nn.Conv2d(
284
+ in_channels,
285
+ out_channels,
286
+ kernel_size=3,
287
+ stride=1,
288
+ padding=1,
289
+ )
290
+ else:
291
+ self.nin_shortcut = nn.Conv2d(
292
+ in_channels,
293
+ out_channels,
294
+ kernel_size=1,
295
+ stride=1,
296
+ padding=0,
297
+ )
298
+
299
+ def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None):
300
+ norm_args = tuple() if self.zq_ch is None else (zq, )
301
+
302
+ h = self.norm1(x, *norm_args)
303
+ h = self.act(h)
304
+ h = self.conv1(h)
305
+
306
+ h = self.norm2(h, *norm_args)
307
+ h = self.act(h)
308
+ h = self.dropout(h)
309
+ h = self.conv2(h)
310
+
311
+ if self.in_channels != self.out_channels:
312
+ if self.use_conv_shortcut:
313
+ x = self.conv_shortcut(x)
314
+ else:
315
+ x = self.nin_shortcut(x)
316
+
317
+ return x + h
318
+
319
+
320
+ class Emu3VisionVQAttnBlock(nn.Module):
321
+
322
+ def __init__(
323
+ self,
324
+ in_channels: int,
325
+ zq_ch: Optional[int] = None,
326
+ add_conv: bool = False
327
+ ):
328
+ super().__init__()
329
+ self.in_channels = in_channels
330
+ self.zq_ch = zq_ch
331
+
332
+ if zq_ch is None:
333
+ norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
334
+ self.norm = nn.GroupNorm(num_channels=in_channels, **norm_kwargs)
335
+ else:
336
+ self.norm = Emu3VisionVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv)
337
+
338
+ self.q = nn.Conv2d(
339
+ in_channels,
340
+ in_channels,
341
+ kernel_size=1,
342
+ stride=1,
343
+ padding=0,
344
+ )
345
+ self.k = nn.Conv2d(
346
+ in_channels,
347
+ in_channels,
348
+ kernel_size=1,
349
+ stride=1,
350
+ padding=0,
351
+ )
352
+ self.v = nn.Conv2d(
353
+ in_channels,
354
+ in_channels,
355
+ kernel_size=1,
356
+ stride=1,
357
+ padding=0,
358
+ )
359
+ self.proj_out = nn.Conv2d(
360
+ in_channels,
361
+ in_channels,
362
+ kernel_size=1,
363
+ stride=1,
364
+ padding=0,
365
+ )
366
+
367
+ def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None):
368
+ norm_args = tuple() if self.zq_ch is None else (zq, )
369
+
370
+ nx = self.norm(x, *norm_args)
371
+ q = self.q(nx)
372
+ k = self.k(nx)
373
+ v = self.v(nx)
374
+
375
+ # compute attention
376
+ b, c, h, w = q.shape
377
+ q = q.reshape(b, c, h * w)
378
+ k = k.reshape(b, c, h * w)
379
+ score = torch.bmm(q.permute(0, 2, 1), k)
380
+ score = score / (c ** 0.5)
381
+ score = F.softmax(score, dim=2)
382
+
383
+ # attend to values
384
+ v = v.reshape(b, c, h * w)
385
+ v = torch.bmm(v, score.permute(0, 2, 1))
386
+ v = v.reshape(b, c, h, w)
387
+
388
+ v = self.proj_out(v)
389
+
390
+ return x + v
391
+
392
+
393
+ class Emu3VisionVQTemporalUpsample(nn.Module):
394
+
395
+ def __init__(
396
+ self,
397
+ in_channel: int,
398
+ out_channel: int,
399
+ kernel_size: Tuple[int, ...] = (3, 3, 3),
400
+ stride: Tuple[int, ...] = (1, 1, 1)
401
+ ):
402
+ super().__init__()
403
+ self.in_channel = in_channel
404
+ self.out_channel = out_channel
405
+ self.conv = Emu3VisionVQCausalConv3d(
406
+ in_channel,
407
+ out_channel,
408
+ kernel_size,
409
+ stride=stride,
410
+ )
411
+
412
+ def forward(self, x: torch.Tensor):
413
+ b, c, t, h, w = x.shape
414
+ x = x.permute(0, 1, 3, 4, 2).contiguous().view(b, -1, t)
415
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
416
+ x = x.view(b, c, h, w, -1).permute(0, 1, 4, 2, 3).contiguous()
417
+ x = self.conv(x)
418
+ return x
419
+
420
+
421
+ class Emu3VisionVQTemporalDownsample(nn.Module):
422
+
423
+ def __init__(
424
+ self,
425
+ in_channel: int,
426
+ out_channel: int,
427
+ kernel_size: Tuple[int, ...] = (4, 3, 3),
428
+ stride: Tuple[int, ...] = (2, 1, 1),
429
+ ):
430
+ super().__init__()
431
+ self.in_channel = in_channel
432
+ self.out_channel = out_channel
433
+ self.kernel_size = kernel_size
434
+
435
+ self.conv = Emu3VisionVQCausalConv3d(
436
+ in_channel,
437
+ out_channel,
438
+ kernel_size=kernel_size,
439
+ stride=stride,
440
+ )
441
+
442
+ def forward(self, x: torch.Tensor):
443
+ x = self.conv(x)
444
+ return x
445
+
446
+
447
+ class Emu3VisionVQVectorQuantizer(nn.Module):
448
+
449
+ def __init__(self, config: Emu3VisionVQConfig):
450
+ super().__init__()
451
+ self.embedding = nn.Embedding(config.codebook_size, config.embed_dim)
452
+ self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size)
453
+
454
+ def forward(self, x: torch.Tensor):
455
+ # b t c h w -> b t h w c
456
+ b, t, c, h, w = x.shape
457
+ x = x.permute(0, 1, 3, 4, 2).contiguous()
458
+ x_flattened = x.view(-1, c)
459
+
460
+ codebook = self.embedding.weight
461
+
462
+ d = torch.sum(x_flattened ** 2, dim=1, keepdim=True) + \
463
+ torch.sum(codebook ** 2, dim=1) - 2 * \
464
+ torch.einsum('bd,dn->bn', x_flattened, codebook.permute(1, 0))
465
+
466
+ indices = torch.argmin(d, dim=1)
467
+ indices = indices.view(b, t, h, w)
468
+ return indices
469
+
470
+
471
+ class Emu3VisionVQEncoder(nn.Module):
472
+
473
+ def __init__(self, config: Emu3VisionVQConfig):
474
+ super().__init__()
475
+ self.ch = config.ch
476
+ self.num_resolutions = len(config.ch_mult)
477
+ self.num_res_blocks = config.num_res_blocks
478
+ self.in_channels = config.in_channels
479
+
480
+ # downsampling
481
+ self.conv_in = nn.Conv2d(
482
+ self.in_channels,
483
+ self.ch,
484
+ kernel_size=3,
485
+ stride=1,
486
+ padding=1
487
+ )
488
+
489
+ in_ch_mult = (1,) + tuple(config.ch_mult)
490
+ self.down = nn.ModuleList()
491
+ for i_level in range(self.num_resolutions):
492
+ block = nn.ModuleList()
493
+ attn = nn.ModuleList()
494
+ block_in = config.ch * in_ch_mult[i_level]
495
+ block_out = config.ch * config.ch_mult[i_level]
496
+ for i_block in range(self.num_res_blocks):
497
+ block.append(
498
+ Emu3VisionVQResnetBlock(
499
+ in_channels=block_in,
500
+ out_channels=block_out,
501
+ dropout=config.dropout,
502
+ )
503
+ )
504
+ block_in = block_out
505
+ if i_level in config.attn_resolutions:
506
+ attn.append(Emu3VisionVQAttnBlock(block_in))
507
+
508
+ down = nn.Module()
509
+ down.block = block
510
+ down.attn = attn
511
+ if i_level != self.num_resolutions - 1:
512
+ down.downsample = Emu3VisionVQDownsample(block_in)
513
+
514
+ self.down.append(down)
515
+
516
+ # middle
517
+ self.mid = nn.Module()
518
+ self.mid.block_1 = Emu3VisionVQResnetBlock(
519
+ in_channels=block_in,
520
+ out_channels=block_in,
521
+ dropout=config.dropout,
522
+ )
523
+ self.mid.attn_1 = Emu3VisionVQAttnBlock(block_in)
524
+ self.mid.block_2 = Emu3VisionVQResnetBlock(
525
+ in_channels=block_in,
526
+ out_channels=block_in,
527
+ dropout=config.dropout,
528
+ )
529
+
530
+ # end
531
+ self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)
532
+
533
+ out_z_channels = 2 * config.z_channels if config.double_z else config.z_channels
534
+ self.conv_out = nn.Conv2d(
535
+ block_in,
536
+ out_z_channels,
537
+ kernel_size=3,
538
+ stride=1,
539
+ padding=1,
540
+ )
541
+
542
+ temporal_down_blocks = int(math.log2(config.temporal_downsample_factor))
543
+ self.time_conv = nn.ModuleList()
544
+
545
+ for i in range(temporal_down_blocks):
546
+ conv = Emu3VisionVQTemporalDownsample(out_z_channels, out_z_channels)
547
+ self.time_conv.append(conv)
548
+
549
+ self.time_res_stack = nn.Sequential(*[
550
+ Emu3VisionVQResnetTemporalBlock(
551
+ in_channels=out_z_channels,
552
+ out_channels=out_z_channels,
553
+ dropout=config.dropout,
554
+ ) for _ in range(self.num_res_blocks)
555
+ ])
556
+
557
+ self.act = Emu3VisionVQActivation()
558
+
559
+ def forward(self, x: torch.Tensor):
560
+ t = x.shape[1]
561
+ x = x.reshape(-1, *x.shape[2:])
562
+
563
+ # downsampling
564
+ h = self.conv_in(x)
565
+ for i_level in range(self.num_resolutions):
566
+ for i_block in range(self.num_res_blocks):
567
+ h = self.down[i_level].block[i_block](h)
568
+ if len(self.down[i_level].attn) > 0:
569
+ h = self.down[i_level].attn[i_block](h)
570
+
571
+ if i_level != self.num_resolutions - 1:
572
+ h = self.down[i_level].downsample(h)
573
+
574
+ h = self.mid.block_1(h)
575
+ h = self.mid.attn_1(h)
576
+ h = self.mid.block_2(h)
577
+
578
+ # end
579
+ h = self.norm_out(h)
580
+ h = self.act(h)
581
+
582
+ h = self.conv_out(h)
583
+
584
+ h = h.reshape(-1, t, *h.shape[1:])
585
+ h = h.permute(0, 2, 1, 3, 4)
586
+
587
+ for conv in self.time_conv:
588
+ h = self.act(conv(h))
589
+
590
+ h = self.time_res_stack(h)
591
+ h = h.permute(0, 2, 1, 3, 4)
592
+
593
+ return h
594
+
595
+
596
+ class Emu3VisionVQDecoder(nn.Module):
597
+
598
+ def __init__(self, config: Emu3VisionVQConfig):
599
+ super().__init__()
600
+ self.ch = config.ch
601
+ self.num_resolutions = len(config.ch_mult)
602
+ self.num_res_blocks = config.num_res_blocks
603
+
604
+ in_ch_mult = (1,) + tuple(config.ch_mult)
605
+ zq_ch = config.embed_dim
606
+
607
+ block_in = config.ch * config.ch_mult[-1]
608
+ self.time_res_stack = nn.Sequential(*[
609
+ Emu3VisionVQResnetTemporalBlock(
610
+ in_channels=config.z_channels,
611
+ out_channels=config.z_channels,
612
+ dropout=config.dropout,
613
+ ) for _ in range(config.num_res_blocks)
614
+ ])
615
+
616
+ tempo_upsample_block_num = int(math.log2(config.temporal_downsample_factor))
617
+ self.time_conv = nn.ModuleList()
618
+ for i in range(tempo_upsample_block_num):
619
+ conv = Emu3VisionVQTemporalUpsample(config.z_channels, config.z_channels)
620
+ self.time_conv.append(conv)
621
+
622
+ self.conv_in = nn.Conv2d(
623
+ config.z_channels,
624
+ block_in,
625
+ kernel_size=3,
626
+ stride=1,
627
+ padding=1,
628
+ )
629
+
630
+ # middle
631
+ self.mid = nn.Module()
632
+ self.mid.block_1 = Emu3VisionVQResnetBlock(
633
+ in_channels=block_in,
634
+ out_channels=block_in,
635
+ dropout=config.dropout,
636
+ zq_ch=zq_ch,
637
+ )
638
+ self.mid.attn_1 = Emu3VisionVQAttnBlock(block_in, zq_ch)
639
+ self.mid.block_2 = Emu3VisionVQResnetBlock(
640
+ in_channels=block_in,
641
+ out_channels=block_in,
642
+ dropout=config.dropout,
643
+ zq_ch=zq_ch,
644
+ )
645
+
646
+ # upsampling
647
+ self.up = nn.ModuleList()
648
+ for i_level in reversed(range(self.num_resolutions)):
649
+ block = nn.ModuleList()
650
+ attn = nn.ModuleList()
651
+ block_out = config.ch * config.ch_mult[i_level]
652
+ for i_block in range(self.num_res_blocks + 1):
653
+ block.append(
654
+ Emu3VisionVQResnetBlock(
655
+ in_channels=block_in,
656
+ out_channels=block_out,
657
+ dropout=config.dropout,
658
+ zq_ch=zq_ch,
659
+ )
660
+ )
661
+ block_in = block_out
662
+ if i_level in config.attn_resolutions:
663
+ attn.append(Emu3VisionVQAttnBlock(block_in, zq_ch))
664
+
665
+ up = nn.Module()
666
+ up.block = block
667
+ up.attn = attn
668
+ if i_level != 0:
669
+ up.upsample = Emu3VisionVQUpsample(block_in)
670
+
671
+ self.up.insert(0, up)
672
+
673
+ self.act = Emu3VisionVQActivation()
674
+
675
+ self.norm_out = Emu3VisionVQSpatialNorm(block_in, zq_ch)
676
+ self.conv_out = nn.Conv2d(
677
+ block_in,
678
+ config.out_channels,
679
+ kernel_size=3,
680
+ stride=1,
681
+ padding=1,
682
+ )
683
+
684
+ def forward(self, z: torch.Tensor, zq: torch.Tensor):
685
+ z_zq = torch.cat((z, zq), dim=0)
686
+ z_zq = z_zq.permute(0, 2, 1, 3, 4)
687
+ z_zq = self.time_res_stack(z_zq)
688
+
689
+ for conv in self.time_conv:
690
+ z_zq = self.act(conv(z_zq))
691
+
692
+ z_zq = z_zq.permute(0, 2, 1, 3, 4)
693
+
694
+ h, zq = torch.chunk(z_zq, 2, dim=0)
695
+
696
+ h = h.reshape(-1, *h.shape[2:])
697
+ zq = zq.reshape(-1, *zq.shape[2:])
698
+
699
+ h = self.conv_in(h)
700
+
701
+ # middle
702
+ h = self.mid.block_1(h, zq)
703
+ h = self.mid.attn_1(h, zq)
704
+ h = self.mid.block_2(h, zq)
705
+
706
+ # upsampling
707
+ for i_level in reversed(range(self.num_resolutions)):
708
+ for i_block in range(self.num_res_blocks+1):
709
+ h = self.up[i_level].block[i_block](h, zq)
710
+ if len(self.up[i_level].attn) > 0:
711
+ h = self.up[i_level].attn[i_block](h, zq)
712
+
713
+ if i_level != 0:
714
+ h = self.up[i_level].upsample(h)
715
+
716
+ h = self.norm_out(h, zq)
717
+ h = self.act(h)
718
+ h = self.conv_out(h)
719
+
720
+ return h
721
+
722
+
723
+ class Emu3VisionVQPretrainedModel(PreTrainedModel):
724
+ """
725
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
726
+ models.
727
+ """
728
+
729
+ config_class = Emu3VisionVQConfig
730
+ base_model_prefix = "emuvideovq"
731
+ main_input_name = "pixel_values"
732
+ _no_split_modules = ["Emu3VisionVQResnetBlock", "Emu3VisionVQAttnBlock", "Emu3VisionVQResnetTemporalBlock"]
733
+
734
+ def _init_weights(self, module):
735
+ if isinstance(module, (nn.Conv2d, nn.Conv3d)):
736
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
737
+ # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
738
+ elif isinstance(module, nn.Linear):
739
+ nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
740
+ if module.bias is not None:
741
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
742
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
743
+ nn.init.uniform_(module.bias, -bound, bound)
744
+ elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
745
+ nn.init.constant_(module.weight, 1)
746
+ nn.init.constant_(module.bias, 0)
747
+
748
+
749
+ class Emu3VisionVQModel(Emu3VisionVQPretrainedModel):
750
+
751
+ def __init__(self, config):
752
+ super().__init__(config)
753
+ self.config = config
754
+
755
+ self.encoder = Emu3VisionVQEncoder(config)
756
+ self.decoder = Emu3VisionVQDecoder(config)
757
+ self.quantize = Emu3VisionVQVectorQuantizer(config)
758
+
759
+ self.quant_conv = Emu3VisionVQCausalConv3d(config.z_channels, config.embed_dim)
760
+ self.post_quant_conv = Emu3VisionVQCausalConv3d(config.embed_dim, config.z_channels)
761
+
762
+ self.spatial_scale_factor = 2 ** (len(config.ch_mult) - 1)
763
+
764
+ self.post_init()
765
+
766
+ def encode(self, x: torch.Tensor):
767
+ ndim = x.ndim
768
+ if ndim == 4:
769
+ t = self.config.temporal_downsample_factor
770
+ b, c, h, w = x.shape
771
+ x = x.unsqueeze(1).repeat(1, t, 1, 1, 1)
772
+ elif ndim == 5:
773
+ b, t, c, h, w = x.shape
774
+
775
+ h = self.encoder(x)
776
+
777
+ # b t c h w -> b c t h w
778
+ h = h.permute(0, 2, 1, 3, 4)
779
+ h = self.quant_conv(h)
780
+ # b c t h w -> b t c h w
781
+ h = h.permute(0, 2, 1, 3, 4)
782
+
783
+ codes = self.quantize(h)
784
+
785
+ if ndim == 4:
786
+ codes = codes.squeeze(1)
787
+
788
+ return codes
789
+
790
+ def decode(self, x: torch.Tensor):
791
+ ndim = x.ndim
792
+ if ndim == 3:
793
+ x = x.unsqueeze(1)
794
+
795
+ b, t, h, w = x.shape
796
+ quant = self.quantize.embedding(x.flatten())
797
+ c = quant.shape[-1]
798
+ quant = quant.view(b, t, h, w, c).permute(0, 4, 1, 2, 3).contiguous()
799
+ quant2 = self.post_quant_conv(quant)
800
+
801
+ quant = quant.permute(0, 2, 1, 3, 4)
802
+ quant2 = quant2.permute(0, 2, 1, 3, 4)
803
+
804
+ video = self.decoder(quant2, quant)
805
+ video = video.reshape(
806
+ b,
807
+ t * self.config.temporal_downsample_factor,
808
+ self.config.out_channels,
809
+ h * self.spatial_scale_factor,
810
+ w * self.spatial_scale_factor,
811
+ )
812
+ if ndim == 3:
813
+ return video[:, 0]
814
+ return video
815
+
816
+ @property
817
+ def device(self):
818
+ return next(self.parameters()).device
819
+
820
+ @property
821
+ def dtype(self):
822
+ return next(self.parameters()).dtype
preprocessor_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_emu3visionvq.Emu3VisionVQImageProcessor"
4
+ },
5
+ "do_convert_rgb": true,
6
+ "do_normalize": true,
7
+ "do_rescale": true,
8
+ "do_resize": true,
9
+ "image_mean": [
10
+ 0.5,
11
+ 0.5,
12
+ 0.5
13
+ ],
14
+ "image_processor_type": "Emu3VisionVQImageProcessor",
15
+ "image_std": [
16
+ 0.5,
17
+ 0.5,
18
+ 0.5
19
+ ],
20
+ "max_pixels": 1048576,
21
+ "min_pixels": 262144,
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "max_pixels": 1048576,
26
+ "min_pixels": 262144
27
+ },
28
+ "spatial_factor": 8
29
+ }