wusize commited on
Commit
133ff85
·
verified ·
1 Parent(s): aa4c3c3

Update modeling_harmon.py

Browse files
Files changed (1) hide show
  1. modeling_harmon.py +288 -288
modeling_harmon.py CHANGED
@@ -1,288 +1,288 @@
1
- import torch
2
- import math
3
- import numpy as np
4
- import torch.nn as nn
5
- import copy
6
- from einops import rearrange
7
- from torch.nn.modules.module import T
8
- from transformers.cache_utils import DynamicCache
9
-
10
- from tqdm import tqdm
11
- from transformers import Qwen2ForCausalLM, Qwen2Config, PreTrainedModel
12
-
13
- from .diffusion_utils import *
14
- from .gaussian_diffusion import *
15
- from .respace import *
16
- from .misc import *
17
- from .diffloss import *
18
-
19
-
20
- from .configuration_harmon import HarmonConfig
21
- from .vae import AutoencoderKL
22
- from .mar import mar_base, mar_large, mar_huge
23
-
24
-
25
-
26
- def build_mlp(hidden_size, projector_dim, z_dim):
27
- return nn.Sequential(
28
- nn.Linear(hidden_size, projector_dim),
29
- nn.SiLU(),
30
- nn.Linear(projector_dim, z_dim),)
31
-
32
-
33
- def mask_by_order(mask_len, order, bsz, seq_len):
34
- masking = torch.zeros(bsz, seq_len, device=order.device)
35
- masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()],
36
- src=torch.ones(bsz, seq_len, device=order.device)).bool()
37
- return masking
38
-
39
-
40
- class HarmonModel(PreTrainedModel):
41
- config_class = HarmonConfig
42
-
43
- def __init__(self, config: HarmonConfig):
44
- super().__init__(config)
45
- # VAE
46
- self.vae = AutoencoderKL(
47
- embed_dim=16,
48
- ch_mult=(1, 1, 2, 2, 4)
49
- )
50
- self.vae_scale = 0.2325
51
-
52
- # LLM
53
- self.llm = Qwen2ForCausalLM(config=Qwen2Config.from_dict(config.llm))
54
-
55
- # MAR
56
- mar_config = copy.deepcopy(config.mar)
57
- mar_type = mar_config.pop('type')
58
- if mar_type == 'mar_base':
59
- self.mar = mar_base(**mar_config)
60
- elif mar_type == 'mar_large':
61
- self.mar = mar_large(**mar_config)
62
- elif mar_type == 'mar_huge':
63
- self.mar = mar_huge(**mar_config)
64
- else:
65
- raise ValueError
66
-
67
- # projection layers
68
- self.proj_in = build_mlp(hidden_size=self.mar.encoder_embed_dim,
69
- projector_dim=self.llm.config.hidden_size,
70
- z_dim=self.llm.config.hidden_size)
71
- self.proj_out = build_mlp(hidden_size=self.llm.config.hidden_size,
72
- projector_dim=self.llm.config.hidden_size,
73
- z_dim=self.mar.encoder_embed_dim)
74
-
75
- @property
76
- def llm_model(self):
77
- return self.llm.model
78
-
79
- @property
80
- def device(self):
81
- return self.llm.device
82
-
83
- @property
84
- def dtype(self):
85
- return self.llm.dtype
86
-
87
- @property
88
- def gen_seq_len(self):
89
- return self.mar.seq_len
90
-
91
- @property
92
- def token_embed_dim(self):
93
- return self.vae.embed_dim * (self.mar.patch_size ** 2)
94
-
95
- @torch.no_grad()
96
- def encode(self, x):
97
- posterior = self.vae.encode(x)
98
- z = posterior.sample().mul_(self.vae_scale)
99
- z = rearrange(z, 'b c (m p) (n q) -> b m n (c p q)',
100
- p=self.mar.patch_size, q=self.mar.patch_size)
101
-
102
- return z
103
-
104
- @torch.no_grad()
105
- def decode(self, z):
106
- z /= self.vae_scale
107
- z = rearrange(z, 'b m n (c p q) -> b c (m p) (n q)',
108
- p=self.mar.patch_size, q=self.mar.patch_size)
109
-
110
- x = self.vae.decode(z)
111
- return x
112
-
113
- def prepare_forward_input(self,
114
- x,
115
- inputs_embeds=None,
116
- input_ids=None,
117
- attention_mask=None,
118
- past_key_values=None):
119
- b, l, _ = x.shape
120
- attention_mask = attention_mask.to(device=self.device, dtype=torch.bool)
121
- attention_mask = torch.cat([
122
- attention_mask, attention_mask.new_ones(b, l)
123
- ], dim=1)
124
- position_ids = torch.cumsum(attention_mask, dim=1) - 1
125
- position_ids[position_ids < 0] = 0
126
-
127
- # import pdb; pdb.set_trace()
128
-
129
- # prepare context
130
- if past_key_values is not None:
131
- inputs_embeds = x
132
- position_ids = position_ids[:, -l:]
133
- else:
134
- if inputs_embeds is None:
135
- input_ids = input_ids.to(self.device)
136
- inputs_embeds = self.llm.get_input_embeddings()(input_ids)
137
- inputs_embeds = torch.cat([inputs_embeds, x], dim=1)
138
-
139
- return dict(inputs_embeds=inputs_embeds,
140
- attention_mask=attention_mask,
141
- position_ids=position_ids,
142
- past_key_values=past_key_values)
143
-
144
- def extract_visual_feature(self, x, mask=None, detach=False):
145
- b, m, n, _ = x.shape
146
- x = x.view(b, m*n, -1)
147
- # x: b mn c
148
- if mask is None:
149
- mask = torch.zeros_like(x[..., 0])
150
- null_embeds = self.mar.fake_latent.expand(x.shape[0], -1)
151
- x_enc = self.mar.forward_mae_encoder(x, mask, null_embeds, image_shape=(m, n))
152
-
153
- z_enc = self.proj_in(x_enc)
154
- # Move buffers to the end of the image sequence
155
- z_enc = torch.cat([
156
- z_enc[:, self.mar.buffer_size:],
157
- z_enc[:, :self.mar.buffer_size]], dim=1)
158
-
159
- if detach:
160
- x_enc = x_enc.detach()
161
- z_enc = z_enc.detach()
162
-
163
- return x_enc, z_enc
164
-
165
- def forward_mae_encoder(self, x, mask, detach=False, **context):
166
- b, m, n, _ = x.shape
167
- x_enc, z_enc = self.extract_visual_feature(x, mask=mask, detach=detach)
168
- inputs = self.prepare_forward_input(x=z_enc, **context)
169
- output = self.llm_model(**inputs, return_dict=True)
170
-
171
- z_llm = output.last_hidden_state[:, -z_enc.shape[1]:]
172
-
173
- # move buffers back to the start of the image sequence
174
- z_llm = torch.cat([
175
- z_llm[:, -self.mar.buffer_size:],
176
- z_llm[:, :-self.mar.buffer_size]], dim=1)
177
-
178
- # residual learning
179
- x_enc = x_enc + self.proj_out(z_llm)
180
-
181
- return x_enc
182
-
183
- @staticmethod
184
- def curtail_cache(past_key_values, cur_len):
185
- for past_key_values_ in past_key_values:
186
- keys, values = past_key_values_
187
- keys.data = keys.data[:, :, :cur_len]
188
- values.data = values.data[:, :, :cur_len]
189
-
190
- @torch.no_grad()
191
- def sample(self,
192
- input_ids=None, inputs_embeds=None,
193
- attention_mask=None, num_iter=64, cfg=1.0, cfg_schedule="constant", temperature=1.0,
194
- progress=False, mask=None, past_key_values=None, image_shape=None, x_con=None, **kwargs):
195
- if inputs_embeds is None and input_ids is not None:
196
- inputs_embeds = self.llm.get_input_embeddings()(input_ids)
197
-
198
- bsz = attention_mask.shape[0]
199
- if cfg != 1.0:
200
- assert bsz % 2 == 0
201
-
202
- if image_shape is None:
203
- m = n = int(self.gen_seq_len ** 0.5)
204
- else:
205
- m, n = image_shape
206
-
207
- if mask is None:
208
- mask = torch.ones(bsz, m*n, device=self.device, dtype=self.dtype)
209
- else:
210
- mask = mask.view(bsz, m*n)
211
- tokens = torch.zeros(bsz, m*n, self.token_embed_dim,
212
- device=self.device, dtype=self.dtype)
213
- orders = self.mar.sample_orders(bsz, seq_len=m*n)
214
- if cfg != 1.0:
215
- orders[bsz//2:] = orders[:bsz//2]
216
-
217
- indices = list(range(num_iter))
218
- if progress:
219
- indices = tqdm(indices)
220
-
221
- # past key values can be prepared outside (usually in multi-turn editing)
222
- if past_key_values is None:
223
- output = self.llm_model(inputs_embeds=inputs_embeds,
224
- attention_mask=None,
225
- position_ids=None,
226
- past_key_values=DynamicCache.from_legacy_cache(),
227
- return_dict=True,
228
- use_cache=True)
229
- past_key_values = output.past_key_values
230
-
231
- # generate latents
232
- for step in indices:
233
- cur_tokens = tokens.clone()
234
- x_enc = self.forward_mae_encoder(tokens.view(bsz, m, n, -1),
235
- mask.to(self.dtype),
236
- past_key_values=past_key_values,
237
- # inputs_embeds=inputs_embeds,
238
- attention_mask=attention_mask)
239
- # import pdb; pdb.set_trace()
240
- self.curtail_cache(past_key_values, inputs_embeds.shape[1])
241
- # import pdb; pdb.set_trace()
242
-
243
- z = self.mar.forward_mae_decoder(x_enc, mask.to(self.dtype), image_shape=(m, n), x_con=x_con)
244
-
245
- # mask ratio for the next round, following MaskGIT and MAGE.
246
- mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
247
- mask_len = torch.Tensor([np.floor(m*n * mask_ratio)]).to(self.device)
248
-
249
- # masks out at least one for the next iteration
250
- mask_len = torch.maximum(torch.Tensor([1]).to(self.device),
251
- torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
252
-
253
- # get masking for next iteration and locations to be predicted in this iteration
254
- mask_next = mask_by_order(mask_len[0], orders, bsz, m*n).to(self.device)
255
- if cfg != 1.0:
256
- mask_next[bsz//2:] = mask_next[:bsz//2]
257
- if step >= num_iter - 1:
258
- mask_to_pred = mask[:bsz].bool()
259
- else:
260
- mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
261
- mask = mask_next
262
- # if not cfg == 1.0:
263
- # mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
264
-
265
- # sample token latents for this step
266
- z = z[mask_to_pred.nonzero(as_tuple=True)]
267
- # cfg schedule follow Muse
268
- if cfg_schedule == "linear":
269
- cfg_iter = 1 + (cfg - 1) * (m*n - mask_len[0]) / (m*n)
270
- elif cfg_schedule == "constant":
271
- cfg_iter = cfg
272
- else:
273
- raise NotImplementedError
274
- sampled_token_latent = self.mar.diffloss.sample(z, temperature, cfg_iter).to(self.dtype)
275
- # if not cfg == 1.0:
276
- # sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
277
- # mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
278
-
279
- cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
280
- if cfg != 1.0:
281
- cur_tokens[bsz//2:] = cur_tokens[:bsz//2]
282
- tokens = cur_tokens.clone()
283
-
284
- pred = self.decode(tokens.view(bsz, m, n, -1))
285
-
286
- if cfg != 1.0:
287
- pred = pred[:bsz//2]
288
- return pred
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import copy
6
+ from einops import rearrange
7
+ from torch.nn.modules.module import T
8
+ from transformers.cache_utils import DynamicCache
9
+
10
+ from tqdm import tqdm
11
+ from transformers import Qwen2ForCausalLM, Qwen2Config, PreTrainedModel
12
+
13
+ from .diffusion_utils import *
14
+ from .gaussian_diffusion import *
15
+ from .respace import *
16
+ from .misc import *
17
+ from .diffloss import *
18
+
19
+
20
+ from .configuration_harmon import HarmonConfig
21
+ from .vae import AutoencoderKL
22
+ from .mar import mar_base, mar_large, mar_huge
23
+
24
+
25
+
26
+ def build_mlp(hidden_size, projector_dim, z_dim):
27
+ return nn.Sequential(
28
+ nn.Linear(hidden_size, projector_dim),
29
+ nn.SiLU(),
30
+ nn.Linear(projector_dim, z_dim),)
31
+
32
+
33
+ def mask_by_order(mask_len, order, bsz, seq_len):
34
+ masking = torch.zeros(bsz, seq_len, device=order.device)
35
+ masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()],
36
+ src=torch.ones(bsz, seq_len, device=order.device)).bool()
37
+ return masking
38
+
39
+
40
+ class HarmonModel(PreTrainedModel):
41
+ config_class = HarmonConfig
42
+
43
+ def __init__(self, config: HarmonConfig):
44
+ super().__init__(config)
45
+ # VAE
46
+ self.vae = AutoencoderKL(
47
+ embed_dim=16,
48
+ ch_mult=(1, 1, 2, 2, 4)
49
+ )
50
+ self.vae_scale = 0.2325
51
+
52
+ # LLM
53
+ self.llm = Qwen2ForCausalLM(config=Qwen2Config.from_dict(config.llm))
54
+
55
+ # MAR
56
+ mar_config = copy.deepcopy(config.mar)
57
+ mar_type = mar_config.pop('type')
58
+ if mar_type == 'mar_base':
59
+ self.mar = mar_base(**mar_config)
60
+ elif mar_type == 'mar_large':
61
+ self.mar = mar_large(**mar_config)
62
+ elif mar_type == 'mar_huge':
63
+ self.mar = mar_huge(**mar_config)
64
+ else:
65
+ raise ValueError
66
+
67
+ # projection layers
68
+ self.proj_in = build_mlp(hidden_size=self.mar.encoder_embed_dim,
69
+ projector_dim=self.llm.config.hidden_size,
70
+ z_dim=self.llm.config.hidden_size)
71
+ self.proj_out = build_mlp(hidden_size=self.llm.config.hidden_size,
72
+ projector_dim=self.llm.config.hidden_size,
73
+ z_dim=self.mar.encoder_embed_dim)
74
+
75
+ @property
76
+ def llm_model(self):
77
+ return self.llm.model
78
+
79
+ @property
80
+ def device(self):
81
+ return self.llm.device
82
+
83
+ @property
84
+ def dtype(self):
85
+ return self.llm.dtype
86
+
87
+ @property
88
+ def gen_seq_len(self):
89
+ return self.mar.seq_len
90
+
91
+ @property
92
+ def token_embed_dim(self):
93
+ return self.vae.embed_dim * (self.mar.patch_size ** 2)
94
+
95
+ @torch.no_grad()
96
+ def encode(self, x):
97
+ posterior = self.vae.encode(x)
98
+ z = posterior.mode().mul_(self.vae_scale)
99
+ z = rearrange(z, 'b c (m p) (n q) -> b m n (c p q)',
100
+ p=self.mar.patch_size, q=self.mar.patch_size)
101
+
102
+ return z
103
+
104
+ @torch.no_grad()
105
+ def decode(self, z):
106
+ z /= self.vae_scale
107
+ z = rearrange(z, 'b m n (c p q) -> b c (m p) (n q)',
108
+ p=self.mar.patch_size, q=self.mar.patch_size)
109
+
110
+ x = self.vae.decode(z)
111
+ return x
112
+
113
+ def prepare_forward_input(self,
114
+ x,
115
+ inputs_embeds=None,
116
+ input_ids=None,
117
+ attention_mask=None,
118
+ past_key_values=None):
119
+ b, l, _ = x.shape
120
+ attention_mask = attention_mask.to(device=self.device, dtype=torch.bool)
121
+ attention_mask = torch.cat([
122
+ attention_mask, attention_mask.new_ones(b, l)
123
+ ], dim=1)
124
+ position_ids = torch.cumsum(attention_mask, dim=1) - 1
125
+ position_ids[position_ids < 0] = 0
126
+
127
+ # import pdb; pdb.set_trace()
128
+
129
+ # prepare context
130
+ if past_key_values is not None:
131
+ inputs_embeds = x
132
+ position_ids = position_ids[:, -l:]
133
+ else:
134
+ if inputs_embeds is None:
135
+ input_ids = input_ids.to(self.device)
136
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
137
+ inputs_embeds = torch.cat([inputs_embeds, x], dim=1)
138
+
139
+ return dict(inputs_embeds=inputs_embeds,
140
+ attention_mask=attention_mask,
141
+ position_ids=position_ids,
142
+ past_key_values=past_key_values)
143
+
144
+ def extract_visual_feature(self, x, mask=None, detach=False):
145
+ b, m, n, _ = x.shape
146
+ x = x.view(b, m*n, -1)
147
+ # x: b mn c
148
+ if mask is None:
149
+ mask = torch.zeros_like(x[..., 0])
150
+ null_embeds = self.mar.fake_latent.expand(x.shape[0], -1)
151
+ x_enc = self.mar.forward_mae_encoder(x, mask, null_embeds, image_shape=(m, n))
152
+
153
+ z_enc = self.proj_in(x_enc)
154
+ # Move buffers to the end of the image sequence
155
+ z_enc = torch.cat([
156
+ z_enc[:, self.mar.buffer_size:],
157
+ z_enc[:, :self.mar.buffer_size]], dim=1)
158
+
159
+ if detach:
160
+ x_enc = x_enc.detach()
161
+ z_enc = z_enc.detach()
162
+
163
+ return x_enc, z_enc
164
+
165
+ def forward_mae_encoder(self, x, mask, detach=False, **context):
166
+ b, m, n, _ = x.shape
167
+ x_enc, z_enc = self.extract_visual_feature(x, mask=mask, detach=detach)
168
+ inputs = self.prepare_forward_input(x=z_enc, **context)
169
+ output = self.llm_model(**inputs, return_dict=True)
170
+
171
+ z_llm = output.last_hidden_state[:, -z_enc.shape[1]:]
172
+
173
+ # move buffers back to the start of the image sequence
174
+ z_llm = torch.cat([
175
+ z_llm[:, -self.mar.buffer_size:],
176
+ z_llm[:, :-self.mar.buffer_size]], dim=1)
177
+
178
+ # residual learning
179
+ x_enc = x_enc + self.proj_out(z_llm)
180
+
181
+ return x_enc
182
+
183
+ @staticmethod
184
+ def curtail_cache(past_key_values, cur_len):
185
+ for past_key_values_ in past_key_values:
186
+ keys, values = past_key_values_
187
+ keys.data = keys.data[:, :, :cur_len]
188
+ values.data = values.data[:, :, :cur_len]
189
+
190
+ @torch.no_grad()
191
+ def sample(self,
192
+ input_ids=None, inputs_embeds=None,
193
+ attention_mask=None, num_iter=64, cfg=1.0, cfg_schedule="constant", temperature=1.0,
194
+ progress=False, mask=None, past_key_values=None, image_shape=None, x_con=None, **kwargs):
195
+ if inputs_embeds is None and input_ids is not None:
196
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
197
+
198
+ bsz = attention_mask.shape[0]
199
+ if cfg != 1.0:
200
+ assert bsz % 2 == 0
201
+
202
+ if image_shape is None:
203
+ m = n = int(self.gen_seq_len ** 0.5)
204
+ else:
205
+ m, n = image_shape
206
+
207
+ if mask is None:
208
+ mask = torch.ones(bsz, m*n, device=self.device, dtype=self.dtype)
209
+ else:
210
+ mask = mask.view(bsz, m*n)
211
+ tokens = torch.zeros(bsz, m*n, self.token_embed_dim,
212
+ device=self.device, dtype=self.dtype)
213
+ orders = self.mar.sample_orders(bsz, seq_len=m*n)
214
+ if cfg != 1.0:
215
+ orders[bsz//2:] = orders[:bsz//2]
216
+
217
+ indices = list(range(num_iter))
218
+ if progress:
219
+ indices = tqdm(indices)
220
+
221
+ # past key values can be prepared outside (usually in multi-turn editing)
222
+ if past_key_values is None:
223
+ output = self.llm_model(inputs_embeds=inputs_embeds,
224
+ attention_mask=None,
225
+ position_ids=None,
226
+ past_key_values=DynamicCache.from_legacy_cache(),
227
+ return_dict=True,
228
+ use_cache=True)
229
+ past_key_values = output.past_key_values
230
+
231
+ # generate latents
232
+ for step in indices:
233
+ cur_tokens = tokens.clone()
234
+ x_enc = self.forward_mae_encoder(tokens.view(bsz, m, n, -1),
235
+ mask.to(self.dtype),
236
+ past_key_values=past_key_values,
237
+ # inputs_embeds=inputs_embeds,
238
+ attention_mask=attention_mask)
239
+ # import pdb; pdb.set_trace()
240
+ self.curtail_cache(past_key_values, inputs_embeds.shape[1])
241
+ # import pdb; pdb.set_trace()
242
+
243
+ z = self.mar.forward_mae_decoder(x_enc, mask.to(self.dtype), image_shape=(m, n), x_con=x_con)
244
+
245
+ # mask ratio for the next round, following MaskGIT and MAGE.
246
+ mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
247
+ mask_len = torch.Tensor([np.floor(m*n * mask_ratio)]).to(self.device)
248
+
249
+ # masks out at least one for the next iteration
250
+ mask_len = torch.maximum(torch.Tensor([1]).to(self.device),
251
+ torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
252
+
253
+ # get masking for next iteration and locations to be predicted in this iteration
254
+ mask_next = mask_by_order(mask_len[0], orders, bsz, m*n).to(self.device)
255
+ if cfg != 1.0:
256
+ mask_next[bsz//2:] = mask_next[:bsz//2]
257
+ if step >= num_iter - 1:
258
+ mask_to_pred = mask[:bsz].bool()
259
+ else:
260
+ mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
261
+ mask = mask_next
262
+ # if not cfg == 1.0:
263
+ # mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
264
+
265
+ # sample token latents for this step
266
+ z = z[mask_to_pred.nonzero(as_tuple=True)]
267
+ # cfg schedule follow Muse
268
+ if cfg_schedule == "linear":
269
+ cfg_iter = 1 + (cfg - 1) * (m*n - mask_len[0]) / (m*n)
270
+ elif cfg_schedule == "constant":
271
+ cfg_iter = cfg
272
+ else:
273
+ raise NotImplementedError
274
+ sampled_token_latent = self.mar.diffloss.sample(z, temperature, cfg_iter).to(self.dtype)
275
+ # if not cfg == 1.0:
276
+ # sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
277
+ # mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
278
+
279
+ cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
280
+ if cfg != 1.0:
281
+ cur_tokens[bsz//2:] = cur_tokens[:bsz//2]
282
+ tokens = cur_tokens.clone()
283
+
284
+ pred = self.decode(tokens.view(bsz, m, n, -1))
285
+
286
+ if cfg != 1.0:
287
+ pred = pred[:bsz//2]
288
+ return pred