HarryHe commited on
Commit
f7c417a
·
1 Parent(s): 745a0a7
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. models/__init__.py +0 -0
  2. models/valle_ar.py +265 -0
  3. models/valle_nar.py +303 -0
  4. modules/__init__.py +0 -0
  5. modules/__pycache__/__init__.cpython-39.pyc +0 -0
  6. modules/activation_functions/__init__.py +7 -0
  7. modules/activation_functions/__pycache__/__init__.cpython-39.pyc +0 -0
  8. modules/activation_functions/__pycache__/gated_activation_unit.cpython-39.pyc +0 -0
  9. modules/activation_functions/__pycache__/snake.cpython-39.pyc +0 -0
  10. modules/activation_functions/gated_activation_unit.py +61 -0
  11. modules/activation_functions/snake.py +122 -0
  12. modules/anti_aliasing/__init__.py +8 -0
  13. modules/anti_aliasing/__pycache__/__init__.cpython-39.pyc +0 -0
  14. modules/anti_aliasing/__pycache__/act.cpython-39.pyc +0 -0
  15. modules/anti_aliasing/__pycache__/filter.cpython-39.pyc +0 -0
  16. modules/anti_aliasing/__pycache__/resample.cpython-39.pyc +0 -0
  17. modules/anti_aliasing/act.py +36 -0
  18. modules/anti_aliasing/filter.py +99 -0
  19. modules/anti_aliasing/resample.py +65 -0
  20. modules/base/base_module.py +75 -0
  21. modules/diffusion/__init__.py +7 -0
  22. modules/diffusion/bidilconv/bidilated_conv.py +102 -0
  23. modules/diffusion/bidilconv/residual_block.py +73 -0
  24. modules/diffusion/karras/karras_diffusion.py +977 -0
  25. modules/diffusion/karras/random_utils.py +177 -0
  26. modules/diffusion/karras/sample.py +185 -0
  27. modules/diffusion/unet/attention.py +241 -0
  28. modules/diffusion/unet/basic.py +15 -0
  29. modules/diffusion/unet/resblock.py +178 -0
  30. modules/diffusion/unet/unet.py +310 -0
  31. modules/distributions/__init__.py +0 -0
  32. modules/distributions/distributions.py +107 -0
  33. modules/duration_predictor/__init__.py +0 -0
  34. modules/duration_predictor/standard_duration_predictor.py +53 -0
  35. modules/duration_predictor/stochastic_duration_predictor.py +120 -0
  36. modules/encoder/__init__.py +1 -0
  37. modules/encoder/__pycache__/__init__.cpython-39.pyc +0 -0
  38. modules/encoder/__pycache__/token_encoder.cpython-39.pyc +0 -0
  39. modules/encoder/condition_encoder.py +251 -0
  40. modules/encoder/conv_encoder.py +103 -0
  41. modules/encoder/position_encoder.py +85 -0
  42. modules/encoder/token_encoder.py +25 -0
  43. modules/flow/modules.py +457 -0
  44. modules/general/__init__.py +3 -0
  45. modules/general/__pycache__/__init__.cpython-39.pyc +0 -0
  46. modules/general/__pycache__/input_strategies.cpython-39.pyc +0 -0
  47. modules/general/__pycache__/scaling.cpython-39.pyc +0 -0
  48. modules/general/__pycache__/utils.cpython-39.pyc +0 -0
  49. modules/general/input_strategies.py +130 -0
  50. modules/general/scaling.py +1349 -0
models/__init__.py ADDED
File without changes
models/valle_ar.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python -m models.tts.valle_gpt.valle_ar
2
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import os
7
+ import torch.nn as nn
8
+
9
+
10
+ class ValleAR(nn.Module):
11
+ def __init__(
12
+ self,
13
+ phone_vocab_size=256,
14
+ target_vocab_size=1024,
15
+ hidden_size=1024,
16
+ intermediate_size=4096,
17
+ num_hidden_layers=12,
18
+ num_attention_heads=16,
19
+ pad_token_id=1281,
20
+ bos_target_id=1282,
21
+ eos_target_id=1283,
22
+ bos_phone_id=1284,
23
+ eos_phone_id=1285,
24
+ use_input_embeds=False,
25
+ emb_dim=256,
26
+ ):
27
+ super(ValleAR, self).__init__()
28
+ self.config = LlamaConfig(
29
+ vocab_size=phone_vocab_size + target_vocab_size + 10,
30
+ hidden_size=hidden_size,
31
+ intermediate_size=intermediate_size,
32
+ num_hidden_layers=num_hidden_layers,
33
+ num_attention_heads=num_attention_heads,
34
+ pad_token_id=pad_token_id,
35
+ bos_token_id=bos_target_id,
36
+ eos_token_id=eos_target_id,
37
+ )
38
+ self.phone_vocab_size = phone_vocab_size
39
+ self.target_vocab_size = target_vocab_size
40
+ self.pad_token_id = pad_token_id
41
+ self.bos_target_id = bos_target_id
42
+ self.eos_target_id = eos_target_id
43
+ self.bos_phone_id = bos_phone_id
44
+ self.eos_phone_id = eos_phone_id
45
+ self.model = LlamaForCausalLM(self.config)
46
+
47
+ self.use_input_embeds = use_input_embeds
48
+
49
+ # no input embedding is used to provide speaker information
50
+ if self.use_input_embeds:
51
+ self.emb_linear = nn.Linear(emb_dim, hidden_size)
52
+ self.emb_linear.weight.data.normal_(mean=0.0, std=0.01)
53
+ self.emb_linear.bias.data.zero_()
54
+
55
+ def forward(
56
+ self, phone_ids, phone_mask, target_ids, target_mask, input_embeds=None
57
+ ):
58
+ if input_embeds is not None:
59
+ input_embeds = self.emb_linear(input_embeds)
60
+ phone_ids, phone_mask, phone_label = self.add_phone_eos_bos_label(
61
+ phone_ids,
62
+ phone_mask,
63
+ self.eos_phone_id,
64
+ self.bos_phone_id,
65
+ self.pad_token_id,
66
+ )
67
+ target_ids, target_mask, target_label = self.add_target_eos_bos_label(
68
+ target_ids,
69
+ target_mask,
70
+ self.eos_target_id,
71
+ self.bos_target_id,
72
+ self.pad_token_id,
73
+ )
74
+ input_token_ids = torch.cat([phone_ids, target_ids], dim=-1)
75
+ attention_mask = torch.cat([phone_mask, target_mask], dim=-1)
76
+ if input_embeds is not None:
77
+ raise NotImplementedError
78
+ attention_mask = torch.cat(
79
+ [
80
+ torch.ones(
81
+ (input_embeds.shape[0], input_embeds.shape[1]),
82
+ dtype=attention_mask.dtype,
83
+ device=attention_mask.device,
84
+ ),
85
+ attention_mask,
86
+ ],
87
+ dim=-1,
88
+ )
89
+ labels = torch.cat([phone_label, target_label], dim=-1)
90
+ if input_embeds is not None:
91
+ raise NotImplementedError
92
+ labels = torch.cat(
93
+ [
94
+ -100
95
+ * torch.ones(
96
+ (input_embeds.shape[0], input_embeds.shape[1]),
97
+ dtype=labels.dtype,
98
+ device=labels.device,
99
+ ),
100
+ labels,
101
+ ],
102
+ dim=-1,
103
+ )
104
+
105
+ if input_embeds is not None:
106
+ raise NotImplementedError
107
+ inputs_embeds = torch.cat(
108
+ [input_embeds, self.model.model.embed_tokens(input_token_ids)], dim=1
109
+ )
110
+ out = self.model(
111
+ inputs_embeds=inputs_embeds,
112
+ attention_mask=attention_mask,
113
+ labels=labels,
114
+ return_dict=True,
115
+ )
116
+ return out
117
+
118
+ out = self.model(
119
+ input_token_ids,
120
+ attention_mask=attention_mask,
121
+ labels=labels,
122
+ return_dict=True,
123
+ )
124
+ return out
125
+
126
+ def add_phone_eos_bos_label(
127
+ self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id
128
+ ):
129
+ # phone_ids: [B, T]
130
+ # phone_mask: [B, T]
131
+
132
+ phone_ids = phone_ids + self.target_vocab_size * phone_mask
133
+
134
+ phone_ids = phone_ids * phone_mask
135
+ phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad(
136
+ 1 - phone_mask, (0, 1), value=1
137
+ ) # make pad token eos token, add eos token at the end
138
+ phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask
139
+ phone_ids = phone_ids * phone_mask + pad_token_id * (1 - phone_mask) # restore pad token ids
140
+ phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token
141
+ phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask
142
+ phone_label = -100 * torch.ones_like(phone_ids) # loss for entire phone is not computed (passed to llama)
143
+ return phone_ids, phone_mask, phone_label
144
+
145
+ def add_target_eos_bos_label(
146
+ self, target_ids, target_mask, target_eos_id, target_bos_id, pad_token_id
147
+ ):
148
+ # target_ids: [B, T]
149
+ # target_mask: [B, T]
150
+ target_ids = target_ids * target_mask
151
+ target_ids = F.pad(target_ids, (0, 1), value=0) + target_eos_id * F.pad(
152
+ 1 - target_mask, (0, 1), value=1
153
+ )
154
+ target_mask = F.pad(target_mask, (1, 0), value=1)
155
+ target_ids = target_ids * target_mask + pad_token_id * (1 - target_mask)
156
+ target_ids = F.pad(target_ids, (1, 0), value=target_bos_id)
157
+ target_mask = F.pad(target_mask, (1, 0), value=1)
158
+ target_label = target_ids * target_mask + (-100) * (1 - target_mask) # loss for target is computed on unmasked tokens
159
+ return target_ids, target_mask, target_label
160
+
161
+ def sample_hf(
162
+ self,
163
+ phone_ids, # the phones of prompt and target should be concatenated together
164
+ prompt_ids,
165
+ inputs_embeds=None,
166
+ max_length=2000,
167
+ temperature=1.0,
168
+ top_k=100,
169
+ top_p=0.9,
170
+ repeat_penalty=1.0,
171
+ ):
172
+ if inputs_embeds is not None:
173
+ inputs_embeds = self.emb_linear(inputs_embeds)
174
+ phone_mask = torch.ones_like(phone_ids)
175
+ prompt_mask = torch.ones_like(prompt_ids)
176
+ phone_ids, _, _ = self.add_phone_eos_bos_label(
177
+ phone_ids,
178
+ phone_mask,
179
+ self.eos_phone_id,
180
+ self.bos_phone_id,
181
+ self.pad_token_id,
182
+ )
183
+ prompt_ids, _, _ = self.add_target_eos_bos_label(
184
+ prompt_ids,
185
+ prompt_mask,
186
+ self.eos_target_id,
187
+ self.bos_target_id,
188
+ self.pad_token_id,
189
+ )
190
+ prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode
191
+
192
+ input_token_ids = torch.cat([phone_ids, prompt_ids], dim=-1)
193
+
194
+ if inputs_embeds is not None:
195
+ raise NotImplementedError
196
+ inputs_embeds = torch.cat(
197
+ [inputs_embeds, self.model.model.embed_tokens(input_token_ids)], dim=1
198
+ )
199
+ generated_ids = self.model.generate(
200
+ inputs_embeds=inputs_embeds,
201
+ do_sample=True,
202
+ max_length=max_length,
203
+ pad_token_id=self.pad_token_id,
204
+ eos_token_id=self.eos_target_id,
205
+ temperature=temperature,
206
+ top_k=top_k,
207
+ top_p=top_p,
208
+ repetition_penalty=repeat_penalty,
209
+ )
210
+ gen_tokens = generated_ids[:, :-1]
211
+ return gen_tokens
212
+
213
+ input_length = input_token_ids.shape[1]
214
+ generated_ids = self.model.generate(
215
+ input_token_ids,
216
+ do_sample=True,
217
+ max_length=max_length,
218
+ pad_token_id=self.pad_token_id,
219
+ eos_token_id=self.eos_target_id,
220
+ temperature=temperature,
221
+ top_k=top_k,
222
+ top_p=top_p,
223
+ repetition_penalty=repeat_penalty,
224
+ )
225
+
226
+ gen_tokens = generated_ids[:, input_length:-1]
227
+
228
+ return gen_tokens
229
+
230
+ def test():
231
+ model = ValleAR()
232
+
233
+ phone_ids = torch.LongTensor([[1,2,3,4,5,0],
234
+ [1,2,3,4,5,6]])
235
+ phone_mask = torch.LongTensor([[1,1,1,0,0,0],
236
+ [1,1,1,0,0,0]])
237
+ target_ids = torch.LongTensor([765, 234, 123, 234, 123,599]).expand(2,-1)
238
+ target_mask = torch.LongTensor([1,1,1,1,0,0]).expand(2,-1)
239
+
240
+ optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
241
+
242
+ for i in range(15):
243
+ optimizer.zero_grad()
244
+ out = model(
245
+ phone_ids=phone_ids,
246
+ phone_mask=phone_mask,
247
+ target_ids=target_ids,
248
+ target_mask=target_mask,
249
+ )
250
+ loss = out.loss
251
+
252
+ loss.backward()
253
+
254
+ optimizer.step()
255
+
256
+ print(f"iter={i}, {loss}.")
257
+
258
+ phone_ids = torch.LongTensor([1,2,3]).reshape(1,-1)
259
+ target_ids = torch.LongTensor([765, 234]).reshape(1,-1)
260
+ sampled = model.sample_hf(phone_ids, target_ids)
261
+
262
+ breakpoint()
263
+
264
+ if __name__ == '__main__':
265
+ test()
models/valle_nar.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import os
6
+ import torch.nn as nn
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
10
+
11
+ from transformers.models.bert.modeling_bert import BertEncoder
12
+
13
+ from models.transformer.position_embedding import SinePositionalEmbedding
14
+
15
+ NUM_PROMPT_TOKENS=225
16
+
17
+ def initialize(module):
18
+ if isinstance(module, (nn.Linear, nn.Embedding, nn.modules.linear.NonDynamicallyQuantizableLinear)):
19
+ module.weight.data.normal_(mean=0.0, std=0.02)
20
+ if isinstance(module, nn.Linear) and module.bias is not None:
21
+ module.bias.data.zero_()
22
+
23
+ from transformers.models.llama.modeling_llama import CrossEntropyLoss
24
+ from easydict import EasyDict as edict
25
+
26
+ from modules.encoder import TokenEmbedding
27
+ from modules.norms import AdaptiveLayerNorm, LayerNorm
28
+
29
+ class ValleNAR(nn.Module):
30
+ def __init__(
31
+ self,
32
+ phone_vocab_size=256,
33
+ target_vocab_size=1024,
34
+ hidden_size=1024,
35
+ intermediate_size=4096,
36
+ num_hidden_layers=12,
37
+ num_attention_heads=16,
38
+ pad_token_id=1024+256,
39
+ bos_target_id=1282,
40
+ eos_target_id=1283,
41
+ bos_phone_id=1284,
42
+ eos_phone_id=1285,
43
+ bos_prompt_id=1286,
44
+ eos_prompt_id=1287,
45
+ use_input_embeds=False,
46
+ emb_dim=256,
47
+ num_quantizers=8,
48
+ ):
49
+ super(ValleNAR, self).__init__()
50
+
51
+ self.phone_vocab_size = phone_vocab_size
52
+ self.target_vocab_size = target_vocab_size
53
+ self.pad_token_id = pad_token_id
54
+ self.bos_target_id = bos_target_id
55
+ self.eos_target_id = eos_target_id
56
+ self.bos_phone_id = bos_phone_id
57
+ self.eos_phone_id = eos_phone_id
58
+ self.bos_prompt_id = bos_prompt_id
59
+ self.eos_prompt_id = eos_prompt_id
60
+
61
+ self.phone_embedder = TokenEmbedding(hidden_size, phone_vocab_size)
62
+
63
+ self.audio_embeddings = nn.ModuleList(
64
+ [
65
+ TokenEmbedding(hidden_size, target_vocab_size+1)
66
+ ] + [
67
+ TokenEmbedding(hidden_size, target_vocab_size)
68
+ for i in range(num_quantizers-1)
69
+ ]
70
+ )
71
+
72
+ from modules.transformer.transformer import TransformerEncoder, TransformerEncoderLayer
73
+ self.decoder = TransformerEncoder(
74
+ TransformerEncoderLayer(
75
+ hidden_size,
76
+ num_attention_heads,
77
+ dim_feedforward=int(4*hidden_size),
78
+ dropout=0.1,
79
+ batch_first=True,
80
+ norm_first=True,
81
+ adaptive_layer_norm=True,
82
+ activation=F.silu,
83
+ ),
84
+ num_layers=num_hidden_layers,
85
+ norm=(
86
+ AdaptiveLayerNorm(
87
+ hidden_size, norm=nn.LayerNorm(hidden_size)
88
+ )
89
+ )
90
+ )
91
+
92
+ self.predict_layers = nn.ModuleList(
93
+ [
94
+ nn.Linear(hidden_size, target_vocab_size, bias=False)
95
+ for i in range(num_quantizers-1)
96
+ ]
97
+ )
98
+
99
+ self.stage_embedding = nn.ModuleList(
100
+ [TokenEmbedding(hidden_size, 1) for i in range(num_quantizers)]
101
+ )
102
+
103
+ self.text_position = SinePositionalEmbedding(
104
+ hidden_size,
105
+ dropout=0.1,
106
+ scale=False,
107
+ alpha=True,
108
+ )
109
+ self.audio_position = SinePositionalEmbedding(
110
+ hidden_size,
111
+ dropout=0.1,
112
+ scale=False,
113
+ alpha=True,
114
+ )
115
+
116
+ def _mask_out_acoustic_tokens(self, target_ids, target_quantization_layer, start_time=NUM_PROMPT_TOKENS+1):
117
+ '''Mask out target_ids after the target_quantization_layer, except for the first 240 tokens.
118
+ target_ids: [8, B, T], which is padded and added with bos and eos tokens
119
+ target_quantization_layer: int
120
+
121
+ returns: [8, B, T] masked input_token_ids
122
+ '''
123
+ mask = torch.ones_like(target_ids, dtype=torch.long, device=target_ids.device)
124
+ mask[target_quantization_layer:, :, start_time:] = 0
125
+ input_token_ids = target_ids * mask
126
+ input_token_ids += (1-mask)*self.mask_target_id
127
+
128
+ return input_token_ids
129
+
130
+ def forward(
131
+ self, phone_ids, phone_mask, target_ids, target_mask, input_embeds=None,
132
+ target_quantization_layer=None,
133
+ ):
134
+ '''
135
+ phone_ids: [B, T]
136
+ phone_mask: [B, T]
137
+ target_ids: [8,B,T]
138
+ '''
139
+ target_ids = target_ids * target_mask
140
+
141
+ phone_label = torch.ones_like(phone_ids, dtype=torch.long) * -100
142
+ # get phone embedding
143
+ phone_embedding = self.phone_embedder(phone_ids) # [B, T, H]
144
+ phone_embedding = self.text_position(phone_embedding)
145
+
146
+
147
+ # randomly select a target to predict
148
+ # total quant layer is 0 to 7
149
+ if target_quantization_layer is None:
150
+ target_quantization_layer = np.random.randint(1, 8)
151
+
152
+ # extract 8-level prompts
153
+ prompt_tokens = target_ids[:, :, :NUM_PROMPT_TOKENS]
154
+ prompt_mask = torch.ones_like(prompt_tokens[0])
155
+ # prompt_label = -100 * prompt_mask
156
+ prompt_label = prompt_tokens[target_quantization_layer]
157
+ # get prompt embedding
158
+ prompt_embedding = self.audio_embeddings[0](prompt_tokens[0]) # [B, T, H]
159
+ for i in range(1, 8):
160
+ prompt_embedding += self.audio_embeddings[i](prompt_tokens[i])
161
+
162
+
163
+ # get y embedding
164
+ y_mask = target_mask[..., NUM_PROMPT_TOKENS:]
165
+ y_tokens = target_ids[:target_quantization_layer, :, NUM_PROMPT_TOKENS:] * y_mask
166
+ y_label = target_ids[target_quantization_layer, :, NUM_PROMPT_TOKENS:] * y_mask + -100*(1-y_mask)
167
+ y_embedding = self.audio_embeddings[0](y_tokens[0])
168
+ for i in range(1, target_quantization_layer):
169
+ y_embedding += self.audio_embeddings[i](y_tokens[i])
170
+
171
+ # concat y embedding and prmpt embedding
172
+ y_embedding = torch.concat([prompt_embedding, y_embedding], dim=1)
173
+ y_embedding = self.audio_position(y_embedding)
174
+
175
+ xy_pos = torch.concat([phone_embedding, y_embedding], dim=1)
176
+ xy_padding_mask = ~torch.concat([phone_mask, prompt_mask, y_mask], dim=1).to(torch.bool)
177
+ xy_dec, _ = self.decoder(
178
+ (xy_pos, self.stage_embedding[target_quantization_layer-1].weight),
179
+ src_key_padding_mask=xy_padding_mask,
180
+ )
181
+
182
+ target_label = torch.concat([phone_label, prompt_label, y_label], dim=1)
183
+
184
+
185
+ logits = self.predict_layers[target_quantization_layer-1](xy_dec).permute(0, 2, 1)
186
+ loss = CrossEntropyLoss()(logits, target_label)
187
+
188
+ out = edict(
189
+ loss=loss,
190
+ logits=logits,
191
+ )
192
+ return out
193
+ # # prompt eos embedding
194
+ # prompt_eos_embedding = self.phone_embedder(torch.tensor(self.eos_prompt_id-self.target_vocab_size, device=phone_ids.device).reshape(1).expand(phone_ids.shape[0], -1)) # [B, 1, H]
195
+
196
+ # # input embeddings
197
+ # input_embeddings = torch.cat([phone_embedding, prompt_embedding, prompt_eos_embedding, target_embedding], dim=1)
198
+ # input_mask = torch.cat([phone_mask, prompt_mask, torch.ones((phone_mask.shape[0], 1), dtype=torch.long, device=phone_mask.device), target_mask], dim=1) # [B, T]
199
+ # prediction_target = torch.cat([phone_label, prompt_label, -100*torch.ones((phone_mask.shape[0], 1), dtype=torch.long, device=phone_mask.device), target_labels], dim=1) # [B, T]
200
+
201
+
202
+ # out = self.model(
203
+ # cond=torch.tensor(target_quantization_layer, device=prediction_target.device, dtype=torch.long),
204
+ # input_ids=input_embeddings,
205
+ # prediction_target=prediction_target,
206
+ # attention_mask=input_mask,
207
+ # return_dict=True,
208
+ # )
209
+ # return out
210
+
211
+ def add_phone_eos_bos_label(
212
+ self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id
213
+ ):
214
+ # phone_ids: [B, T]
215
+ # phone_mask: [B, T]
216
+
217
+ phone_ids = phone_ids + self.target_vocab_size * phone_mask
218
+
219
+ phone_ids = phone_ids * phone_mask
220
+ phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad(
221
+ 1 - phone_mask, (0, 1), value=1
222
+ ) # make pad token eos token, add eos token at the end
223
+ phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask
224
+ phone_ids = phone_ids * phone_mask + pad_token_id * (1 - phone_mask) # restore pad token ids
225
+ phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token
226
+ phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask
227
+ phone_label = -100 * torch.ones_like(phone_ids) # loss for entire phone is not computed (passed to llama)
228
+ return phone_ids, phone_mask, phone_label
229
+
230
+ @torch.no_grad()
231
+ def sample_hf(
232
+ self,
233
+ phone_ids, # [B, T]
234
+ prompt_ids, # [8, B, T]
235
+ first_stage_ids, # [B, T]
236
+ ):
237
+ '''
238
+ phone_ids: [B, T]
239
+ prompt_ids: [8, B, T]
240
+ first_stage_ids: [B, T] result from first quant layer. Should be continuation of prompt_ids
241
+ '''
242
+ phone_mask = torch.ones_like(phone_ids, dtype=torch.long)
243
+
244
+ assert prompt_ids.shape[-1] >= NUM_PROMPT_TOKENS, "prompt_ids should have at least 240 tokens"
245
+ prompt_ids = prompt_ids[:, :, :NUM_PROMPT_TOKENS]
246
+ target_ids = torch.cat([prompt_ids, first_stage_ids.expand(prompt_ids.shape[0],-1,-1)], dim=-1)
247
+ target_mask = torch.ones_like(target_ids[0], dtype=torch.long)
248
+
249
+ gen_len = first_stage_ids.shape[-1]
250
+ for qnt_level in range(1, 8):
251
+ out = self.forward(
252
+ phone_ids=phone_ids,
253
+ phone_mask=phone_mask,
254
+ target_ids=target_ids,
255
+ target_mask=target_mask,
256
+ target_quantization_layer=qnt_level,
257
+ )
258
+ logits = out.logits
259
+ gen_tokens = torch.argmax(logits, dim=1)[0, -gen_len:] # [T], generated tokens in this level
260
+
261
+ # overwrite the target_ids with the generated tokens
262
+ target_ids[qnt_level, :, -gen_len:] = gen_tokens
263
+
264
+ return target_ids[:, :, -gen_len:]
265
+
266
+ def test():
267
+ model = ValleNAR().cuda()
268
+ model.apply(initialize)
269
+
270
+ phone_ids = torch.LongTensor([1,2,3,4,5]).reshape(1,-1).cuda()
271
+ phone_mask = torch.LongTensor([1,1,1,1,1]).reshape(1,-1).cuda()
272
+ target_ids = torch.randint(high=1024, size=(8,1,250), dtype=torch.long).cuda()
273
+ target_mask = torch.ones(1,250, dtype=torch.long).cuda()
274
+ optimizer = torch.optim.Adam(model.parameters(), lr=4e-4)
275
+
276
+ for i in range(200):
277
+ optimizer.zero_grad()
278
+ out = model(
279
+ phone_ids=phone_ids,
280
+ phone_mask=phone_mask,
281
+ target_ids=target_ids,
282
+ target_mask=target_mask,
283
+ target_quantization_layer=1+i%7,
284
+ )
285
+ loss = out.loss
286
+
287
+ loss.backward()
288
+
289
+ optimizer.step()
290
+
291
+ print(f"iter={i}, {loss}.")
292
+ target_ids_short = target_ids[:, :, :240]
293
+ sampled = model.sample_hf(phone_ids, prompt_ids=target_ids_short, first_stage_ids=target_ids[0, :, 240:])
294
+ breakpoint()
295
+
296
+ print(target_ids[:,:,-10:])
297
+ print(sampled)
298
+
299
+ print((sampled == target_ids[:,:,-10:]).all())
300
+
301
+
302
+ if __name__ == '__main__':
303
+ test()
modules/__init__.py ADDED
File without changes
modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (146 Bytes). View file
 
modules/activation_functions/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .gated_activation_unit import GaU
7
+ from .snake import Snake, SnakeBeta
modules/activation_functions/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (271 Bytes). View file
 
modules/activation_functions/__pycache__/gated_activation_unit.cpython-39.pyc ADDED
Binary file (1.75 kB). View file
 
modules/activation_functions/__pycache__/snake.cpython-39.pyc ADDED
Binary file (3.69 kB). View file
 
modules/activation_functions/gated_activation_unit.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from modules.general.utils import Conv1d
10
+
11
+
12
+ class GaU(nn.Module):
13
+ r"""Gated Activation Unit (GaU) proposed in `Gated Activation Units for Neural
14
+ Networks <https://arxiv.org/pdf/1606.05328.pdf>`_.
15
+
16
+ Args:
17
+ channels: number of input channels.
18
+ kernel_size: kernel size of the convolution.
19
+ dilation: dilation rate of the convolution.
20
+ d_context: dimension of context tensor, None if don't use context.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ channels: int,
26
+ kernel_size: int = 3,
27
+ dilation: int = 1,
28
+ d_context: int = None,
29
+ ):
30
+ super().__init__()
31
+
32
+ self.context = d_context
33
+
34
+ self.conv = Conv1d(
35
+ channels,
36
+ channels * 2,
37
+ kernel_size,
38
+ dilation=dilation,
39
+ padding=dilation * (kernel_size - 1) // 2,
40
+ )
41
+
42
+ if self.context:
43
+ self.context_proj = Conv1d(d_context, channels * 2, 1)
44
+
45
+ def forward(self, x: torch.Tensor, context: torch.Tensor = None):
46
+ r"""Calculate forward propagation.
47
+
48
+ Args:
49
+ x: input tensor with shape [B, C, T].
50
+ context: context tensor with shape [B, ``d_context``, T], default to None.
51
+ """
52
+
53
+ h = self.conv(x)
54
+
55
+ if self.context:
56
+ h = h + self.context_proj(context)
57
+
58
+ h1, h2 = h.chunk(2, 1)
59
+ h = torch.tanh(h1) * torch.sigmoid(h2)
60
+
61
+ return h
modules/activation_functions/snake.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from torch import nn, pow, sin
8
+ from torch.nn import Parameter
9
+
10
+
11
+ class Snake(nn.Module):
12
+ r"""Implementation of a sine-based periodic activation function.
13
+ Alpha is initialized to 1 by default, higher values means higher frequency.
14
+ It will be trained along with the rest of your model.
15
+
16
+ Args:
17
+ in_features: shape of the input
18
+ alpha: trainable parameter
19
+
20
+ Shape:
21
+ - Input: (B, C, T)
22
+ - Output: (B, C, T), same shape as the input
23
+
24
+ References:
25
+ This activation function is from this paper by Liu Ziyin, Tilman Hartwig,
26
+ Masahito Ueda: https://arxiv.org/abs/2006.08195
27
+
28
+ Examples:
29
+ >>> a1 = Snake(256)
30
+ >>> x = torch.randn(256)
31
+ >>> x = a1(x)
32
+ """
33
+
34
+ def __init__(
35
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
36
+ ):
37
+ super(Snake, self).__init__()
38
+ self.in_features = in_features
39
+
40
+ # initialize alpha
41
+ self.alpha_logscale = alpha_logscale
42
+ if self.alpha_logscale: # log scale alphas initialized to zeros
43
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
44
+ else: # linear scale alphas initialized to ones
45
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
46
+
47
+ self.alpha.requires_grad = alpha_trainable
48
+
49
+ self.no_div_by_zero = 0.000000001
50
+
51
+ def forward(self, x):
52
+ r"""Forward pass of the function. Applies the function to the input elementwise.
53
+ Snake ∶= x + 1/a * sin^2 (ax)
54
+ """
55
+
56
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
57
+ if self.alpha_logscale:
58
+ alpha = torch.exp(alpha)
59
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
60
+
61
+ return x
62
+
63
+
64
+ class SnakeBeta(nn.Module):
65
+ r"""A modified Snake function which uses separate parameters for the magnitude
66
+ of the periodic components. Alpha is initialized to 1 by default,
67
+ higher values means higher frequency. Beta is initialized to 1 by default,
68
+ higher values means higher magnitude. Both will be trained along with the
69
+ rest of your model.
70
+
71
+ Args:
72
+ in_features: shape of the input
73
+ alpha: trainable parameter that controls frequency
74
+ beta: trainable parameter that controls magnitude
75
+
76
+ Shape:
77
+ - Input: (B, C, T)
78
+ - Output: (B, C, T), same shape as the input
79
+
80
+ References:
81
+ This activation function is a modified version based on this paper by Liu Ziyin,
82
+ Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195
83
+
84
+ Examples:
85
+ >>> a1 = SnakeBeta(256)
86
+ >>> x = torch.randn(256)
87
+ >>> x = a1(x)
88
+ """
89
+
90
+ def __init__(
91
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
92
+ ):
93
+ super(SnakeBeta, self).__init__()
94
+ self.in_features = in_features
95
+
96
+ # initialize alpha
97
+ self.alpha_logscale = alpha_logscale
98
+ if self.alpha_logscale: # log scale alphas initialized to zeros
99
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
100
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
101
+ else: # linear scale alphas initialized to ones
102
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
103
+ self.beta = Parameter(torch.ones(in_features) * alpha)
104
+
105
+ self.alpha.requires_grad = alpha_trainable
106
+ self.beta.requires_grad = alpha_trainable
107
+
108
+ self.no_div_by_zero = 0.000000001
109
+
110
+ def forward(self, x):
111
+ r"""Forward pass of the function. Applies the function to the input elementwise.
112
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
113
+ """
114
+
115
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
116
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
117
+ if self.alpha_logscale:
118
+ alpha = torch.exp(alpha)
119
+ beta = torch.exp(beta)
120
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
121
+
122
+ return x
modules/anti_aliasing/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .act import *
7
+ from .filter import *
8
+ from .resample import *
modules/anti_aliasing/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (218 Bytes). View file
 
modules/anti_aliasing/__pycache__/act.cpython-39.pyc ADDED
Binary file (1 kB). View file
 
modules/anti_aliasing/__pycache__/filter.cpython-39.pyc ADDED
Binary file (2.6 kB). View file
 
modules/anti_aliasing/__pycache__/resample.cpython-39.pyc ADDED
Binary file (1.91 kB). View file
 
modules/anti_aliasing/act.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch.nn as nn
7
+
8
+ from .resample import *
9
+
10
+ # This code is adopted from BigVGAN under the MIT License
11
+ # https://github.com/NVIDIA/BigVGAN
12
+
13
+
14
+ class Activation1d(nn.Module):
15
+ def __init__(
16
+ self,
17
+ activation,
18
+ up_ratio: int = 2,
19
+ down_ratio: int = 2,
20
+ up_kernel_size: int = 12,
21
+ down_kernel_size: int = 12,
22
+ ):
23
+ super().__init__()
24
+ self.up_ratio = up_ratio
25
+ self.down_ratio = down_ratio
26
+ self.act = activation
27
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
28
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
29
+
30
+ # x: [B,C,T]
31
+ def forward(self, x):
32
+ x = self.upsample(x)
33
+ x = self.act(x)
34
+ x = self.downsample(x)
35
+
36
+ return x
modules/anti_aliasing/filter.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import math
10
+
11
+ if "sinc" in dir(torch):
12
+ sinc = torch.sinc
13
+ else:
14
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
15
+ # https://adefossez.github.io/julius/julius/core.html
16
+ def sinc(x: torch.Tensor):
17
+ """
18
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
19
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
20
+ """
21
+ return torch.where(
22
+ x == 0,
23
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
24
+ torch.sin(math.pi * x) / math.pi / x,
25
+ )
26
+
27
+
28
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
29
+ # https://adefossez.github.io/julius/julius/lowpass.html
30
+ def kaiser_sinc_filter1d(
31
+ cutoff, half_width, kernel_size
32
+ ): # return filter [1,1,kernel_size]
33
+ even = kernel_size % 2 == 0
34
+ half_size = kernel_size // 2
35
+
36
+ # For kaiser window
37
+ delta_f = 4 * half_width
38
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
39
+ if A > 50.0:
40
+ beta = 0.1102 * (A - 8.7)
41
+ elif A >= 21.0:
42
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
43
+ else:
44
+ beta = 0.0
45
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
46
+
47
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
48
+ if even:
49
+ time = torch.arange(-half_size, half_size) + 0.5
50
+ else:
51
+ time = torch.arange(kernel_size) - half_size
52
+ if cutoff == 0:
53
+ filter_ = torch.zeros_like(time)
54
+ else:
55
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
56
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
57
+ # of the constant component in the input signal.
58
+ filter_ /= filter_.sum()
59
+ filter = filter_.view(1, 1, kernel_size)
60
+
61
+ return filter
62
+
63
+
64
+ class LowPassFilter1d(nn.Module):
65
+ def __init__(
66
+ self,
67
+ cutoff=0.5,
68
+ half_width=0.6,
69
+ stride: int = 1,
70
+ padding: bool = True,
71
+ padding_mode: str = "replicate",
72
+ kernel_size: int = 12,
73
+ ):
74
+ # kernel_size should be even number for stylegan3 setup,
75
+ # in this implementation, odd number is also possible.
76
+ super().__init__()
77
+ if cutoff < -0.0:
78
+ raise ValueError("Minimum cutoff must be larger than zero.")
79
+ if cutoff > 0.5:
80
+ raise ValueError("A cutoff above 0.5 does not make sense.")
81
+ self.kernel_size = kernel_size
82
+ self.even = kernel_size % 2 == 0
83
+ self.pad_left = kernel_size // 2 - int(self.even)
84
+ self.pad_right = kernel_size // 2
85
+ self.stride = stride
86
+ self.padding = padding
87
+ self.padding_mode = padding_mode
88
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
89
+ self.register_buffer("filter", filter)
90
+
91
+ # input [B, C, T]
92
+ def forward(self, x):
93
+ _, C, _ = x.shape
94
+
95
+ if self.padding:
96
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
97
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
98
+
99
+ return out
modules/anti_aliasing/resample.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ #################### Anti-aliasing ####################
7
+
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+
11
+ from .filter import *
12
+
13
+ # This code is adopted from BigVGAN under the MIT License
14
+ # https://github.com/NVIDIA/BigVGAN
15
+
16
+
17
+ class UpSample1d(nn.Module):
18
+ def __init__(self, ratio=2, kernel_size=None):
19
+ super().__init__()
20
+ self.ratio = ratio
21
+ self.kernel_size = (
22
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
23
+ )
24
+ self.stride = ratio
25
+ self.pad = self.kernel_size // ratio - 1
26
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
27
+ self.pad_right = (
28
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
29
+ )
30
+ filter = kaiser_sinc_filter1d(
31
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
32
+ )
33
+ self.register_buffer("filter", filter)
34
+
35
+ # x: [B, C, T]
36
+ def forward(self, x):
37
+ _, C, _ = x.shape
38
+
39
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
40
+ x = self.ratio * F.conv_transpose1d(
41
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
42
+ )
43
+ x = x[..., self.pad_left : -self.pad_right]
44
+
45
+ return x
46
+
47
+
48
+ class DownSample1d(nn.Module):
49
+ def __init__(self, ratio=2, kernel_size=None):
50
+ super().__init__()
51
+ self.ratio = ratio
52
+ self.kernel_size = (
53
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
54
+ )
55
+ self.lowpass = LowPassFilter1d(
56
+ cutoff=0.5 / ratio,
57
+ half_width=0.6 / ratio,
58
+ stride=ratio,
59
+ kernel_size=self.kernel_size,
60
+ )
61
+
62
+ def forward(self, x):
63
+ xx = self.lowpass(x)
64
+
65
+ return xx
modules/base/base_module.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+
11
+ class LayerNorm(nn.Module):
12
+ def __init__(self, channels, eps=1e-5):
13
+ super().__init__()
14
+ self.channels = channels
15
+ self.eps = eps
16
+
17
+ self.gamma = nn.Parameter(torch.ones(channels))
18
+ self.beta = nn.Parameter(torch.zeros(channels))
19
+
20
+ def forward(self, x):
21
+ x = x.transpose(1, -1)
22
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
23
+ return x.transpose(1, -1)
24
+
25
+
26
+ class ConvReluNorm(nn.Module):
27
+ def __init__(
28
+ self,
29
+ in_channels,
30
+ hidden_channels,
31
+ out_channels,
32
+ kernel_size,
33
+ n_layers,
34
+ p_dropout,
35
+ ):
36
+ super().__init__()
37
+ self.in_channels = in_channels
38
+ self.hidden_channels = hidden_channels
39
+ self.out_channels = out_channels
40
+ self.kernel_size = kernel_size
41
+ self.n_layers = n_layers
42
+ self.p_dropout = p_dropout
43
+ assert n_layers > 1, "Number of layers should be larger than 0."
44
+
45
+ self.conv_layers = nn.ModuleList()
46
+ self.norm_layers = nn.ModuleList()
47
+ self.conv_layers.append(
48
+ nn.Conv1d(
49
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
50
+ )
51
+ )
52
+ self.norm_layers.append(LayerNorm(hidden_channels))
53
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
54
+ for _ in range(n_layers - 1):
55
+ self.conv_layers.append(
56
+ nn.Conv1d(
57
+ hidden_channels,
58
+ hidden_channels,
59
+ kernel_size,
60
+ padding=kernel_size // 2,
61
+ )
62
+ )
63
+ self.norm_layers.append(LayerNorm(hidden_channels))
64
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
65
+ self.proj.weight.data.zero_()
66
+ self.proj.bias.data.zero_()
67
+
68
+ def forward(self, x, x_mask):
69
+ x_org = x
70
+ for i in range(self.n_layers):
71
+ x = self.conv_layers[i](x * x_mask)
72
+ x = self.norm_layers[i](x)
73
+ x = self.relu_drop(x)
74
+ x = x_org + self.proj(x)
75
+ return x * x_mask
modules/diffusion/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .bidilconv.bidilated_conv import BiDilConv
7
+ from .unet.unet import UNet
modules/diffusion/bidilconv/bidilated_conv.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch.nn as nn
9
+
10
+ from modules.general.utils import Conv1d, zero_module
11
+ from .residual_block import ResidualBlock
12
+
13
+
14
+ class BiDilConv(nn.Module):
15
+ r"""Dilated CNN architecture with residual connections, default diffusion decoder.
16
+
17
+ Args:
18
+ input_channel: The number of input channels.
19
+ base_channel: The number of base channels.
20
+ n_res_block: The number of residual blocks.
21
+ conv_kernel_size: The kernel size of convolutional layers.
22
+ dilation_cycle_length: The cycle length of dilation.
23
+ conditioner_size: The size of conditioner.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ input_channel,
29
+ base_channel,
30
+ n_res_block,
31
+ conv_kernel_size,
32
+ dilation_cycle_length,
33
+ conditioner_size,
34
+ output_channel: int = -1,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.input_channel = input_channel
39
+ self.base_channel = base_channel
40
+ self.n_res_block = n_res_block
41
+ self.conv_kernel_size = conv_kernel_size
42
+ self.dilation_cycle_length = dilation_cycle_length
43
+ self.conditioner_size = conditioner_size
44
+ self.output_channel = output_channel if output_channel > 0 else input_channel
45
+
46
+ self.input = nn.Sequential(
47
+ Conv1d(
48
+ input_channel,
49
+ base_channel,
50
+ 1,
51
+ ),
52
+ nn.ReLU(),
53
+ )
54
+
55
+ self.residual_blocks = nn.ModuleList(
56
+ [
57
+ ResidualBlock(
58
+ channels=base_channel,
59
+ kernel_size=conv_kernel_size,
60
+ dilation=2 ** (i % dilation_cycle_length),
61
+ d_context=conditioner_size,
62
+ )
63
+ for i in range(n_res_block)
64
+ ]
65
+ )
66
+
67
+ self.out_proj = nn.Sequential(
68
+ Conv1d(
69
+ base_channel,
70
+ base_channel,
71
+ 1,
72
+ ),
73
+ nn.ReLU(),
74
+ zero_module(
75
+ Conv1d(
76
+ base_channel,
77
+ self.output_channel,
78
+ 1,
79
+ ),
80
+ ),
81
+ )
82
+
83
+ def forward(self, x, y, context=None):
84
+ """
85
+ Args:
86
+ x: Noisy mel-spectrogram [B x ``n_mel`` x L]
87
+ y: FILM embeddings with the shape of (B, ``base_channel``)
88
+ context: Context with the shape of [B x ``d_context`` x L], default to None.
89
+ """
90
+
91
+ h = self.input(x)
92
+
93
+ skip = None
94
+ for i in range(self.n_res_block):
95
+ h, skip_connection = self.residual_blocks[i](h, y, context)
96
+ skip = skip_connection if skip is None else skip_connection + skip
97
+
98
+ out = skip / math.sqrt(self.n_res_block)
99
+
100
+ out = self.out_proj(out)
101
+
102
+ return out
modules/diffusion/bidilconv/residual_block.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from modules.activation_functions import GaU
12
+ from modules.general.utils import Conv1d
13
+
14
+
15
+ class ResidualBlock(nn.Module):
16
+ r"""Residual block with dilated convolution, main portion of ``BiDilConv``.
17
+
18
+ Args:
19
+ channels: The number of channels of input and output.
20
+ kernel_size: The kernel size of dilated convolution.
21
+ dilation: The dilation rate of dilated convolution.
22
+ d_context: The dimension of content encoder output, None if don't use context.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ channels: int = 256,
28
+ kernel_size: int = 3,
29
+ dilation: int = 1,
30
+ d_context: int = None,
31
+ ):
32
+ super().__init__()
33
+
34
+ self.context = d_context
35
+
36
+ self.gau = GaU(
37
+ channels,
38
+ kernel_size,
39
+ dilation,
40
+ d_context,
41
+ )
42
+
43
+ self.out_proj = Conv1d(
44
+ channels,
45
+ channels * 2,
46
+ 1,
47
+ )
48
+
49
+ def forward(
50
+ self,
51
+ x: torch.Tensor,
52
+ y_emb: torch.Tensor,
53
+ context: torch.Tensor = None,
54
+ ):
55
+ """
56
+ Args:
57
+ x: Latent representation inherited from previous residual block
58
+ with the shape of [B x C x T].
59
+ y_emb: Embeddings with the shape of [B x C], which will be FILM on the x.
60
+ context: Context with the shape of [B x ``d_context`` x T], default to None.
61
+ """
62
+
63
+ h = x + y_emb[..., None]
64
+
65
+ if self.context:
66
+ h = self.gau(h, context)
67
+ else:
68
+ h = self.gau(h)
69
+
70
+ h = self.out_proj(h)
71
+ res, skip = h.chunk(2, 1)
72
+
73
+ return (res + x) / math.sqrt(2.0), skip
modules/diffusion/karras/karras_diffusion.py ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ Based on: https://github.com/crowsonkb/k-diffusion
8
+ """
9
+ import random
10
+
11
+ import numpy as np
12
+ import torch as th
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ # from piq import LPIPS
17
+ from utils.ssim import SSIM
18
+
19
+ from modules.diffusion.karras.random_utils import get_generator
20
+
21
+
22
+ def mean_flat(tensor):
23
+ """
24
+ Take the mean over all non-batch dimensions.
25
+ """
26
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
27
+
28
+
29
+ def append_dims(x, target_dims):
30
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
31
+ dims_to_append = target_dims - x.ndim
32
+ if dims_to_append < 0:
33
+ raise ValueError(
34
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
35
+ )
36
+ return x[(...,) + (None,) * dims_to_append]
37
+
38
+
39
+ def append_zero(x):
40
+ return th.cat([x, x.new_zeros([1])])
41
+
42
+
43
+ def get_weightings(weight_schedule, snrs, sigma_data):
44
+ if weight_schedule == "snr":
45
+ weightings = snrs
46
+ elif weight_schedule == "snr+1":
47
+ weightings = snrs + 1
48
+ elif weight_schedule == "karras":
49
+ weightings = snrs + 1.0 / sigma_data**2
50
+ elif weight_schedule == "truncated-snr":
51
+ weightings = th.clamp(snrs, min=1.0)
52
+ elif weight_schedule == "uniform":
53
+ weightings = th.ones_like(snrs)
54
+ else:
55
+ raise NotImplementedError()
56
+ return weightings
57
+
58
+
59
+ class KarrasDenoiser:
60
+ def __init__(
61
+ self,
62
+ sigma_data: float = 0.5,
63
+ sigma_max=80.0,
64
+ sigma_min=0.002,
65
+ rho=7.0,
66
+ weight_schedule="karras",
67
+ distillation=False,
68
+ loss_norm="l2",
69
+ ):
70
+ self.sigma_data = sigma_data
71
+ self.sigma_max = sigma_max
72
+ self.sigma_min = sigma_min
73
+ self.weight_schedule = weight_schedule
74
+ self.distillation = distillation
75
+ self.loss_norm = loss_norm
76
+ # if loss_norm == "lpips":
77
+ # self.lpips_loss = LPIPS(replace_pooling=True, reduction="none")
78
+ if loss_norm == "ssim":
79
+ self.ssim_loss = SSIM()
80
+ self.rho = rho
81
+ self.num_timesteps = 40
82
+
83
+ def get_snr(self, sigmas):
84
+ return sigmas**-2
85
+
86
+ def get_sigmas(self, sigmas):
87
+ return sigmas
88
+
89
+ def get_scalings(self, sigma):
90
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
91
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
92
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
93
+ return c_skip, c_out, c_in
94
+
95
+ def get_scalings_for_boundary_condition(self, sigma):
96
+ c_skip = self.sigma_data**2 / (
97
+ (sigma - self.sigma_min) ** 2 + self.sigma_data**2
98
+ )
99
+ c_out = (
100
+ (sigma - self.sigma_min)
101
+ * self.sigma_data
102
+ / (sigma**2 + self.sigma_data**2) ** 0.5
103
+ )
104
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
105
+ return c_skip, c_out, c_in
106
+
107
+ def training_losses(self, model, x_start, sigmas, condition=None, noise=None):
108
+ if noise is None:
109
+ noise = th.randn_like(x_start)
110
+
111
+ terms = {}
112
+
113
+ dims = x_start.ndim
114
+ x_t = x_start + noise * append_dims(sigmas, dims)
115
+ model_output, denoised = self.denoise(model, x_t, sigmas, condition)
116
+
117
+ snrs = self.get_snr(sigmas)
118
+ weights = append_dims(
119
+ get_weightings(self.weight_schedule, snrs, self.sigma_data), dims
120
+ )
121
+ # terms["xs_mse"] = mean_flat((denoised - x_start) ** 2)
122
+ terms["mse"] = mean_flat(weights * (denoised - x_start) ** 2)
123
+ # terms["mae"] = mean_flat(weights * th.abs(denoised - x_start))
124
+ # terms["mse"] = nn.MSELoss(reduction="none")(denoised, x_start)
125
+
126
+ # if "vb" in terms:
127
+ # terms["loss"] = terms["mse"] + terms["vb"]
128
+ # else:
129
+ terms["loss"] = terms["mse"]
130
+
131
+ return terms
132
+
133
+ def consistency_losses(
134
+ self,
135
+ model,
136
+ x_start,
137
+ num_scales,
138
+ # model_kwargs=None,
139
+ condition=None,
140
+ target_model=None,
141
+ teacher_model=None,
142
+ teacher_diffusion=None,
143
+ noise=None,
144
+ ):
145
+ if noise is None:
146
+ noise = th.randn_like(x_start)
147
+
148
+ dims = x_start.ndim
149
+
150
+ def denoise_fn(x, t):
151
+ return self.denoise(model, x, t, condition)[1]
152
+
153
+ if target_model:
154
+
155
+ @th.no_grad()
156
+ def target_denoise_fn(x, t):
157
+ return self.denoise(target_model, x, t, condition)[1]
158
+
159
+ else:
160
+ raise NotImplementedError("Must have a target model")
161
+
162
+ if teacher_model:
163
+
164
+ @th.no_grad()
165
+ def teacher_denoise_fn(x, t):
166
+ return teacher_diffusion.denoise(teacher_model, x, t, condition)[1]
167
+
168
+ @th.no_grad()
169
+ def heun_solver(samples, t, next_t, x0):
170
+ x = samples
171
+ if teacher_model is None:
172
+ denoiser = x0
173
+ else:
174
+ denoiser = teacher_denoise_fn(x, t)
175
+
176
+ d = (x - denoiser) / append_dims(t, dims)
177
+ samples = x + d * append_dims(next_t - t, dims)
178
+ if teacher_model is None:
179
+ denoiser = x0
180
+ else:
181
+ denoiser = teacher_denoise_fn(samples, next_t)
182
+
183
+ next_d = (samples - denoiser) / append_dims(next_t, dims)
184
+ samples = x + (d + next_d) * append_dims((next_t - t) / 2, dims)
185
+
186
+ return samples
187
+
188
+ @th.no_grad()
189
+ def euler_solver(samples, t, next_t, x0):
190
+ x = samples
191
+ if teacher_model is None:
192
+ denoiser = x0
193
+ else:
194
+ denoiser = teacher_denoise_fn(x, t)
195
+ d = (x - denoiser) / append_dims(t, dims)
196
+ samples = x + d * append_dims(next_t - t, dims)
197
+
198
+ return samples
199
+
200
+ indices = th.randint(
201
+ 0, num_scales - 1, (x_start.shape[0],), device=x_start.device
202
+ )
203
+
204
+ t = self.sigma_max ** (1 / self.rho) + indices / (num_scales - 1) * (
205
+ self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
206
+ )
207
+ t = t**self.rho
208
+
209
+ t2 = self.sigma_max ** (1 / self.rho) + (indices + 1) / (num_scales - 1) * (
210
+ self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
211
+ )
212
+ t2 = t2**self.rho
213
+
214
+ x_t = x_start + noise * append_dims(t, dims)
215
+
216
+ dropout_state = th.get_rng_state()
217
+ distiller = denoise_fn(x_t, t)
218
+
219
+ if teacher_model is None:
220
+ x_t2 = euler_solver(x_t, t, t2, x_start).detach()
221
+ else:
222
+ x_t2 = heun_solver(x_t, t, t2, x_start).detach()
223
+
224
+ th.set_rng_state(dropout_state)
225
+ distiller_target = target_denoise_fn(x_t2, t2)
226
+ distiller_target = distiller_target.detach()
227
+
228
+ snrs = self.get_snr(t)
229
+ weights = get_weightings(self.weight_schedule, snrs, self.sigma_data)
230
+ if self.loss_norm == "l1":
231
+ diffs = th.abs(distiller - distiller_target)
232
+ loss = mean_flat(diffs) * weights
233
+ elif self.loss_norm == "l2":
234
+ # diffs = (distiller - distiller_target) ** 2
235
+ loss = F.mse_loss(distiller, distiller_target)
236
+ # loss = mean_flat(diffs) * weights
237
+ elif self.loss_norm == "ssim":
238
+ loss = self.ssim_loss(distiller, distiller_target) * weights
239
+ # elif self.loss_norm == "l2-32":
240
+ # distiller = F.interpolate(distiller, size=32, mode="bilinear")
241
+ # distiller_target = F.interpolate(
242
+ # distiller_target,
243
+ # size=32,
244
+ # mode="bilinear",
245
+ # )
246
+ # diffs = (distiller - distiller_target) ** 2
247
+ # loss = mean_flat(diffs) * weights
248
+ # elif self.loss_norm == "lpips":
249
+ # if x_start.shape[-1] < 256:
250
+ # distiller = F.interpolate(distiller, size=224, mode="bilinear")
251
+ # distiller_target = F.interpolate(
252
+ # distiller_target, size=224, mode="bilinear"
253
+ # )
254
+
255
+ # loss = (
256
+ # self.lpips_loss(
257
+ # (distiller + 1) / 2.0,
258
+ # (distiller_target + 1) / 2.0,
259
+ # )
260
+ # * weights
261
+ # )
262
+ else:
263
+ raise ValueError(f"Unknown loss norm {self.loss_norm}")
264
+
265
+ terms = {}
266
+ terms["loss"] = loss
267
+
268
+ return terms
269
+
270
+ # def progdist_losses(
271
+ # self,
272
+ # model,
273
+ # x_start,
274
+ # num_scales,
275
+ # model_kwargs=None,
276
+ # teacher_model=None,
277
+ # teacher_diffusion=None,
278
+ # noise=None,
279
+ # ):
280
+ # if model_kwargs is None:
281
+ # model_kwargs = {}
282
+ # if noise is None:
283
+ # noise = th.randn_like(x_start)
284
+
285
+ # dims = x_start.ndim
286
+
287
+ # def denoise_fn(x, t):
288
+ # return self.denoise(model, x, t, **model_kwargs)[1]
289
+
290
+ # @th.no_grad()
291
+ # def teacher_denoise_fn(x, t):
292
+ # return teacher_diffusion.denoise(teacher_model, x, t, **model_kwargs)[1]
293
+
294
+ # @th.no_grad()
295
+ # def euler_solver(samples, t, next_t):
296
+ # x = samples
297
+ # denoiser = teacher_denoise_fn(x, t)
298
+ # d = (x - denoiser) / append_dims(t, dims)
299
+ # samples = x + d * append_dims(next_t - t, dims)
300
+
301
+ # return samples
302
+
303
+ # @th.no_grad()
304
+ # def euler_to_denoiser(x_t, t, x_next_t, next_t):
305
+ # denoiser = x_t - append_dims(t, dims) * (x_next_t - x_t) / append_dims(
306
+ # next_t - t, dims
307
+ # )
308
+ # return denoiser
309
+
310
+ # indices = th.randint(0, num_scales, (x_start.shape[0],), device=x_start.device)
311
+
312
+ # t = self.sigma_max ** (1 / self.rho) + indices / num_scales * (
313
+ # self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
314
+ # )
315
+ # t = t**self.rho
316
+
317
+ # t2 = self.sigma_max ** (1 / self.rho) + (indices + 0.5) / num_scales * (
318
+ # self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
319
+ # )
320
+ # t2 = t2**self.rho
321
+
322
+ # t3 = self.sigma_max ** (1 / self.rho) + (indices + 1) / num_scales * (
323
+ # self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
324
+ # )
325
+ # t3 = t3**self.rho
326
+
327
+ # x_t = x_start + noise * append_dims(t, dims)
328
+
329
+ # denoised_x = denoise_fn(x_t, t)
330
+
331
+ # x_t2 = euler_solver(x_t, t, t2).detach()
332
+ # x_t3 = euler_solver(x_t2, t2, t3).detach()
333
+
334
+ # target_x = euler_to_denoiser(x_t, t, x_t3, t3).detach()
335
+
336
+ # snrs = self.get_snr(t)
337
+ # weights = get_weightings(self.weight_schedule, snrs, self.sigma_data)
338
+ # if self.loss_norm == "l1":
339
+ # diffs = th.abs(denoised_x - target_x)
340
+ # loss = mean_flat(diffs) * weights
341
+ # elif self.loss_norm == "l2":
342
+ # diffs = (denoised_x - target_x) ** 2
343
+ # loss = mean_flat(diffs) * weights
344
+ # elif self.loss_norm == "lpips":
345
+ # if x_start.shape[-1] < 256:
346
+ # denoised_x = F.interpolate(denoised_x, size=224, mode="bilinear")
347
+ # target_x = F.interpolate(target_x, size=224, mode="bilinear")
348
+ # loss = (
349
+ # self.lpips_loss(
350
+ # (denoised_x + 1) / 2.0,
351
+ # (target_x + 1) / 2.0,
352
+ # )
353
+ # * weights
354
+ # )
355
+ # else:
356
+ # raise ValueError(f"Unknown loss norm {self.loss_norm}")
357
+
358
+ # terms = {}
359
+ # terms["loss"] = loss
360
+
361
+ # return terms
362
+
363
+ def denoise(self, model, x_t, sigmas, condition):
364
+ if not self.distillation:
365
+ c_skip, c_out, c_in = [
366
+ append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)
367
+ ]
368
+ else:
369
+ c_skip, c_out, c_in = [
370
+ append_dims(x, x_t.ndim)
371
+ for x in self.get_scalings_for_boundary_condition(sigmas)
372
+ ]
373
+ rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)
374
+ # rescaled_t = rescaled_t[:, None]
375
+ model_output = model(c_in * x_t, rescaled_t, condition)
376
+ denoised = c_out * model_output + c_skip * x_t
377
+ return model_output, denoised
378
+
379
+
380
+ def karras_sample(
381
+ diffusion,
382
+ model,
383
+ shape,
384
+ steps,
385
+ clip_denoised=True,
386
+ progress=True,
387
+ callback=None,
388
+ # model_kwargs=None,
389
+ condition=None,
390
+ device=None,
391
+ sigma_min=0.002,
392
+ sigma_max=80, # higher for highres?
393
+ rho=7.0,
394
+ sampler="heun",
395
+ s_churn=0.0,
396
+ s_tmin=0.0,
397
+ s_tmax=float("inf"),
398
+ s_noise=1.0,
399
+ generator=None,
400
+ ts=None,
401
+ ):
402
+ if generator is None:
403
+ generator = get_generator("dummy")
404
+
405
+ if sampler == "progdist":
406
+ sigmas = get_sigmas_karras(steps + 1, sigma_min, sigma_max, rho, device=device)
407
+ else:
408
+ sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)
409
+ th.manual_seed(42)
410
+ x_T = generator.randn(*shape, device=device) * sigma_max
411
+ sigmas = sigmas.unsqueeze(-1)
412
+ sample_fn = {
413
+ "heun": sample_heun,
414
+ "dpm": sample_dpm,
415
+ "ancestral": sample_euler_ancestral,
416
+ "onestep": sample_onestep,
417
+ "progdist": sample_progdist,
418
+ "euler": sample_euler,
419
+ "multistep": stochastic_iterative_sampler,
420
+ }[sampler]
421
+
422
+ if sampler in ["heun", "dpm"]:
423
+ sampler_args = dict(
424
+ s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise
425
+ )
426
+ elif sampler == "multistep":
427
+ sampler_args = dict(
428
+ ts=ts, t_min=sigma_min, t_max=sigma_max, rho=diffusion.rho, steps=steps
429
+ )
430
+ else:
431
+ sampler_args = {}
432
+
433
+ def denoiser(x_t, sigma):
434
+ _, denoised = diffusion.denoise(model, x_t, sigma, condition)
435
+ if clip_denoised:
436
+ denoised = denoised.clamp(-1, 1)
437
+ return denoised
438
+
439
+ x_0 = sample_fn(
440
+ denoiser,
441
+ x_T,
442
+ sigmas,
443
+ generator,
444
+ progress=progress,
445
+ callback=callback,
446
+ **sampler_args,
447
+ )
448
+ return x_0.clamp(-1, 1)
449
+
450
+
451
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
452
+ """Constructs the noise schedule of Karras et al. (2022)."""
453
+ ramp = th.linspace(0, 1, n)
454
+ min_inv_rho = sigma_min ** (1 / rho)
455
+ max_inv_rho = sigma_max ** (1 / rho)
456
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
457
+ return append_zero(sigmas).to(device)
458
+
459
+
460
+ def to_d(x, sigma, denoised):
461
+ """Converts a denoiser output to a Karras ODE derivative."""
462
+ return (x - denoised) / append_dims(sigma, x.ndim)
463
+
464
+
465
+ def get_ancestral_step(sigma_from, sigma_to):
466
+ """Calculates the noise level (sigma_down) to step down to and the amount
467
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
468
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
469
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
470
+ return sigma_down, sigma_up
471
+
472
+
473
+ @th.no_grad()
474
+ def sample_euler_ancestral(model, x, sigmas, generator, progress=False, callback=None):
475
+ """Ancestral sampling with Euler method steps."""
476
+ s_in = x.new_ones([x.shape[0]])
477
+ indices = range(len(sigmas) - 1)
478
+ if progress:
479
+ from tqdm.auto import tqdm
480
+
481
+ indices = tqdm(indices)
482
+
483
+ for i in indices:
484
+ denoised = model(x, sigmas[i] * s_in)
485
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
486
+ if callback is not None:
487
+ callback(
488
+ {
489
+ "x": x,
490
+ "i": i,
491
+ "sigma": sigmas[i],
492
+ "sigma_hat": sigmas[i],
493
+ "denoised": denoised,
494
+ }
495
+ )
496
+ d = to_d(x, sigmas[i], denoised)
497
+ # Euler method
498
+ dt = sigma_down - sigmas[i]
499
+ x = x + d * dt
500
+ x = x + generator.randn_like(x) * sigma_up
501
+ return x
502
+
503
+
504
+ @th.no_grad()
505
+ def sample_midpoint_ancestral(model, x, ts, generator, progress=False, callback=None):
506
+ """Ancestral sampling with midpoint method steps."""
507
+ s_in = x.new_ones([x.shape[0]])
508
+ step_size = 1 / len(ts)
509
+ if progress:
510
+ from tqdm.auto import tqdm
511
+
512
+ ts = tqdm(ts)
513
+
514
+ for tn in ts:
515
+ dn = model(x, tn * s_in)
516
+ dn_2 = model(x + (step_size / 2) * dn, (tn + step_size / 2) * s_in)
517
+ x = x + step_size * dn_2
518
+ if callback is not None:
519
+ callback({"x": x, "tn": tn, "dn": dn, "dn_2": dn_2})
520
+ return x
521
+
522
+
523
+ @th.no_grad()
524
+ def sample_heun(
525
+ denoiser,
526
+ x,
527
+ sigmas,
528
+ generator,
529
+ progress=False,
530
+ callback=None,
531
+ s_churn=0.0,
532
+ s_tmin=0.0,
533
+ s_tmax=float("inf"),
534
+ s_noise=1.0,
535
+ ):
536
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
537
+ s_in = x.new_ones([x.shape[0]])
538
+ indices = range(len(sigmas) - 1)
539
+ if progress:
540
+ from tqdm.auto import tqdm
541
+
542
+ indices = tqdm(indices)
543
+
544
+ for i in indices:
545
+ gamma = (
546
+ min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
547
+ if s_tmin <= sigmas[i] <= s_tmax
548
+ else 0.0
549
+ )
550
+ eps = generator.randn_like(x) * s_noise
551
+ sigma_hat = sigmas[i] * (gamma + 1)
552
+ if gamma > 0:
553
+ x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
554
+ denoised = denoiser(x, sigma_hat * s_in)
555
+ d = to_d(x, sigma_hat, denoised)
556
+ if callback is not None:
557
+ callback(
558
+ {
559
+ "x": x,
560
+ "i": i,
561
+ "sigma": sigmas[i],
562
+ "sigma_hat": sigma_hat,
563
+ "denoised": denoised,
564
+ }
565
+ )
566
+ dt = sigmas[i + 1] - sigma_hat
567
+ if sigmas[i + 1] == 0:
568
+ # Euler method
569
+ x = x + d * dt
570
+ else:
571
+ # Heun's method
572
+ x_2 = x + d * dt
573
+ denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in)
574
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
575
+ d_prime = (d + d_2) / 2
576
+ x = x + d_prime * dt
577
+ return x
578
+
579
+
580
+ @th.no_grad()
581
+ def sample_euler(
582
+ denoiser,
583
+ x,
584
+ sigmas,
585
+ generator,
586
+ progress=False,
587
+ callback=None,
588
+ ):
589
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
590
+ s_in = x.new_ones([x.shape[0]])
591
+ indices = range(len(sigmas) - 1)
592
+ if progress:
593
+ from tqdm.auto import tqdm
594
+
595
+ indices = tqdm(indices)
596
+
597
+ for i in indices:
598
+ sigma = sigmas[i]
599
+ denoised = denoiser(x, sigma * s_in)
600
+ d = to_d(x, sigma, denoised)
601
+ if callback is not None:
602
+ callback(
603
+ {
604
+ "x": x,
605
+ "i": i,
606
+ "sigma": sigmas[i],
607
+ "denoised": denoised,
608
+ }
609
+ )
610
+ dt = sigmas[i + 1] - sigma
611
+ x = x + d * dt
612
+ return x
613
+
614
+
615
+ @th.no_grad()
616
+ def sample_dpm(
617
+ denoiser,
618
+ x,
619
+ sigmas,
620
+ generator,
621
+ progress=False,
622
+ callback=None,
623
+ s_churn=0.0,
624
+ s_tmin=0.0,
625
+ s_tmax=float("inf"),
626
+ s_noise=1.0,
627
+ ):
628
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
629
+ s_in = x.new_ones([x.shape[0]])
630
+ indices = range(len(sigmas) - 1)
631
+ if progress:
632
+ from tqdm.auto import tqdm
633
+
634
+ indices = tqdm(indices)
635
+
636
+ for i in indices:
637
+ gamma = (
638
+ min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
639
+ if s_tmin <= sigmas[i] <= s_tmax
640
+ else 0.0
641
+ )
642
+ eps = generator.randn_like(x) * s_noise
643
+ sigma_hat = sigmas[i] * (gamma + 1)
644
+ if gamma > 0:
645
+ x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
646
+ denoised = denoiser(x, sigma_hat * s_in)
647
+ d = to_d(x, sigma_hat, denoised)
648
+ if callback is not None:
649
+ callback(
650
+ {
651
+ "x": x,
652
+ "i": i,
653
+ "sigma": sigmas[i],
654
+ "sigma_hat": sigma_hat,
655
+ "denoised": denoised,
656
+ }
657
+ )
658
+ # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
659
+ sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
660
+ dt_1 = sigma_mid - sigma_hat
661
+ dt_2 = sigmas[i + 1] - sigma_hat
662
+ x_2 = x + d * dt_1
663
+ denoised_2 = denoiser(x_2, sigma_mid * s_in)
664
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
665
+ x = x + d_2 * dt_2
666
+ return x
667
+
668
+
669
+ @th.no_grad()
670
+ def sample_onestep(
671
+ distiller,
672
+ x,
673
+ sigmas,
674
+ generator=None,
675
+ progress=False,
676
+ callback=None,
677
+ ):
678
+ """Single-step generation from a distilled model."""
679
+ s_in = x.new_ones([x.shape[0]])
680
+ return distiller(x, sigmas[0] * s_in)
681
+
682
+
683
+ @th.no_grad()
684
+ def stochastic_iterative_sampler(
685
+ distiller,
686
+ x,
687
+ sigmas,
688
+ generator,
689
+ ts,
690
+ progress=False,
691
+ callback=None,
692
+ t_min=0.002,
693
+ t_max=80.0,
694
+ rho=7.0,
695
+ steps=40,
696
+ ):
697
+ t_max_rho = t_max ** (1 / rho)
698
+ t_min_rho = t_min ** (1 / rho)
699
+ s_in = x.new_ones([x.shape[0]])
700
+
701
+ for i in range(len(ts) - 1):
702
+ t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
703
+ x0 = distiller(x, t * s_in)
704
+ next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
705
+ next_t = np.clip(next_t, t_min, t_max)
706
+ x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)
707
+
708
+ return x
709
+
710
+
711
+ @th.no_grad()
712
+ def sample_progdist(
713
+ denoiser,
714
+ x,
715
+ sigmas,
716
+ generator=None,
717
+ progress=False,
718
+ callback=None,
719
+ ):
720
+ s_in = x.new_ones([x.shape[0]])
721
+ sigmas = sigmas[:-1] # skip the zero sigma
722
+
723
+ indices = range(len(sigmas) - 1)
724
+ if progress:
725
+ from tqdm.auto import tqdm
726
+
727
+ indices = tqdm(indices)
728
+
729
+ for i in indices:
730
+ sigma = sigmas[i]
731
+ denoised = denoiser(x, sigma * s_in)
732
+ d = to_d(x, sigma, denoised)
733
+ if callback is not None:
734
+ callback(
735
+ {
736
+ "x": x,
737
+ "i": i,
738
+ "sigma": sigma,
739
+ "denoised": denoised,
740
+ }
741
+ )
742
+ dt = sigmas[i + 1] - sigma
743
+ x = x + d * dt
744
+
745
+ return x
746
+
747
+
748
+ # @th.no_grad()
749
+ # def iterative_colorization(
750
+ # distiller,
751
+ # images,
752
+ # x,
753
+ # ts,
754
+ # t_min=0.002,
755
+ # t_max=80.0,
756
+ # rho=7.0,
757
+ # steps=40,
758
+ # generator=None,
759
+ # ):
760
+ # def obtain_orthogonal_matrix():
761
+ # vector = np.asarray([0.2989, 0.5870, 0.1140])
762
+ # vector = vector / np.linalg.norm(vector)
763
+ # matrix = np.eye(3)
764
+ # matrix[:, 0] = vector
765
+ # matrix = np.linalg.qr(matrix)[0]
766
+ # if np.sum(matrix[:, 0]) < 0:
767
+ # matrix = -matrix
768
+ # return matrix
769
+
770
+ # Q = th.from_numpy(obtain_orthogonal_matrix()).to(dist_util.dev()).to(th.float32)
771
+ # mask = th.zeros(*x.shape[1:], device=dist_util.dev())
772
+ # mask[0, ...] = 1.0
773
+
774
+ # def replacement(x0, x1):
775
+ # x0 = th.einsum("bchw,cd->bdhw", x0, Q)
776
+ # x1 = th.einsum("bchw,cd->bdhw", x1, Q)
777
+
778
+ # x_mix = x0 * mask + x1 * (1.0 - mask)
779
+ # x_mix = th.einsum("bdhw,cd->bchw", x_mix, Q)
780
+ # return x_mix
781
+
782
+ # t_max_rho = t_max ** (1 / rho)
783
+ # t_min_rho = t_min ** (1 / rho)
784
+ # s_in = x.new_ones([x.shape[0]])
785
+ # images = replacement(images, th.zeros_like(images))
786
+
787
+ # for i in range(len(ts) - 1):
788
+ # t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
789
+ # x0 = distiller(x, t * s_in)
790
+ # x0 = th.clamp(x0, -1.0, 1.0)
791
+ # x0 = replacement(images, x0)
792
+ # next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
793
+ # next_t = np.clip(next_t, t_min, t_max)
794
+ # x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)
795
+
796
+ # return x, images
797
+
798
+
799
+ # @th.no_grad()
800
+ # def iterative_inpainting(
801
+ # distiller,
802
+ # images,
803
+ # x,
804
+ # ts,
805
+ # t_min=0.002,
806
+ # t_max=80.0,
807
+ # rho=7.0,
808
+ # steps=40,
809
+ # generator=None,
810
+ # ):
811
+ # from PIL import Image, ImageDraw, ImageFont
812
+
813
+ # image_size = x.shape[-1]
814
+
815
+ # # create a blank image with a white background
816
+ # img = Image.new("RGB", (image_size, image_size), color="white")
817
+
818
+ # # get a drawing context for the image
819
+ # draw = ImageDraw.Draw(img)
820
+
821
+ # # load a font
822
+ # font = ImageFont.truetype("arial.ttf", 250)
823
+
824
+ # # draw the letter "C" in black
825
+ # draw.text((50, 0), "S", font=font, fill=(0, 0, 0))
826
+
827
+ # # convert the image to a numpy array
828
+ # img_np = np.array(img)
829
+ # img_np = img_np.transpose(2, 0, 1)
830
+ # img_th = th.from_numpy(img_np).to(dist_util.dev())
831
+
832
+ # mask = th.zeros(*x.shape, device=dist_util.dev())
833
+ # mask = mask.reshape(-1, 7, 3, image_size, image_size)
834
+
835
+ # mask[::2, :, img_th > 0.5] = 1.0
836
+ # mask[1::2, :, img_th < 0.5] = 1.0
837
+ # mask = mask.reshape(-1, 3, image_size, image_size)
838
+
839
+ # def replacement(x0, x1):
840
+ # x_mix = x0 * mask + x1 * (1 - mask)
841
+ # return x_mix
842
+
843
+ # t_max_rho = t_max ** (1 / rho)
844
+ # t_min_rho = t_min ** (1 / rho)
845
+ # s_in = x.new_ones([x.shape[0]])
846
+ # images = replacement(images, -th.ones_like(images))
847
+
848
+ # for i in range(len(ts) - 1):
849
+ # t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
850
+ # x0 = distiller(x, t * s_in)
851
+ # x0 = th.clamp(x0, -1.0, 1.0)
852
+ # x0 = replacement(images, x0)
853
+ # next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
854
+ # next_t = np.clip(next_t, t_min, t_max)
855
+ # x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)
856
+
857
+ # return x, images
858
+
859
+
860
+ # @th.no_grad()
861
+ # def iterative_superres(
862
+ # distiller,
863
+ # images,
864
+ # x,
865
+ # ts,
866
+ # t_min=0.002,
867
+ # t_max=80.0,
868
+ # rho=7.0,
869
+ # steps=40,
870
+ # generator=None,
871
+ # ):
872
+ # patch_size = 8
873
+
874
+ # def obtain_orthogonal_matrix():
875
+ # vector = np.asarray([1] * patch_size**2)
876
+ # vector = vector / np.linalg.norm(vector)
877
+ # matrix = np.eye(patch_size**2)
878
+ # matrix[:, 0] = vector
879
+ # matrix = np.linalg.qr(matrix)[0]
880
+ # if np.sum(matrix[:, 0]) < 0:
881
+ # matrix = -matrix
882
+ # return matrix
883
+
884
+ # Q = th.from_numpy(obtain_orthogonal_matrix()).to(dist_util.dev()).to(th.float32)
885
+
886
+ # image_size = x.shape[-1]
887
+
888
+ # def replacement(x0, x1):
889
+ # x0_flatten = (
890
+ # x0.reshape(-1, 3, image_size, image_size)
891
+ # .reshape(
892
+ # -1,
893
+ # 3,
894
+ # image_size // patch_size,
895
+ # patch_size,
896
+ # image_size // patch_size,
897
+ # patch_size,
898
+ # )
899
+ # .permute(0, 1, 2, 4, 3, 5)
900
+ # .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
901
+ # )
902
+ # x1_flatten = (
903
+ # x1.reshape(-1, 3, image_size, image_size)
904
+ # .reshape(
905
+ # -1,
906
+ # 3,
907
+ # image_size // patch_size,
908
+ # patch_size,
909
+ # image_size // patch_size,
910
+ # patch_size,
911
+ # )
912
+ # .permute(0, 1, 2, 4, 3, 5)
913
+ # .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
914
+ # )
915
+ # x0 = th.einsum("bcnd,de->bcne", x0_flatten, Q)
916
+ # x1 = th.einsum("bcnd,de->bcne", x1_flatten, Q)
917
+ # x_mix = x0.new_zeros(x0.shape)
918
+ # x_mix[..., 0] = x0[..., 0]
919
+ # x_mix[..., 1:] = x1[..., 1:]
920
+ # x_mix = th.einsum("bcne,de->bcnd", x_mix, Q)
921
+ # x_mix = (
922
+ # x_mix.reshape(
923
+ # -1,
924
+ # 3,
925
+ # image_size // patch_size,
926
+ # image_size // patch_size,
927
+ # patch_size,
928
+ # patch_size,
929
+ # )
930
+ # .permute(0, 1, 2, 4, 3, 5)
931
+ # .reshape(-1, 3, image_size, image_size)
932
+ # )
933
+ # return x_mix
934
+
935
+ # def average_image_patches(x):
936
+ # x_flatten = (
937
+ # x.reshape(-1, 3, image_size, image_size)
938
+ # .reshape(
939
+ # -1,
940
+ # 3,
941
+ # image_size // patch_size,
942
+ # patch_size,
943
+ # image_size // patch_size,
944
+ # patch_size,
945
+ # )
946
+ # .permute(0, 1, 2, 4, 3, 5)
947
+ # .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
948
+ # )
949
+ # x_flatten[..., :] = x_flatten.mean(dim=-1, keepdim=True)
950
+ # return (
951
+ # x_flatten.reshape(
952
+ # -1,
953
+ # 3,
954
+ # image_size // patch_size,
955
+ # image_size // patch_size,
956
+ # patch_size,
957
+ # patch_size,
958
+ # )
959
+ # .permute(0, 1, 2, 4, 3, 5)
960
+ # .reshape(-1, 3, image_size, image_size)
961
+ # )
962
+
963
+ # t_max_rho = t_max ** (1 / rho)
964
+ # t_min_rho = t_min ** (1 / rho)
965
+ # s_in = x.new_ones([x.shape[0]])
966
+ # images = average_image_patches(images)
967
+
968
+ # for i in range(len(ts) - 1):
969
+ # t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
970
+ # x0 = distiller(x, t * s_in)
971
+ # x0 = th.clamp(x0, -1.0, 1.0)
972
+ # x0 = replacement(images, x0)
973
+ # next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
974
+ # next_t = np.clip(next_t, t_min, t_max)
975
+ # x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)
976
+
977
+ # return x, images
modules/diffusion/karras/random_utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch as th
7
+
8
+
9
+ def get_generator(generator, num_samples=0, seed=0):
10
+ if generator == "dummy":
11
+ return DummyGenerator()
12
+ elif generator == "determ":
13
+ return DeterministicGenerator(num_samples, seed)
14
+ elif generator == "determ-indiv":
15
+ return DeterministicIndividualGenerator(num_samples, seed)
16
+ else:
17
+ raise NotImplementedError
18
+
19
+
20
+ class DummyGenerator:
21
+ def randn(self, *args, **kwargs):
22
+ return th.randn(*args, **kwargs)
23
+
24
+ def randint(self, *args, **kwargs):
25
+ return th.randint(*args, **kwargs)
26
+
27
+ def randn_like(self, *args, **kwargs):
28
+ return th.randn_like(*args, **kwargs)
29
+
30
+
31
+ class DeterministicGenerator:
32
+ """
33
+ RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines
34
+ Uses a single rng and samples num_samples sized randomness and subsamples the current indices
35
+ """
36
+
37
+ def __init__(self, num_samples, seed=0):
38
+ print("Warning: Distributed not initialised, using single rank")
39
+ self.rank = 0
40
+ self.world_size = 1
41
+ self.num_samples = num_samples
42
+ self.done_samples = 0
43
+ self.seed = seed
44
+ self.rng_cpu = th.Generator()
45
+ if th.cuda.is_available():
46
+ self.rng_cuda = th.Generator(dist_util.dev())
47
+ self.set_seed(seed)
48
+
49
+ def get_global_size_and_indices(self, size):
50
+ global_size = (self.num_samples, *size[1:])
51
+ indices = th.arange(
52
+ self.done_samples + self.rank,
53
+ self.done_samples + self.world_size * int(size[0]),
54
+ self.world_size,
55
+ )
56
+ indices = th.clamp(indices, 0, self.num_samples - 1)
57
+ assert (
58
+ len(indices) == size[0]
59
+ ), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}"
60
+ return global_size, indices
61
+
62
+ def get_generator(self, device):
63
+ return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda
64
+
65
+ def randn(self, *size, dtype=th.float, device="cpu"):
66
+ global_size, indices = self.get_global_size_and_indices(size)
67
+ generator = self.get_generator(device)
68
+ return th.randn(*global_size, generator=generator, dtype=dtype, device=device)[
69
+ indices
70
+ ]
71
+
72
+ def randint(self, low, high, size, dtype=th.long, device="cpu"):
73
+ global_size, indices = self.get_global_size_and_indices(size)
74
+ generator = self.get_generator(device)
75
+ return th.randint(
76
+ low, high, generator=generator, size=global_size, dtype=dtype, device=device
77
+ )[indices]
78
+
79
+ def randn_like(self, tensor):
80
+ size, dtype, device = tensor.size(), tensor.dtype, tensor.device
81
+ return self.randn(*size, dtype=dtype, device=device)
82
+
83
+ def set_done_samples(self, done_samples):
84
+ self.done_samples = done_samples
85
+ self.set_seed(self.seed)
86
+
87
+ def get_seed(self):
88
+ return self.seed
89
+
90
+ def set_seed(self, seed):
91
+ self.rng_cpu.manual_seed(seed)
92
+ if th.cuda.is_available():
93
+ self.rng_cuda.manual_seed(seed)
94
+
95
+
96
+ class DeterministicIndividualGenerator:
97
+ """
98
+ RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines
99
+ Uses a separate rng for each sample to reduce memoery usage
100
+ """
101
+
102
+ def __init__(self, num_samples, seed=0):
103
+ print("Warning: Distributed not initialised, using single rank")
104
+ self.rank = 0
105
+ self.world_size = 1
106
+ self.num_samples = num_samples
107
+ self.done_samples = 0
108
+ self.seed = seed
109
+ self.rng_cpu = [th.Generator() for _ in range(num_samples)]
110
+ if th.cuda.is_available():
111
+ self.rng_cuda = [th.Generator(dist_util.dev()) for _ in range(num_samples)]
112
+ self.set_seed(seed)
113
+
114
+ def get_size_and_indices(self, size):
115
+ indices = th.arange(
116
+ self.done_samples + self.rank,
117
+ self.done_samples + self.world_size * int(size[0]),
118
+ self.world_size,
119
+ )
120
+ indices = th.clamp(indices, 0, self.num_samples - 1)
121
+ assert (
122
+ len(indices) == size[0]
123
+ ), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}"
124
+ return (1, *size[1:]), indices
125
+
126
+ def get_generator(self, device):
127
+ return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda
128
+
129
+ def randn(self, *size, dtype=th.float, device="cpu"):
130
+ size, indices = self.get_size_and_indices(size)
131
+ generator = self.get_generator(device)
132
+ return th.cat(
133
+ [
134
+ th.randn(*size, generator=generator[i], dtype=dtype, device=device)
135
+ for i in indices
136
+ ],
137
+ dim=0,
138
+ )
139
+
140
+ def randint(self, low, high, size, dtype=th.long, device="cpu"):
141
+ size, indices = self.get_size_and_indices(size)
142
+ generator = self.get_generator(device)
143
+ return th.cat(
144
+ [
145
+ th.randint(
146
+ low,
147
+ high,
148
+ generator=generator[i],
149
+ size=size,
150
+ dtype=dtype,
151
+ device=device,
152
+ )
153
+ for i in indices
154
+ ],
155
+ dim=0,
156
+ )
157
+
158
+ def randn_like(self, tensor):
159
+ size, dtype, device = tensor.size(), tensor.dtype, tensor.device
160
+ return self.randn(*size, dtype=dtype, device=device)
161
+
162
+ def set_done_samples(self, done_samples):
163
+ self.done_samples = done_samples
164
+
165
+ def get_seed(self):
166
+ return self.seed
167
+
168
+ def set_seed(self, seed):
169
+ [
170
+ rng_cpu.manual_seed(i + self.num_samples * seed)
171
+ for i, rng_cpu in enumerate(self.rng_cpu)
172
+ ]
173
+ if th.cuda.is_available():
174
+ [
175
+ rng_cuda.manual_seed(i + self.num_samples * seed)
176
+ for i, rng_cuda in enumerate(self.rng_cuda)
177
+ ]
modules/diffusion/karras/sample.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ from scipy.stats import norm
11
+ import torch.distributed as dist
12
+
13
+
14
+ def create_named_schedule_sampler(name, diffusion):
15
+ """
16
+ Create a ScheduleSampler from a library of pre-defined samplers.
17
+
18
+ :param name: the name of the sampler.
19
+ :param diffusion: the diffusion object to sample for.
20
+ """
21
+ if name == "uniform":
22
+ return UniformSampler(diffusion)
23
+ elif name == "loss-second-moment":
24
+ return LossSecondMomentResampler(diffusion)
25
+ elif name == "lognormal":
26
+ return LogNormalSampler()
27
+ else:
28
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
29
+
30
+
31
+ class ScheduleSampler(ABC):
32
+ """
33
+ A distribution over timesteps in the diffusion process, intended to reduce
34
+ variance of the objective.
35
+
36
+ By default, samplers perform unbiased importance sampling, in which the
37
+ objective's mean is unchanged.
38
+ However, subclasses may override sample() to change how the resampled
39
+ terms are reweighted, allowing for actual changes in the objective.
40
+ """
41
+
42
+ @abstractmethod
43
+ def weights(self):
44
+ """
45
+ Get a numpy array of weights, one per diffusion step.
46
+
47
+ The weights needn't be normalized, but must be positive.
48
+ """
49
+
50
+ def sample(self, batch_size, device):
51
+ """
52
+ Importance-sample timesteps for a batch.
53
+
54
+ :param batch_size: the number of timesteps.
55
+ :param device: the torch device to save to.
56
+ :return: a tuple (timesteps, weights):
57
+ - timesteps: a tensor of timestep indices.
58
+ - weights: a tensor of weights to scale the resulting losses.
59
+ """
60
+ w = self.weights()
61
+ p = w / np.sum(w)
62
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
63
+ indices = th.from_numpy(indices_np).long().to(device)
64
+ weights_np = 1 / (len(p) * p[indices_np])
65
+ weights = th.from_numpy(weights_np).float().to(device)
66
+ return indices, weights
67
+
68
+
69
+ class UniformSampler(ScheduleSampler):
70
+ def __init__(self, diffusion):
71
+ self.diffusion = diffusion
72
+ self._weights = np.ones([diffusion.num_timesteps])
73
+
74
+ def weights(self):
75
+ return self._weights
76
+
77
+
78
+ class LossAwareSampler(ScheduleSampler):
79
+ def update_with_local_losses(self, local_ts, local_losses):
80
+ """
81
+ Update the reweighting using losses from a model.
82
+
83
+ Call this method from each rank with a batch of timesteps and the
84
+ corresponding losses for each of those timesteps.
85
+ This method will perform synchronization to make sure all of the ranks
86
+ maintain the exact same reweighting.
87
+
88
+ :param local_ts: an integer Tensor of timesteps.
89
+ :param local_losses: a 1D Tensor of losses.
90
+ """
91
+ batch_sizes = [
92
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
93
+ for _ in range(dist.get_world_size())
94
+ ]
95
+ dist.all_gather(
96
+ batch_sizes,
97
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
98
+ )
99
+
100
+ # Pad all_gather batches to be the maximum batch size.
101
+ batch_sizes = [x.item() for x in batch_sizes]
102
+ max_bs = max(batch_sizes)
103
+
104
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
105
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
106
+ dist.all_gather(timestep_batches, local_ts)
107
+ dist.all_gather(loss_batches, local_losses)
108
+ timesteps = [
109
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
110
+ ]
111
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
112
+ self.update_with_all_losses(timesteps, losses)
113
+
114
+ @abstractmethod
115
+ def update_with_all_losses(self, ts, losses):
116
+ """
117
+ Update the reweighting using losses from a model.
118
+
119
+ Sub-classes should override this method to update the reweighting
120
+ using losses from the model.
121
+
122
+ This method directly updates the reweighting without synchronizing
123
+ between workers. It is called by update_with_local_losses from all
124
+ ranks with identical arguments. Thus, it should have deterministic
125
+ behavior to maintain state across workers.
126
+
127
+ :param ts: a list of int timesteps.
128
+ :param losses: a list of float losses, one per timestep.
129
+ """
130
+
131
+
132
+ class LossSecondMomentResampler(LossAwareSampler):
133
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
134
+ self.diffusion = diffusion
135
+ self.history_per_term = history_per_term
136
+ self.uniform_prob = uniform_prob
137
+ self._loss_history = np.zeros(
138
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
139
+ )
140
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
141
+
142
+ def weights(self):
143
+ if not self._warmed_up():
144
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
145
+ weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
146
+ weights /= np.sum(weights)
147
+ weights *= 1 - self.uniform_prob
148
+ weights += self.uniform_prob / len(weights)
149
+ return weights
150
+
151
+ def update_with_all_losses(self, ts, losses):
152
+ for t, loss in zip(ts, losses):
153
+ if self._loss_counts[t] == self.history_per_term:
154
+ # Shift out the oldest loss term.
155
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
156
+ self._loss_history[t, -1] = loss
157
+ else:
158
+ self._loss_history[t, self._loss_counts[t]] = loss
159
+ self._loss_counts[t] += 1
160
+
161
+ def _warmed_up(self):
162
+ return (self._loss_counts == self.history_per_term).all()
163
+
164
+
165
+ class LogNormalSampler:
166
+ def __init__(self, p_mean=-1.2, p_std=1.2, even=False):
167
+ self.p_mean = p_mean
168
+ self.p_std = p_std
169
+ self.even = even
170
+ if self.even:
171
+ self.inv_cdf = lambda x: norm.ppf(x, loc=p_mean, scale=p_std)
172
+ self.rank, self.size = dist.get_rank(), dist.get_world_size()
173
+
174
+ def sample(self, bs, device):
175
+ if self.even:
176
+ # buckets = [1/G]
177
+ start_i, end_i = self.rank * bs, (self.rank + 1) * bs
178
+ global_batch_size = self.size * bs
179
+ locs = (th.arange(start_i, end_i) + th.rand(bs)) / global_batch_size
180
+ log_sigmas = th.tensor(self.inv_cdf(locs), dtype=th.float32, device=device)
181
+ else:
182
+ log_sigmas = self.p_mean + self.p_std * th.randn(bs, device=device)
183
+ sigmas = th.exp(log_sigmas)
184
+ weights = th.ones_like(sigmas)
185
+ return sigmas, weights
modules/diffusion/unet/attention.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from modules.general.utils import Conv1d, normalization, zero_module
11
+ from .basic import UNetBlock
12
+
13
+
14
+ class AttentionBlock(UNetBlock):
15
+ r"""A spatial transformer encoder block that allows spatial positions to attend
16
+ to each other. Reference from `latent diffusion repo
17
+ <https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/attention.py#L531>`_.
18
+
19
+ Args:
20
+ channels: Number of channels in the input.
21
+ num_head_channels: Number of channels per attention head.
22
+ num_heads: Number of attention heads. Overrides ``num_head_channels`` if set.
23
+ encoder_channels: Number of channels in the encoder output for cross-attention.
24
+ If ``None``, then self-attention is performed.
25
+ use_self_attention: Whether to use self-attention before cross-attention, only applicable if encoder_channels is set.
26
+ dims: Number of spatial dimensions, i.e. 1 for temporal signals, 2 for images.
27
+ h_dim: The dimension of the height, would be applied if ``dims`` is 2.
28
+ encoder_hdim: The dimension of the height of the encoder output, would be applied if ``dims`` is 2.
29
+ p_dropout: Dropout probability.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ channels: int,
35
+ num_head_channels: int = 32,
36
+ num_heads: int = -1,
37
+ encoder_channels: int = None,
38
+ use_self_attention: bool = False,
39
+ dims: int = 1,
40
+ h_dim: int = 100,
41
+ encoder_hdim: int = 384,
42
+ p_dropout: float = 0.0,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.channels = channels
47
+ self.p_dropout = p_dropout
48
+ self.dims = dims
49
+
50
+ if dims == 1:
51
+ self.channels = channels
52
+ elif dims == 2:
53
+ # We consider the channel as product of channel and height, i.e. C x H
54
+ # This is because we want to apply attention on the audio signal, which is 1D
55
+ self.channels = channels * h_dim
56
+ else:
57
+ raise ValueError(f"invalid number of dimensions: {dims}")
58
+
59
+ if num_head_channels == -1:
60
+ assert (
61
+ self.channels % num_heads == 0
62
+ ), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}"
63
+ self.num_heads = num_heads
64
+ self.num_head_channels = self.channels // num_heads
65
+ else:
66
+ assert (
67
+ self.channels % num_head_channels == 0
68
+ ), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}"
69
+ self.num_heads = self.channels // num_head_channels
70
+ self.num_head_channels = num_head_channels
71
+
72
+ if encoder_channels is not None:
73
+ self.use_self_attention = use_self_attention
74
+
75
+ if dims == 1:
76
+ self.encoder_channels = encoder_channels
77
+ elif dims == 2:
78
+ self.encoder_channels = encoder_channels * encoder_hdim
79
+ else:
80
+ raise ValueError(f"invalid number of dimensions: {dims}")
81
+
82
+ if use_self_attention:
83
+ self.self_attention = BasicAttentionBlock(
84
+ self.channels,
85
+ self.num_head_channels,
86
+ self.num_heads,
87
+ p_dropout=self.p_dropout,
88
+ )
89
+ self.cross_attention = BasicAttentionBlock(
90
+ self.channels,
91
+ self.num_head_channels,
92
+ self.num_heads,
93
+ self.encoder_channels,
94
+ p_dropout=self.p_dropout,
95
+ )
96
+ else:
97
+ self.encoder_channels = None
98
+ self.self_attention = BasicAttentionBlock(
99
+ self.channels,
100
+ self.num_head_channels,
101
+ self.num_heads,
102
+ p_dropout=self.p_dropout,
103
+ )
104
+
105
+ def forward(self, x: torch.Tensor, encoder_output: torch.Tensor = None):
106
+ r"""
107
+ Args:
108
+ x: input tensor with shape [B x ``channels`` x ...]
109
+ encoder_output: feature tensor with shape [B x ``encoder_channels`` x ...], if ``None``, then self-attention is performed.
110
+
111
+ Returns:
112
+ output tensor with shape [B x ``channels`` x ...]
113
+ """
114
+ shape = x.size()
115
+ x = x.reshape(shape[0], self.channels, -1).contiguous()
116
+
117
+ if self.encoder_channels is None:
118
+ assert (
119
+ encoder_output is None
120
+ ), "encoder_output must be None for self-attention."
121
+ h = self.self_attention(x)
122
+
123
+ else:
124
+ assert (
125
+ encoder_output is not None
126
+ ), "encoder_output must be given for cross-attention."
127
+ encoder_output = encoder_output.reshape(
128
+ shape[0], self.encoder_channels, -1
129
+ ).contiguous()
130
+
131
+ if self.use_self_attention:
132
+ x = self.self_attention(x)
133
+ h = self.cross_attention(x, encoder_output)
134
+
135
+ return h.reshape(*shape).contiguous()
136
+
137
+
138
+ class BasicAttentionBlock(nn.Module):
139
+ def __init__(
140
+ self,
141
+ channels: int,
142
+ num_head_channels: int = 32,
143
+ num_heads: int = -1,
144
+ context_channels: int = None,
145
+ p_dropout: float = 0.0,
146
+ ):
147
+ super().__init__()
148
+
149
+ self.channels = channels
150
+ self.p_dropout = p_dropout
151
+ self.context_channels = context_channels
152
+
153
+ if num_head_channels == -1:
154
+ assert (
155
+ self.channels % num_heads == 0
156
+ ), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}"
157
+ self.num_heads = num_heads
158
+ self.num_head_channels = self.channels // num_heads
159
+ else:
160
+ assert (
161
+ self.channels % num_head_channels == 0
162
+ ), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}"
163
+ self.num_heads = self.channels // num_head_channels
164
+ self.num_head_channels = num_head_channels
165
+
166
+ if context_channels is not None:
167
+ self.to_q = nn.Sequential(
168
+ normalization(self.channels),
169
+ Conv1d(self.channels, self.channels, 1),
170
+ )
171
+ self.to_kv = Conv1d(context_channels, 2 * self.channels, 1)
172
+ else:
173
+ self.to_qkv = nn.Sequential(
174
+ normalization(self.channels),
175
+ Conv1d(self.channels, 3 * self.channels, 1),
176
+ )
177
+
178
+ self.linear = Conv1d(self.channels, self.channels)
179
+
180
+ self.proj_out = nn.Sequential(
181
+ normalization(self.channels),
182
+ Conv1d(self.channels, self.channels, 1),
183
+ nn.GELU(),
184
+ nn.Dropout(p=self.p_dropout),
185
+ zero_module(Conv1d(self.channels, self.channels, 1)),
186
+ )
187
+
188
+ def forward(self, q: torch.Tensor, kv: torch.Tensor = None):
189
+ r"""
190
+ Args:
191
+ q: input tensor with shape [B, ``channels``, L]
192
+ kv: feature tensor with shape [B, ``context_channels``, T], if ``None``, then self-attention is performed.
193
+
194
+ Returns:
195
+ output tensor with shape [B, ``channels``, L]
196
+ """
197
+ N, C, L = q.size()
198
+
199
+ if self.context_channels is not None:
200
+ assert kv is not None, "kv must be given for cross-attention."
201
+
202
+ q = (
203
+ self.to_q(q)
204
+ .reshape(self.num_heads, self.num_head_channels, -1)
205
+ .transpose(-1, -2)
206
+ .contiguous()
207
+ )
208
+ kv = (
209
+ self.to_kv(kv)
210
+ .reshape(2, self.num_heads, self.num_head_channels, -1)
211
+ .transpose(-1, -2)
212
+ .chunk(2)
213
+ )
214
+ k, v = (
215
+ kv[0].squeeze(0).contiguous(),
216
+ kv[1].squeeze(0).contiguous(),
217
+ )
218
+
219
+ else:
220
+ qkv = (
221
+ self.to_qkv(q)
222
+ .reshape(3, self.num_heads, self.num_head_channels, -1)
223
+ .transpose(-1, -2)
224
+ .chunk(3)
225
+ )
226
+ q, k, v = (
227
+ qkv[0].squeeze(0).contiguous(),
228
+ qkv[1].squeeze(0).contiguous(),
229
+ qkv[2].squeeze(0).contiguous(),
230
+ )
231
+
232
+ h = F.scaled_dot_product_attention(q, k, v, dropout_p=self.p_dropout).transpose(
233
+ -1, -2
234
+ )
235
+ h = h.reshape(N, -1, L).contiguous()
236
+ h = self.linear(h)
237
+
238
+ x = q + h
239
+ h = self.proj_out(x)
240
+
241
+ return x + h
modules/diffusion/unet/basic.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch.nn as nn
7
+ from abc import abstractmethod
8
+
9
+
10
+ class UNetBlock(nn.Module):
11
+ r"""Any module where forward() takes timestep embeddings as a second argument."""
12
+
13
+ @abstractmethod
14
+ def forward(self, x, emb):
15
+ r"""Apply the module to `x` given `emb` timestep embeddings."""
modules/diffusion/unet/resblock.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from .basic import UNetBlock
10
+ from modules.general.utils import (
11
+ append_dims,
12
+ ConvNd,
13
+ normalization,
14
+ zero_module,
15
+ )
16
+
17
+
18
+ class ResBlock(UNetBlock):
19
+ r"""A residual block that can optionally change the number of channels.
20
+
21
+ Args:
22
+ channels: the number of input channels.
23
+ emb_channels: the number of timestep embedding channels.
24
+ dropout: the rate of dropout.
25
+ out_channels: if specified, the number of out channels.
26
+ use_conv: if True and out_channels is specified, use a spatial
27
+ convolution instead of a smaller 1x1 convolution to change the
28
+ channels in the skip connection.
29
+ dims: determines if the signal is 1D, 2D, or 3D.
30
+ up: if True, use this block for upsampling.
31
+ down: if True, use this block for downsampling.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ channels,
37
+ emb_channels,
38
+ dropout: float = 0.0,
39
+ out_channels=None,
40
+ use_conv=False,
41
+ use_scale_shift_norm=False,
42
+ dims=2,
43
+ up=False,
44
+ down=False,
45
+ ):
46
+ super().__init__()
47
+ self.channels = channels
48
+ self.emb_channels = emb_channels
49
+ self.dropout = dropout
50
+ self.out_channels = out_channels or channels
51
+ self.use_conv = use_conv
52
+ self.use_scale_shift_norm = use_scale_shift_norm
53
+
54
+ self.in_layers = nn.Sequential(
55
+ normalization(channels),
56
+ nn.SiLU(),
57
+ ConvNd(dims, channels, self.out_channels, 3, padding=1),
58
+ )
59
+
60
+ self.updown = up or down
61
+
62
+ if up:
63
+ self.h_upd = Upsample(channels, False, dims)
64
+ self.x_upd = Upsample(channels, False, dims)
65
+ elif down:
66
+ self.h_upd = Downsample(channels, False, dims)
67
+ self.x_upd = Downsample(channels, False, dims)
68
+ else:
69
+ self.h_upd = self.x_upd = nn.Identity()
70
+
71
+ self.emb_layers = nn.Sequential(
72
+ nn.SiLU(),
73
+ ConvNd(
74
+ dims,
75
+ emb_channels,
76
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
77
+ 1,
78
+ ),
79
+ )
80
+ self.out_layers = nn.Sequential(
81
+ normalization(self.out_channels),
82
+ nn.SiLU(),
83
+ nn.Dropout(p=dropout),
84
+ zero_module(
85
+ ConvNd(dims, self.out_channels, self.out_channels, 3, padding=1)
86
+ ),
87
+ )
88
+
89
+ if self.out_channels == channels:
90
+ self.skip_connection = nn.Identity()
91
+ elif use_conv:
92
+ self.skip_connection = ConvNd(
93
+ dims, channels, self.out_channels, 3, padding=1
94
+ )
95
+ else:
96
+ self.skip_connection = ConvNd(dims, channels, self.out_channels, 1)
97
+
98
+ def forward(self, x, emb):
99
+ """
100
+ Apply the block to a Tensor, conditioned on a timestep embedding.
101
+
102
+ x: an [N x C x ...] Tensor of features.
103
+ emb: an [N x emb_channels x ...] Tensor of timestep embeddings.
104
+ :return: an [N x C x ...] Tensor of outputs.
105
+ """
106
+ if self.updown:
107
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
108
+ h = in_rest(x)
109
+ h = self.h_upd(h)
110
+ x = self.x_upd(x)
111
+ h = in_conv(h)
112
+ else:
113
+ h = self.in_layers(x)
114
+ emb_out = self.emb_layers(emb)
115
+ emb_out = append_dims(emb_out, h.dim())
116
+ if self.use_scale_shift_norm:
117
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
118
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
119
+ h = out_norm(h) * (1 + scale) + shift
120
+ h = out_rest(h)
121
+ else:
122
+ h = h + emb_out
123
+ h = self.out_layers(h)
124
+ return self.skip_connection(x) + h
125
+
126
+
127
+ class Upsample(nn.Module):
128
+ r"""An upsampling layer with an optional convolution.
129
+
130
+ Args:
131
+ channels: channels in the inputs and outputs.
132
+ dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
133
+ upsampling occurs in the inner-two dimensions.
134
+ out_channels: if specified, the number of out channels.
135
+ """
136
+
137
+ def __init__(self, channels, dims=2, out_channels=None):
138
+ super().__init__()
139
+ self.channels = channels
140
+ self.out_channels = out_channels or channels
141
+ self.dims = dims
142
+ self.conv = ConvNd(dims, self.channels, self.out_channels, 3, padding=1)
143
+
144
+ def forward(self, x):
145
+ assert x.shape[1] == self.channels
146
+ if self.dims == 3:
147
+ x = F.interpolate(
148
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
149
+ )
150
+ else:
151
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
152
+ x = self.conv(x)
153
+ return x
154
+
155
+
156
+ class Downsample(nn.Module):
157
+ r"""A downsampling layer with an optional convolution.
158
+
159
+ Args:
160
+ channels: channels in the inputs and outputs.
161
+ dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
162
+ downsampling occurs in the inner-two dimensions.
163
+ out_channels: if specified, the number of output channels.
164
+ """
165
+
166
+ def __init__(self, channels, dims=2, out_channels=None):
167
+ super().__init__()
168
+ self.channels = channels
169
+ self.out_channels = out_channels or channels
170
+ self.dims = dims
171
+ stride = 2 if dims != 3 else (1, 2, 2)
172
+ self.op = ConvNd(
173
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
174
+ )
175
+
176
+ def forward(self, x):
177
+ assert x.shape[1] == self.channels
178
+ return self.op(x)
modules/diffusion/unet/unet.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from modules.encoder.position_encoder import PositionEncoder
10
+ from modules.general.utils import append_dims, ConvNd, normalization, zero_module
11
+ from .attention import AttentionBlock
12
+ from .resblock import Downsample, ResBlock, Upsample
13
+
14
+
15
+ class UNet(nn.Module):
16
+ r"""The full UNet model with attention and timestep embedding.
17
+
18
+ Args:
19
+ dims: determines if the signal is 1D (temporal), 2D(spatial).
20
+ in_channels: channels in the input Tensor.
21
+ model_channels: base channel count for the model.
22
+ out_channels: channels in the output Tensor.
23
+ num_res_blocks: number of residual blocks per downsample.
24
+ channel_mult: channel multiplier for each level of the UNet.
25
+ num_attn_blocks: number of attention blocks at place.
26
+ attention_resolutions: a collection of downsample rates at which attention will
27
+ take place. May be a set, list, or tuple. For example, if this contains 4,
28
+ then at 4x downsampling, attention will be used.
29
+ num_heads: the number of attention heads in each attention layer.
30
+ num_head_channels: if specified, ignore num_heads and instead use a fixed
31
+ channel width per attention head.
32
+ d_context: if specified, use for cross-attention channel project.
33
+ p_dropout: the dropout probability.
34
+ use_self_attention: Apply self attention before cross attention.
35
+ num_classes: if specified (as an int), then this model will be class-conditional
36
+ with ``num_classes`` classes.
37
+ use_extra_film: if specified, use an extra FiLM-like conditioning mechanism.
38
+ d_emb: if specified, use for FiLM-like conditioning.
39
+ use_scale_shift_norm: use a FiLM-like conditioning mechanism.
40
+ resblock_updown: use residual blocks for up/downsampling.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dims: int = 1,
46
+ in_channels: int = 100,
47
+ model_channels: int = 128,
48
+ out_channels: int = 100,
49
+ h_dim: int = 128,
50
+ num_res_blocks: int = 1,
51
+ channel_mult: tuple = (1, 2, 4),
52
+ num_attn_blocks: int = 1,
53
+ attention_resolutions: tuple = (1, 2, 4),
54
+ num_heads: int = 1,
55
+ num_head_channels: int = -1,
56
+ d_context: int = None,
57
+ context_hdim: int = 128,
58
+ p_dropout: float = 0.0,
59
+ num_classes: int = -1,
60
+ use_extra_film: str = None,
61
+ d_emb: int = None,
62
+ use_scale_shift_norm: bool = True,
63
+ resblock_updown: bool = False,
64
+ ):
65
+ super().__init__()
66
+
67
+ self.dims = dims
68
+ self.in_channels = in_channels
69
+ self.model_channels = model_channels
70
+ self.out_channels = out_channels
71
+ self.num_res_blocks = num_res_blocks
72
+ self.channel_mult = channel_mult
73
+ self.num_attn_blocks = num_attn_blocks
74
+ self.attention_resolutions = attention_resolutions
75
+ self.num_heads = num_heads
76
+ self.num_head_channels = num_head_channels
77
+ self.d_context = d_context
78
+ self.p_dropout = p_dropout
79
+ self.num_classes = num_classes
80
+ self.use_extra_film = use_extra_film
81
+ self.d_emb = d_emb
82
+ self.use_scale_shift_norm = use_scale_shift_norm
83
+ self.resblock_updown = resblock_updown
84
+
85
+ time_embed_dim = model_channels * 4
86
+ self.pos_enc = PositionEncoder(model_channels, time_embed_dim)
87
+
88
+ assert (
89
+ num_classes == -1 or use_extra_film is None
90
+ ), "You cannot set both num_classes and use_extra_film."
91
+
92
+ if self.num_classes > 0:
93
+ # TODO: if used for singer, norm should be 1, correct?
94
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim, max_norm=1.0)
95
+ elif use_extra_film is not None:
96
+ assert (
97
+ d_emb is not None
98
+ ), "d_emb must be specified if use_extra_film is not None"
99
+ assert use_extra_film in [
100
+ "add",
101
+ "concat",
102
+ ], f"use_extra_film only supported by add or concat. Your input is {use_extra_film}"
103
+ self.use_extra_film = use_extra_film
104
+ self.film_emb = ConvNd(dims, d_emb, time_embed_dim, 1)
105
+ if use_extra_film == "concat":
106
+ time_embed_dim *= 2
107
+
108
+ # Input blocks
109
+ ch = input_ch = int(channel_mult[0] * model_channels)
110
+ self.input_blocks = nn.ModuleList(
111
+ [UNetSequential(ConvNd(dims, in_channels, ch, 3, padding=1))]
112
+ )
113
+ self._feature_size = ch
114
+ input_block_chans = [ch]
115
+ ds = 1
116
+ for level, mult in enumerate(channel_mult):
117
+ for _ in range(num_res_blocks):
118
+ layers = [
119
+ ResBlock(
120
+ ch,
121
+ time_embed_dim,
122
+ p_dropout,
123
+ out_channels=int(mult * model_channels),
124
+ dims=dims,
125
+ use_scale_shift_norm=use_scale_shift_norm,
126
+ )
127
+ ]
128
+ ch = int(mult * model_channels)
129
+ if ds in attention_resolutions:
130
+ for _ in range(num_attn_blocks):
131
+ layers.append(
132
+ AttentionBlock(
133
+ ch,
134
+ num_heads=num_heads,
135
+ num_head_channels=num_head_channels,
136
+ encoder_channels=d_context,
137
+ dims=dims,
138
+ h_dim=h_dim // (level + 1),
139
+ encoder_hdim=context_hdim,
140
+ p_dropout=p_dropout,
141
+ )
142
+ )
143
+ self.input_blocks.append(UNetSequential(*layers))
144
+ self._feature_size += ch
145
+ input_block_chans.append(ch)
146
+ if level != len(channel_mult) - 1:
147
+ out_ch = ch
148
+ self.input_blocks.append(
149
+ UNetSequential(
150
+ ResBlock(
151
+ ch,
152
+ time_embed_dim,
153
+ p_dropout,
154
+ out_channels=out_ch,
155
+ dims=dims,
156
+ use_scale_shift_norm=use_scale_shift_norm,
157
+ down=True,
158
+ )
159
+ if resblock_updown
160
+ else Downsample(ch, dims=dims, out_channels=out_ch)
161
+ )
162
+ )
163
+ ch = out_ch
164
+ input_block_chans.append(ch)
165
+ ds *= 2
166
+ self._feature_size += ch
167
+
168
+ # Middle blocks
169
+ self.middle_block = UNetSequential(
170
+ ResBlock(
171
+ ch,
172
+ time_embed_dim,
173
+ p_dropout,
174
+ dims=dims,
175
+ use_scale_shift_norm=use_scale_shift_norm,
176
+ ),
177
+ AttentionBlock(
178
+ ch,
179
+ num_heads=num_heads,
180
+ num_head_channels=num_head_channels,
181
+ encoder_channels=d_context,
182
+ dims=dims,
183
+ h_dim=h_dim // (level + 1),
184
+ encoder_hdim=context_hdim,
185
+ p_dropout=p_dropout,
186
+ ),
187
+ ResBlock(
188
+ ch,
189
+ time_embed_dim,
190
+ p_dropout,
191
+ dims=dims,
192
+ use_scale_shift_norm=use_scale_shift_norm,
193
+ ),
194
+ )
195
+ self._feature_size += ch
196
+
197
+ # Output blocks
198
+ self.output_blocks = nn.ModuleList([])
199
+ for level, mult in tuple(enumerate(channel_mult))[::-1]:
200
+ for i in range(num_res_blocks + 1):
201
+ ich = input_block_chans.pop()
202
+ layers = [
203
+ ResBlock(
204
+ ch + ich,
205
+ time_embed_dim,
206
+ p_dropout,
207
+ out_channels=int(model_channels * mult),
208
+ dims=dims,
209
+ use_scale_shift_norm=use_scale_shift_norm,
210
+ )
211
+ ]
212
+ ch = int(model_channels * mult)
213
+ if ds in attention_resolutions:
214
+ for _ in range(num_attn_blocks):
215
+ layers.append(
216
+ AttentionBlock(
217
+ ch,
218
+ num_heads=num_heads,
219
+ num_head_channels=num_head_channels,
220
+ encoder_channels=d_context,
221
+ dims=dims,
222
+ h_dim=h_dim // (level + 1),
223
+ encoder_hdim=context_hdim,
224
+ p_dropout=p_dropout,
225
+ )
226
+ )
227
+ if level and i == num_res_blocks:
228
+ out_ch = ch
229
+ layers.append(
230
+ ResBlock(
231
+ ch,
232
+ time_embed_dim,
233
+ p_dropout,
234
+ out_channels=out_ch,
235
+ dims=dims,
236
+ use_scale_shift_norm=use_scale_shift_norm,
237
+ up=True,
238
+ )
239
+ if resblock_updown
240
+ else Upsample(ch, dims=dims, out_channels=out_ch)
241
+ )
242
+ ds //= 2
243
+ self.output_blocks.append(UNetSequential(*layers))
244
+ self._feature_size += ch
245
+
246
+ # Final proj out
247
+ self.out = nn.Sequential(
248
+ normalization(ch),
249
+ nn.SiLU(),
250
+ zero_module(ConvNd(dims, input_ch, out_channels, 3, padding=1)),
251
+ )
252
+
253
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
254
+ r"""Apply the model to an input batch.
255
+
256
+ Args:
257
+ x: an [N x C x ...] Tensor of inputs.
258
+ timesteps: a 1-D batch of timesteps, i.e. [N].
259
+ context: conditioning Tensor with shape of [N x ``d_context`` x ...] plugged
260
+ in via cross attention.
261
+ y: an [N] Tensor of labels, if **class-conditional**.
262
+ an [N x ``d_emb`` x ...] Tensor if **film-embed conditional**.
263
+
264
+ Returns:
265
+ an [N x C x ...] Tensor of outputs.
266
+ """
267
+ assert (y is None) or (
268
+ (y is not None)
269
+ and ((self.num_classes > 0) or (self.use_extra_film is not None))
270
+ ), f"y must be specified if num_classes or use_extra_film is not None. \nGot num_classes: {self.num_classes}\t\nuse_extra_film: {self.use_extra_film}\t\n"
271
+
272
+ hs = []
273
+ emb = self.pos_enc(timesteps)
274
+ emb = append_dims(emb, x.dim())
275
+
276
+ if self.num_classes > 0:
277
+ assert y.size() == (x.size(0),)
278
+ emb = emb + self.label_emb(y)
279
+ elif self.use_extra_film is not None:
280
+ assert y.size() == (x.size(0), self.d_emb, *x.size()[2:])
281
+ y = self.film_emb(y)
282
+ if self.use_extra_film == "add":
283
+ emb = emb + y
284
+ elif self.use_extra_film == "concat":
285
+ emb = torch.cat([emb, y], dim=1)
286
+
287
+ h = x
288
+ for module in self.input_blocks:
289
+ h = module(h, emb, context)
290
+ hs.append(h)
291
+ h = self.middle_block(h, emb, context)
292
+ for module in self.output_blocks:
293
+ h = torch.cat([h, hs.pop()], dim=1)
294
+ h = module(h, emb, context)
295
+
296
+ return self.out(h)
297
+
298
+
299
+ class UNetSequential(nn.Sequential):
300
+ r"""A sequential module that passes embeddings to the children that support it."""
301
+
302
+ def forward(self, x, emb=None, context=None):
303
+ for layer in self:
304
+ if isinstance(layer, ResBlock):
305
+ x = layer(x, emb)
306
+ elif isinstance(layer, AttentionBlock):
307
+ x = layer(x, context)
308
+ else:
309
+ x = layer(x)
310
+ return x
modules/distributions/__init__.py ADDED
File without changes
modules/distributions/distributions.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import numpy as np
8
+
9
+
10
+ class AbstractDistribution:
11
+ def sample(self):
12
+ raise NotImplementedError()
13
+
14
+ def mode(self):
15
+ raise NotImplementedError()
16
+
17
+
18
+ class DiracDistribution(AbstractDistribution):
19
+ def __init__(self, value):
20
+ self.value = value
21
+
22
+ def sample(self):
23
+ return self.value
24
+
25
+ def mode(self):
26
+ return self.value
27
+
28
+
29
+ class DiagonalGaussianDistribution(object):
30
+ def __init__(self, parameters, deterministic=False):
31
+ self.parameters = parameters
32
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
33
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
34
+ self.deterministic = deterministic
35
+ self.std = torch.exp(0.5 * self.logvar)
36
+ self.var = torch.exp(self.logvar)
37
+ if self.deterministic:
38
+ self.var = self.std = torch.zeros_like(self.mean).to(
39
+ device=self.parameters.device
40
+ )
41
+
42
+ def sample(self):
43
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
44
+ device=self.parameters.device
45
+ )
46
+ return x
47
+
48
+ def kl(self, other=None):
49
+ if self.deterministic:
50
+ return torch.Tensor([0.0])
51
+ else:
52
+ if other is None:
53
+ return 0.5 * torch.sum(
54
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
55
+ dim=[1, 2, 3],
56
+ )
57
+ else:
58
+ return 0.5 * torch.sum(
59
+ torch.pow(self.mean - other.mean, 2) / other.var
60
+ + self.var / other.var
61
+ - 1.0
62
+ - self.logvar
63
+ + other.logvar,
64
+ dim=[1, 2, 3],
65
+ )
66
+
67
+ def nll(self, sample, dims=[1, 2, 3]):
68
+ if self.deterministic:
69
+ return torch.Tensor([0.0])
70
+ logtwopi = np.log(2.0 * np.pi)
71
+ return 0.5 * torch.sum(
72
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
73
+ dim=dims,
74
+ )
75
+
76
+ def mode(self):
77
+ return self.mean
78
+
79
+
80
+ def normal_kl(mean1, logvar1, mean2, logvar2):
81
+ """
82
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
83
+ Compute the KL divergence between two gaussians.
84
+ Shapes are automatically broadcasted, so batches can be compared to
85
+ scalars, among other use cases.
86
+ """
87
+ tensor = None
88
+ for obj in (mean1, logvar1, mean2, logvar2):
89
+ if isinstance(obj, torch.Tensor):
90
+ tensor = obj
91
+ break
92
+ assert tensor is not None, "at least one argument must be a Tensor"
93
+
94
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
95
+ # Tensors, but it does not work for torch.exp().
96
+ logvar1, logvar2 = [
97
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
98
+ for x in (logvar1, logvar2)
99
+ ]
100
+
101
+ return 0.5 * (
102
+ -1.0
103
+ + logvar2
104
+ - logvar1
105
+ + torch.exp(logvar1 - logvar2)
106
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
107
+ )
modules/duration_predictor/__init__.py ADDED
File without changes
modules/duration_predictor/standard_duration_predictor.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/jaywalnut310/vits/blob/main/models.py
7
+
8
+ import torch
9
+ from torch import nn
10
+ from modules.base.base_module import LayerNorm
11
+
12
+
13
+ class DurationPredictor(nn.Module):
14
+ def __init__(
15
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
16
+ ):
17
+ super().__init__()
18
+
19
+ self.in_channels = in_channels
20
+ self.filter_channels = filter_channels
21
+ self.kernel_size = kernel_size
22
+ self.p_dropout = p_dropout
23
+ self.gin_channels = gin_channels
24
+
25
+ self.drop = nn.Dropout(p_dropout)
26
+ self.conv_1 = nn.Conv1d(
27
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
28
+ )
29
+ self.norm_1 = LayerNorm(filter_channels)
30
+ self.conv_2 = nn.Conv1d(
31
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
32
+ )
33
+ self.norm_2 = LayerNorm(filter_channels)
34
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
35
+
36
+ if gin_channels != 0:
37
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
38
+
39
+ def forward(self, x, x_mask, g=None):
40
+ x = torch.detach(x)
41
+ if g is not None:
42
+ g = torch.detach(g)
43
+ x = x + self.cond(g)
44
+ x = self.conv_1(x * x_mask)
45
+ x = torch.relu(x)
46
+ x = self.norm_1(x)
47
+ x = self.drop(x)
48
+ x = self.conv_2(x * x_mask)
49
+ x = torch.relu(x)
50
+ x = self.norm_2(x)
51
+ x = self.drop(x)
52
+ x = self.proj(x * x_mask)
53
+ return x * x_mask
modules/duration_predictor/stochastic_duration_predictor.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/jaywalnut310/vits/blob/main/models.pyimport torch
7
+
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ import math
11
+ from modules.flow.modules import *
12
+
13
+
14
+ class StochasticDurationPredictor(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_channels,
18
+ filter_channels,
19
+ kernel_size,
20
+ p_dropout,
21
+ n_flows=4,
22
+ gin_channels=0,
23
+ ):
24
+ super().__init__()
25
+ filter_channels = in_channels
26
+ self.in_channels = in_channels
27
+ self.filter_channels = filter_channels
28
+ self.kernel_size = kernel_size
29
+ self.p_dropout = p_dropout
30
+ self.n_flows = n_flows
31
+ self.gin_channels = gin_channels
32
+
33
+ self.log_flow = Log()
34
+ self.flows = nn.ModuleList()
35
+ self.flows.append(ElementwiseAffine(2))
36
+ for i in range(n_flows):
37
+ self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3))
38
+ self.flows.append(Flip())
39
+
40
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
41
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
42
+ self.post_convs = DDSConv(
43
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
44
+ )
45
+ self.post_flows = nn.ModuleList()
46
+ self.post_flows.append(ElementwiseAffine(2))
47
+ for i in range(4):
48
+ self.post_flows.append(
49
+ ConvFlow(2, filter_channels, kernel_size, n_layers=3)
50
+ )
51
+ self.post_flows.append(Flip())
52
+
53
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
54
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
55
+ self.convs = DDSConv(
56
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
57
+ )
58
+ if gin_channels != 0:
59
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
60
+
61
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
62
+ x = torch.detach(x)
63
+ x = self.pre(x)
64
+ if g is not None:
65
+ g = torch.detach(g)
66
+ x = x + self.cond(g)
67
+ x = self.convs(x, x_mask)
68
+ x = self.proj(x) * x_mask
69
+
70
+ if not reverse:
71
+ flows = self.flows
72
+ assert w is not None
73
+
74
+ logdet_tot_q = 0
75
+ h_w = self.post_pre(w)
76
+ h_w = self.post_convs(h_w, x_mask)
77
+ h_w = self.post_proj(h_w) * x_mask
78
+ e_q = (
79
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
80
+ * x_mask
81
+ )
82
+ z_q = e_q
83
+ for flow in self.post_flows:
84
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
85
+ logdet_tot_q += logdet_q
86
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
87
+ u = torch.sigmoid(z_u) * x_mask
88
+ z0 = (w - u) * x_mask
89
+ logdet_tot_q += torch.sum(
90
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
91
+ )
92
+ logq = (
93
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
94
+ - logdet_tot_q
95
+ )
96
+
97
+ logdet_tot = 0
98
+ z0, logdet = self.log_flow(z0, x_mask)
99
+ logdet_tot += logdet
100
+ z = torch.cat([z0, z1], 1)
101
+ for flow in flows:
102
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
103
+ logdet_tot = logdet_tot + logdet
104
+ nll = (
105
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
106
+ - logdet_tot
107
+ )
108
+ return nll + logq
109
+ else:
110
+ flows = list(reversed(self.flows))
111
+ flows = flows[:-2] + [flows[-1]]
112
+ z = (
113
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
114
+ * noise_scale
115
+ )
116
+ for flow in flows:
117
+ z = flow(z, x_mask, g=x, reverse=reverse)
118
+ z0, z1 = torch.split(z, [1, 1], 1)
119
+ logw = z0
120
+ return logw
modules/encoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .token_encoder import TokenEmbedding
modules/encoder/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (206 Bytes). View file
 
modules/encoder/__pycache__/token_encoder.cpython-39.pyc ADDED
Binary file (1.08 kB). View file
 
modules/encoder/condition_encoder.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torchaudio.models import Conformer
10
+ from models.svc.transformer.transformer import PositionalEncoding
11
+
12
+ from utils.f0 import f0_to_coarse
13
+
14
+
15
+ class ContentEncoder(nn.Module):
16
+ def __init__(self, cfg, input_dim, output_dim):
17
+ super().__init__()
18
+ self.cfg = cfg
19
+
20
+ assert input_dim != 0
21
+ self.nn = nn.Linear(input_dim, output_dim)
22
+
23
+ # Introduce conformer or not
24
+ if (
25
+ "use_conformer_for_content_features" in cfg
26
+ and cfg.use_conformer_for_content_features
27
+ ):
28
+ self.pos_encoder = PositionalEncoding(input_dim)
29
+ self.conformer = Conformer(
30
+ input_dim=input_dim,
31
+ num_heads=2,
32
+ ffn_dim=256,
33
+ num_layers=6,
34
+ depthwise_conv_kernel_size=3,
35
+ )
36
+ else:
37
+ self.conformer = None
38
+
39
+ def forward(self, x, length=None):
40
+ # x: (N, seq_len, input_dim) -> (N, seq_len, output_dim)
41
+ if self.conformer:
42
+ x = self.pos_encoder(x)
43
+ x, _ = self.conformer(x, length)
44
+ return self.nn(x)
45
+
46
+
47
+ class MelodyEncoder(nn.Module):
48
+ def __init__(self, cfg):
49
+ super().__init__()
50
+ self.cfg = cfg
51
+
52
+ self.input_dim = self.cfg.input_melody_dim
53
+ self.output_dim = self.cfg.output_melody_dim
54
+ self.n_bins = self.cfg.n_bins_melody
55
+ self.pitch_min = self.cfg.pitch_min
56
+ self.pitch_max = self.cfg.pitch_max
57
+
58
+ if self.input_dim != 0:
59
+ if self.n_bins == 0:
60
+ # Not use quantization
61
+ self.nn = nn.Linear(self.input_dim, self.output_dim)
62
+ else:
63
+ self.f0_min = cfg.f0_min
64
+ self.f0_max = cfg.f0_max
65
+
66
+ self.nn = nn.Embedding(
67
+ num_embeddings=self.n_bins,
68
+ embedding_dim=self.output_dim,
69
+ padding_idx=None,
70
+ )
71
+ self.uv_embedding = nn.Embedding(2, self.output_dim)
72
+ # self.conformer = Conformer(
73
+ # input_dim=self.output_dim,
74
+ # num_heads=4,
75
+ # ffn_dim=128,
76
+ # num_layers=4,
77
+ # depthwise_conv_kernel_size=3,
78
+ # )
79
+
80
+ def forward(self, x, uv=None, length=None):
81
+ # x: (N, frame_len)
82
+ # print(x.shape)
83
+ if self.n_bins == 0:
84
+ x = x.unsqueeze(-1)
85
+ else:
86
+ x = f0_to_coarse(x, self.n_bins, self.f0_min, self.f0_max)
87
+ x = self.nn(x)
88
+ if uv is not None:
89
+ uv = self.uv_embedding(uv)
90
+ x = x + uv
91
+ # x, _ = self.conformer(x, length)
92
+ return x
93
+
94
+
95
+ class LoudnessEncoder(nn.Module):
96
+ def __init__(self, cfg):
97
+ super().__init__()
98
+ self.cfg = cfg
99
+
100
+ self.input_dim = self.cfg.input_loudness_dim
101
+ self.output_dim = self.cfg.output_loudness_dim
102
+ self.n_bins = self.cfg.n_bins_loudness
103
+
104
+ if self.input_dim != 0:
105
+ if self.n_bins == 0:
106
+ # Not use quantization
107
+ self.nn = nn.Linear(self.input_dim, self.output_dim)
108
+ else:
109
+ # TODO: set trivially now
110
+ self.loudness_min = 1e-30
111
+ self.loudness_max = 1.5
112
+
113
+ if cfg.use_log_loudness:
114
+ self.energy_bins = nn.Parameter(
115
+ torch.exp(
116
+ torch.linspace(
117
+ np.log(self.loudness_min),
118
+ np.log(self.loudness_max),
119
+ self.n_bins - 1,
120
+ )
121
+ ),
122
+ requires_grad=False,
123
+ )
124
+
125
+ self.nn = nn.Embedding(
126
+ num_embeddings=self.n_bins,
127
+ embedding_dim=self.output_dim,
128
+ padding_idx=None,
129
+ )
130
+
131
+ def forward(self, x):
132
+ # x: (N, frame_len)
133
+ if self.n_bins == 0:
134
+ x = x.unsqueeze(-1)
135
+ else:
136
+ x = torch.bucketize(x, self.energy_bins)
137
+ return self.nn(x)
138
+
139
+
140
+ class SingerEncoder(nn.Module):
141
+ def __init__(self, cfg):
142
+ super().__init__()
143
+ self.cfg = cfg
144
+
145
+ self.input_dim = 1
146
+ self.output_dim = self.cfg.output_singer_dim
147
+
148
+ self.nn = nn.Embedding(
149
+ num_embeddings=cfg.singer_table_size,
150
+ embedding_dim=self.output_dim,
151
+ padding_idx=None,
152
+ )
153
+
154
+ def forward(self, x):
155
+ # x: (N, 1) -> (N, 1, output_dim)
156
+ return self.nn(x)
157
+
158
+
159
+ class ConditionEncoder(nn.Module):
160
+ def __init__(self, cfg):
161
+ super().__init__()
162
+ self.cfg = cfg
163
+
164
+ self.merge_mode = cfg.merge_mode
165
+
166
+ if cfg.use_whisper:
167
+ self.whisper_encoder = ContentEncoder(
168
+ self.cfg, self.cfg.whisper_dim, self.cfg.content_encoder_dim
169
+ )
170
+
171
+ if cfg.use_contentvec:
172
+ self.contentvec_encoder = ContentEncoder(
173
+ self.cfg, self.cfg.contentvec_dim, self.cfg.content_encoder_dim
174
+ )
175
+
176
+ if cfg.use_mert:
177
+ self.mert_encoder = ContentEncoder(
178
+ self.cfg, self.cfg.mert_dim, self.cfg.content_encoder_dim
179
+ )
180
+
181
+ if cfg.use_wenet:
182
+ self.wenet_encoder = ContentEncoder(
183
+ self.cfg, self.cfg.wenet_dim, self.cfg.content_encoder_dim
184
+ )
185
+
186
+ self.melody_encoder = MelodyEncoder(self.cfg)
187
+ self.loudness_encoder = LoudnessEncoder(self.cfg)
188
+ if cfg.use_spkid:
189
+ self.singer_encoder = SingerEncoder(self.cfg)
190
+
191
+ def forward(self, x):
192
+ outputs = []
193
+
194
+ if "frame_pitch" in x.keys():
195
+ if "frame_uv" not in x.keys():
196
+ x["frame_uv"] = None
197
+ pitch_enc_out = self.melody_encoder(
198
+ x["frame_pitch"], uv=x["frame_uv"], length=x["target_len"]
199
+ )
200
+ outputs.append(pitch_enc_out)
201
+
202
+ if "frame_energy" in x.keys():
203
+ loudness_enc_out = self.loudness_encoder(x["frame_energy"])
204
+ outputs.append(loudness_enc_out)
205
+
206
+ if "whisper_feat" in x.keys():
207
+ # whisper_feat: [b, T, 1024]
208
+ whiser_enc_out = self.whisper_encoder(
209
+ x["whisper_feat"], length=x["target_len"]
210
+ )
211
+ outputs.append(whiser_enc_out)
212
+ seq_len = whiser_enc_out.shape[1]
213
+
214
+ if "contentvec_feat" in x.keys():
215
+ contentvec_enc_out = self.contentvec_encoder(
216
+ x["contentvec_feat"], length=x["target_len"]
217
+ )
218
+ outputs.append(contentvec_enc_out)
219
+ seq_len = contentvec_enc_out.shape[1]
220
+
221
+ if "mert_feat" in x.keys():
222
+ mert_enc_out = self.mert_encoder(x["mert_feat"], length=x["target_len"])
223
+ outputs.append(mert_enc_out)
224
+ seq_len = mert_enc_out.shape[1]
225
+
226
+ if "wenet_feat" in x.keys():
227
+ wenet_enc_out = self.wenet_encoder(x["wenet_feat"], length=x["target_len"])
228
+ outputs.append(wenet_enc_out)
229
+ seq_len = wenet_enc_out.shape[1]
230
+
231
+ if "spk_id" in x.keys():
232
+ speaker_enc_out = self.singer_encoder(x["spk_id"]) # [b, 1, 384]
233
+ assert (
234
+ "whisper_feat" in x.keys()
235
+ or "contentvec_feat" in x.keys()
236
+ or "mert_feat" in x.keys()
237
+ or "wenet_feat" in x.keys()
238
+ )
239
+ singer_info = speaker_enc_out.expand(-1, seq_len, -1)
240
+ outputs.append(singer_info)
241
+
242
+ encoder_output = None
243
+ if self.merge_mode == "concat":
244
+ encoder_output = torch.cat(outputs, dim=-1)
245
+ if self.merge_mode == "add":
246
+ # (#modules, N, seq_len, output_dim)
247
+ outputs = torch.cat([out[None, :, :, :] for out in outputs], dim=0)
248
+ # (N, seq_len, output_dim)
249
+ encoder_output = torch.sum(outputs, dim=0)
250
+
251
+ return encoder_output
modules/encoder/conv_encoder.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn.utils import spectral_norm
10
+ from modules.generic.conv import Conv1d
11
+
12
+
13
+ class ConvEncoder(nn.Module):
14
+ def __init__(self, in_channels, z_channels, spk_channels, num_dilation_layer=10):
15
+ super(ConvEncoder, self).__init__()
16
+
17
+ self.in_channels = in_channels
18
+ self.z_channels = z_channels
19
+ self.spk_channels = spk_channels
20
+
21
+ self.pre_process = Conv1d(in_channels, 512, kernel_size=3)
22
+
23
+ self.dilated_conv_layers = nn.ModuleList()
24
+ for i in range(num_dilation_layer):
25
+ dilation = 2**i
26
+ self.dilated_conv_layers.append(
27
+ DilatedConvBlock(512, 512, z_channels, spk_channels, dilation)
28
+ )
29
+
30
+ def forward(self, inputs, z, s):
31
+ inputs = inputs.transpose(1, 2)
32
+ outputs = self.pre_process(inputs)
33
+ print(inputs.shape)
34
+ for layer in self.dilated_conv_layers:
35
+ outputs = layer(outputs, z, s)
36
+
37
+ encoder_outputs = outputs.transpose(1, 2)
38
+ return encoder_outputs
39
+
40
+
41
+ class DilatedConvBlock(nn.Module):
42
+ """A stack of dilated convolutions interspersed
43
+ with batch normalisation and ReLU activations"""
44
+
45
+ def __init__(self, in_channels, out_channels, z_channels, s_channels, dilation):
46
+ super(DilatedConvBlock, self).__init__()
47
+
48
+ self.in_channels = in_channels
49
+ self.out_channels = out_channels
50
+ self.z_channels = z_channels
51
+ self.s_channels = s_channels
52
+
53
+ self.conv1d = Conv1d(
54
+ in_channels, out_channels, kernel_size=3, dilation=dilation
55
+ )
56
+ self.batch_layer = BatchNorm1dLayer(out_channels, s_channels, z_channels)
57
+
58
+ def forward(self, inputs, z, s):
59
+ outputs = self.conv1d(inputs)
60
+ outputs = self.batch_layer(outputs, z, s)
61
+ return F.relu(outputs)
62
+
63
+
64
+ class BatchNorm1dLayer(nn.Module):
65
+ """The latents z and speaker embedding s modulate the scale and
66
+ shift parameters of the batch normalisation layers"""
67
+
68
+ def __init__(self, num_features, s_channels=128, z_channels=128):
69
+ super().__init__()
70
+
71
+ self.num_features = num_features
72
+ self.s_channels = s_channels
73
+ self.z_channels = z_channels
74
+ self.batch_nrom = nn.BatchNorm1d(num_features, affine=False)
75
+
76
+ self.scale_layer = spectral_norm(nn.Linear(z_channels, num_features))
77
+ self.scale_layer.weight.data.normal_(1, 0.02) # Initialise scale at N(1, 0.02)
78
+ self.scale_layer.bias.data.zero_() # Initialise bias at 0
79
+
80
+ self.shift_layer = spectral_norm(nn.Linear(s_channels, num_features))
81
+ self.shift_layer.weight.data.normal_(1, 0.02) # Initialise scale at N(1, 0.02)
82
+ self.shift_layer.bias.data.zero_() # Initialise bias at 0
83
+
84
+ def forward(self, inputs, z, s):
85
+ outputs = self.batch_nrom(inputs)
86
+ scale = self.scale_layer(z)
87
+ scale = scale.view(-1, self.num_features, 1)
88
+
89
+ shift = self.shift_layer(s)
90
+ shift = shift.view(-1, self.num_features, 1)
91
+
92
+ outputs = scale * outputs + shift
93
+
94
+ return outputs
95
+
96
+
97
+ if __name__ == "__main__":
98
+ model = ConvEncoder(256, 64, 64)
99
+ encoder_inputs = torch.randn(2, 256, 10)
100
+ z = torch.randn(2, 64)
101
+ speaker = torch.randn(1, 64)
102
+ outputs, duration = model(encoder_inputs, z, speaker)
103
+ print(outputs.shape, duration.shape)
modules/encoder/position_encoder.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from modules.general.utils import Linear
12
+
13
+
14
+ class PositionEncoder(nn.Module):
15
+ r"""Encoder of positional embedding, generates PE and then
16
+ feed into 2 full-connected layers with ``SiLU``.
17
+
18
+ Args:
19
+ d_raw_emb: The dimension of raw embedding vectors.
20
+ d_out: The dimension of output embedding vectors, default to ``d_raw_emb``.
21
+ d_mlp: The dimension of hidden layer in MLP, default to ``d_raw_emb`` * 4.
22
+ activation_function: The activation function used in MLP, default to ``SiLU``.
23
+ n_layer: The number of layers in MLP, default to 2.
24
+ max_period: controls the minimum frequency of the embeddings.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ d_raw_emb: int = 128,
30
+ d_out: int = None,
31
+ d_mlp: int = None,
32
+ activation_function: str = "SiLU",
33
+ n_layer: int = 2,
34
+ max_period: int = 10000,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.d_raw_emb = d_raw_emb
39
+ self.d_out = d_raw_emb if d_out is None else d_out
40
+ self.d_mlp = d_raw_emb * 4 if d_mlp is None else d_mlp
41
+ self.n_layer = n_layer
42
+ self.max_period = max_period
43
+
44
+ if activation_function.lower() == "silu":
45
+ self.activation_function = "SiLU"
46
+ elif activation_function.lower() == "relu":
47
+ self.activation_function = "ReLU"
48
+ elif activation_function.lower() == "gelu":
49
+ self.activation_function = "GELU"
50
+ else:
51
+ raise ValueError("activation_function must be one of SiLU, ReLU, GELU")
52
+ self.activation_function = activation_function
53
+
54
+ tmp = [Linear(self.d_raw_emb, self.d_mlp), getattr(nn, activation_function)()]
55
+ for _ in range(self.n_layer - 1):
56
+ tmp.append(Linear(self.d_mlp, self.d_mlp))
57
+ tmp.append(getattr(nn, activation_function)())
58
+ tmp.append(Linear(self.d_mlp, self.d_out))
59
+
60
+ self.out = nn.Sequential(*tmp)
61
+
62
+ def forward(self, steps: torch.Tensor) -> torch.Tensor:
63
+ r"""Create and return sinusoidal timestep embeddings directly.
64
+
65
+ Args:
66
+ steps: a 1D Tensor of N indices, one per batch element.
67
+ These may be fractional.
68
+
69
+ Returns:
70
+ an [N x ``d_out``] Tensor of positional embeddings.
71
+ """
72
+
73
+ half = self.d_raw_emb // 2
74
+ freqs = torch.exp(
75
+ -math.log(self.max_period)
76
+ / half
77
+ * torch.arange(half, dtype=torch.float32, device=steps.device)
78
+ )
79
+ args = steps[:, None].float() * freqs[None]
80
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
81
+ if self.d_raw_emb % 2:
82
+ embedding = torch.cat(
83
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
84
+ )
85
+ return self.out(embedding)
modules/encoder/token_encoder.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/lifeiteng/vall-e
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ class TokenEmbedding(nn.Module):
13
+ def __init__(self, dim_model: int, vocab_size: int, dropout: float = 0.0):
14
+ super().__init__()
15
+ self.dropout = nn.Dropout(p=dropout)
16
+ self.word_embeddings = nn.Embedding(vocab_size, dim_model)
17
+
18
+ @property
19
+ def weight(self) -> torch.Tensor:
20
+ return self.word_embeddings.weight
21
+
22
+ def forward(self, x: torch.Tensor):
23
+ x = self.word_embeddings(x)
24
+ x = self.dropout(x)
25
+ return x
modules/flow/modules.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/jaywalnut310/vits/
7
+
8
+ import math
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+
13
+ from torch.nn import Conv1d
14
+ from torch.nn.utils import weight_norm, remove_weight_norm
15
+
16
+ from utils.util import *
17
+ from modules.transformer.transforms import (
18
+ piecewise_rational_quadratic_transform,
19
+ )
20
+ from modules.base.base_module import LayerNorm
21
+
22
+ LRELU_SLOPE = 0.1
23
+
24
+
25
+ class DDSConv(nn.Module):
26
+ """
27
+ Dialted and Depth-Separable Convolution
28
+ """
29
+
30
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
31
+ super().__init__()
32
+ self.channels = channels
33
+ self.kernel_size = kernel_size
34
+ self.n_layers = n_layers
35
+ self.p_dropout = p_dropout
36
+
37
+ self.drop = nn.Dropout(p_dropout)
38
+ self.convs_sep = nn.ModuleList()
39
+ self.convs_1x1 = nn.ModuleList()
40
+ self.norms_1 = nn.ModuleList()
41
+ self.norms_2 = nn.ModuleList()
42
+ for i in range(n_layers):
43
+ dilation = kernel_size**i
44
+ padding = (kernel_size * dilation - dilation) // 2
45
+ self.convs_sep.append(
46
+ nn.Conv1d(
47
+ channels,
48
+ channels,
49
+ kernel_size,
50
+ groups=channels,
51
+ dilation=dilation,
52
+ padding=padding,
53
+ )
54
+ )
55
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
56
+ self.norms_1.append(LayerNorm(channels))
57
+ self.norms_2.append(LayerNorm(channels))
58
+
59
+ def forward(self, x, x_mask, g=None):
60
+ if g is not None:
61
+ x = x + g
62
+ for i in range(self.n_layers):
63
+ y = self.convs_sep[i](x * x_mask)
64
+ y = self.norms_1[i](y)
65
+ y = F.gelu(y)
66
+ y = self.convs_1x1[i](y)
67
+ y = self.norms_2[i](y)
68
+ y = F.gelu(y)
69
+ y = self.drop(y)
70
+ x = x + y
71
+ return x * x_mask
72
+
73
+
74
+ class WN(torch.nn.Module):
75
+ def __init__(
76
+ self,
77
+ hidden_channels,
78
+ kernel_size,
79
+ dilation_rate,
80
+ n_layers,
81
+ gin_channels=0,
82
+ p_dropout=0,
83
+ ):
84
+ super(WN, self).__init__()
85
+ assert kernel_size % 2 == 1
86
+ self.hidden_channels = hidden_channels
87
+ self.kernel_size = (kernel_size,)
88
+ self.dilation_rate = dilation_rate
89
+ self.n_layers = n_layers
90
+ self.gin_channels = gin_channels
91
+ self.p_dropout = p_dropout
92
+
93
+ self.in_layers = torch.nn.ModuleList()
94
+ self.res_skip_layers = torch.nn.ModuleList()
95
+ self.drop = nn.Dropout(p_dropout)
96
+
97
+ if gin_channels != 0:
98
+ cond_layer = torch.nn.Conv1d(
99
+ gin_channels, 2 * hidden_channels * n_layers, 1
100
+ )
101
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
102
+
103
+ for i in range(n_layers):
104
+ dilation = dilation_rate**i
105
+ padding = int((kernel_size * dilation - dilation) / 2)
106
+ in_layer = torch.nn.Conv1d(
107
+ hidden_channels,
108
+ 2 * hidden_channels,
109
+ kernel_size,
110
+ dilation=dilation,
111
+ padding=padding,
112
+ )
113
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
114
+ self.in_layers.append(in_layer)
115
+
116
+ # last one is not necessary
117
+ if i < n_layers - 1:
118
+ res_skip_channels = 2 * hidden_channels
119
+ else:
120
+ res_skip_channels = hidden_channels
121
+
122
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
123
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
124
+ self.res_skip_layers.append(res_skip_layer)
125
+
126
+ def forward(self, x, x_mask, g=None, **kwargs):
127
+ output = torch.zeros_like(x)
128
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
129
+
130
+ if g is not None:
131
+ g = self.cond_layer(g)
132
+
133
+ for i in range(self.n_layers):
134
+ x_in = self.in_layers[i](x)
135
+ if g is not None:
136
+ cond_offset = i * 2 * self.hidden_channels
137
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
138
+ else:
139
+ g_l = torch.zeros_like(x_in)
140
+
141
+ acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
142
+ acts = self.drop(acts)
143
+
144
+ res_skip_acts = self.res_skip_layers[i](acts)
145
+ if i < self.n_layers - 1:
146
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
147
+ x = (x + res_acts) * x_mask
148
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
149
+ else:
150
+ output = output + res_skip_acts
151
+ return output * x_mask
152
+
153
+ def remove_weight_norm(self):
154
+ if self.gin_channels != 0:
155
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
156
+ for l in self.in_layers:
157
+ torch.nn.utils.remove_weight_norm(l)
158
+ for l in self.res_skip_layers:
159
+ torch.nn.utils.remove_weight_norm(l)
160
+
161
+
162
+ class ResBlock1(torch.nn.Module):
163
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
164
+ super(ResBlock1, self).__init__()
165
+ self.convs1 = nn.ModuleList(
166
+ [
167
+ weight_norm(
168
+ Conv1d(
169
+ channels,
170
+ channels,
171
+ kernel_size,
172
+ 1,
173
+ dilation=dilation[0],
174
+ padding=get_padding(kernel_size, dilation[0]),
175
+ )
176
+ ),
177
+ weight_norm(
178
+ Conv1d(
179
+ channels,
180
+ channels,
181
+ kernel_size,
182
+ 1,
183
+ dilation=dilation[1],
184
+ padding=get_padding(kernel_size, dilation[1]),
185
+ )
186
+ ),
187
+ weight_norm(
188
+ Conv1d(
189
+ channels,
190
+ channels,
191
+ kernel_size,
192
+ 1,
193
+ dilation=dilation[2],
194
+ padding=get_padding(kernel_size, dilation[2]),
195
+ )
196
+ ),
197
+ ]
198
+ )
199
+ self.convs1.apply(init_weights)
200
+
201
+ self.convs2 = nn.ModuleList(
202
+ [
203
+ weight_norm(
204
+ Conv1d(
205
+ channels,
206
+ channels,
207
+ kernel_size,
208
+ 1,
209
+ dilation=1,
210
+ padding=get_padding(kernel_size, 1),
211
+ )
212
+ ),
213
+ weight_norm(
214
+ Conv1d(
215
+ channels,
216
+ channels,
217
+ kernel_size,
218
+ 1,
219
+ dilation=1,
220
+ padding=get_padding(kernel_size, 1),
221
+ )
222
+ ),
223
+ weight_norm(
224
+ Conv1d(
225
+ channels,
226
+ channels,
227
+ kernel_size,
228
+ 1,
229
+ dilation=1,
230
+ padding=get_padding(kernel_size, 1),
231
+ )
232
+ ),
233
+ ]
234
+ )
235
+ self.convs2.apply(init_weights)
236
+
237
+ def forward(self, x, x_mask=None):
238
+ for c1, c2 in zip(self.convs1, self.convs2):
239
+ xt = F.leaky_relu(x, LRELU_SLOPE)
240
+ if x_mask is not None:
241
+ xt = xt * x_mask
242
+ xt = c1(xt)
243
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
244
+ if x_mask is not None:
245
+ xt = xt * x_mask
246
+ xt = c2(xt)
247
+ x = xt + x
248
+ if x_mask is not None:
249
+ x = x * x_mask
250
+ return x
251
+
252
+ def remove_weight_norm(self):
253
+ for l in self.convs1:
254
+ remove_weight_norm(l)
255
+ for l in self.convs2:
256
+ remove_weight_norm(l)
257
+
258
+
259
+ class ResBlock2(torch.nn.Module):
260
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
261
+ super(ResBlock2, self).__init__()
262
+ self.convs = nn.ModuleList(
263
+ [
264
+ weight_norm(
265
+ Conv1d(
266
+ channels,
267
+ channels,
268
+ kernel_size,
269
+ 1,
270
+ dilation=dilation[0],
271
+ padding=get_padding(kernel_size, dilation[0]),
272
+ )
273
+ ),
274
+ weight_norm(
275
+ Conv1d(
276
+ channels,
277
+ channels,
278
+ kernel_size,
279
+ 1,
280
+ dilation=dilation[1],
281
+ padding=get_padding(kernel_size, dilation[1]),
282
+ )
283
+ ),
284
+ ]
285
+ )
286
+ self.convs.apply(init_weights)
287
+
288
+ def forward(self, x, x_mask=None):
289
+ for c in self.convs:
290
+ xt = F.leaky_relu(x, LRELU_SLOPE)
291
+ if x_mask is not None:
292
+ xt = xt * x_mask
293
+ xt = c(xt)
294
+ x = xt + x
295
+ if x_mask is not None:
296
+ x = x * x_mask
297
+ return x
298
+
299
+ def remove_weight_norm(self):
300
+ for l in self.convs:
301
+ remove_weight_norm(l)
302
+
303
+
304
+ class Log(nn.Module):
305
+ def forward(self, x, x_mask, reverse=False, **kwargs):
306
+ if not reverse:
307
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
308
+ logdet = torch.sum(-y, [1, 2])
309
+ return y, logdet
310
+ else:
311
+ x = torch.exp(x) * x_mask
312
+ return x
313
+
314
+
315
+ class Flip(nn.Module):
316
+ def forward(self, x, *args, reverse=False, **kwargs):
317
+ x = torch.flip(x, [1])
318
+ if not reverse:
319
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
320
+ return x, logdet
321
+ else:
322
+ return x
323
+
324
+
325
+ class ElementwiseAffine(nn.Module):
326
+ def __init__(self, channels):
327
+ super().__init__()
328
+ self.channels = channels
329
+ self.m = nn.Parameter(torch.zeros(channels, 1))
330
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
331
+
332
+ def forward(self, x, x_mask, reverse=False, **kwargs):
333
+ if not reverse:
334
+ y = self.m + torch.exp(self.logs) * x
335
+ y = y * x_mask
336
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
337
+ return y, logdet
338
+ else:
339
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
340
+ return x
341
+
342
+
343
+ class ResidualCouplingLayer(nn.Module):
344
+ def __init__(
345
+ self,
346
+ channels,
347
+ hidden_channels,
348
+ kernel_size,
349
+ dilation_rate,
350
+ n_layers,
351
+ p_dropout=0,
352
+ gin_channels=0,
353
+ mean_only=False,
354
+ ):
355
+ assert channels % 2 == 0, "channels should be divisible by 2"
356
+ super().__init__()
357
+ self.channels = channels
358
+ self.hidden_channels = hidden_channels
359
+ self.kernel_size = kernel_size
360
+ self.dilation_rate = dilation_rate
361
+ self.n_layers = n_layers
362
+ self.half_channels = channels // 2
363
+ self.mean_only = mean_only
364
+
365
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
366
+ self.enc = WN(
367
+ hidden_channels,
368
+ kernel_size,
369
+ dilation_rate,
370
+ n_layers,
371
+ p_dropout=p_dropout,
372
+ gin_channels=gin_channels,
373
+ )
374
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
375
+ self.post.weight.data.zero_()
376
+ self.post.bias.data.zero_()
377
+
378
+ def forward(self, x, x_mask, g=None, reverse=False):
379
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
380
+ h = self.pre(x0) * x_mask
381
+ h = self.enc(h, x_mask, g=g)
382
+ stats = self.post(h) * x_mask
383
+ if not self.mean_only:
384
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
385
+ else:
386
+ m = stats
387
+ logs = torch.zeros_like(m)
388
+
389
+ if not reverse:
390
+ x1 = m + x1 * torch.exp(logs) * x_mask
391
+ x = torch.cat([x0, x1], 1)
392
+ logdet = torch.sum(logs, [1, 2])
393
+ return x, logdet
394
+ else:
395
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
396
+ x = torch.cat([x0, x1], 1)
397
+ return x
398
+
399
+
400
+ class ConvFlow(nn.Module):
401
+ def __init__(
402
+ self,
403
+ in_channels,
404
+ filter_channels,
405
+ kernel_size,
406
+ n_layers,
407
+ num_bins=10,
408
+ tail_bound=5.0,
409
+ ):
410
+ super().__init__()
411
+ self.in_channels = in_channels
412
+ self.filter_channels = filter_channels
413
+ self.kernel_size = kernel_size
414
+ self.n_layers = n_layers
415
+ self.num_bins = num_bins
416
+ self.tail_bound = tail_bound
417
+ self.half_channels = in_channels // 2
418
+
419
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
420
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
421
+ self.proj = nn.Conv1d(
422
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
423
+ )
424
+ self.proj.weight.data.zero_()
425
+ self.proj.bias.data.zero_()
426
+
427
+ def forward(self, x, x_mask, g=None, reverse=False):
428
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
429
+ h = self.pre(x0)
430
+ h = self.convs(h, x_mask, g=g)
431
+ h = self.proj(h) * x_mask
432
+
433
+ b, c, t = x0.shape
434
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
435
+
436
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
437
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
438
+ self.filter_channels
439
+ )
440
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
441
+
442
+ x1, logabsdet = piecewise_rational_quadratic_transform(
443
+ x1,
444
+ unnormalized_widths,
445
+ unnormalized_heights,
446
+ unnormalized_derivatives,
447
+ inverse=reverse,
448
+ tails="linear",
449
+ tail_bound=self.tail_bound,
450
+ )
451
+
452
+ x = torch.cat([x0, x1], 1) * x_mask
453
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
454
+ if not reverse:
455
+ return x, logdet
456
+ else:
457
+ return x
modules/general/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .input_strategies import PromptedFeatures, PromptedPrecomputedFeatures
2
+ from .scaling import BalancedDoubleSwish
3
+ from .utils import Transpose
modules/general/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (339 Bytes). View file
 
modules/general/__pycache__/input_strategies.cpython-39.pyc ADDED
Binary file (5.64 kB). View file
 
modules/general/__pycache__/scaling.cpython-39.pyc ADDED
Binary file (39.7 kB). View file
 
modules/general/__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.54 kB). View file
 
modules/general/input_strategies.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ # This code is modified from
8
+ # https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/data/input_strategies.py
9
+ import random
10
+ from collections import defaultdict
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from typing import Tuple, Type
13
+
14
+ from lhotse import CutSet
15
+ from lhotse.dataset.collation import collate_features
16
+ from lhotse.dataset.input_strategies import (
17
+ ExecutorType,
18
+ PrecomputedFeatures,
19
+ _get_executor,
20
+ )
21
+ from lhotse.utils import fastcopy
22
+
23
+
24
+ class PromptedFeatures:
25
+ def __init__(self, prompts, features):
26
+ self.prompts = prompts
27
+ self.features = features
28
+
29
+ def to(self, device):
30
+ return PromptedFeatures(self.prompts.to(device), self.features.to(device))
31
+
32
+ def sum(self):
33
+ return self.features.sum()
34
+
35
+ @property
36
+ def ndim(self):
37
+ return self.features.ndim
38
+
39
+ @property
40
+ def data(self):
41
+ return (self.prompts, self.features)
42
+
43
+
44
+ class PromptedPrecomputedFeatures(PrecomputedFeatures):
45
+ def __init__(
46
+ self,
47
+ dataset: str,
48
+ cuts: CutSet,
49
+ num_workers: int = 0,
50
+ executor_type: Type[ExecutorType] = ThreadPoolExecutor,
51
+ ) -> None:
52
+ super().__init__(num_workers, executor_type)
53
+ self.utt2neighbors = self._create_utt2neighbors(dataset, cuts)
54
+
55
+ def __call__(self, cuts: CutSet) -> Tuple[PromptedFeatures, PromptedFeatures]:
56
+ features, features_lens = self._collate_features(cuts)
57
+ prompts, prompts_lens = self._collate_prompts(cuts)
58
+ return PromptedFeatures(prompts, features), PromptedFeatures(
59
+ prompts_lens, features_lens
60
+ )
61
+
62
+ def _create_utt2neighbors(self, dataset, cuts):
63
+ utt2neighbors = defaultdict(lambda: [])
64
+ utt2cut = {cut.id: cut for cut in cuts}
65
+ if dataset.lower() == "libritts":
66
+ self._process_libritts_dataset(utt2neighbors, utt2cut, cuts)
67
+ elif dataset.lower() == "ljspeech":
68
+ self._process_ljspeech_dataset(utt2neighbors, utt2cut, cuts)
69
+ else:
70
+ raise ValueError("Unsupported dataset")
71
+ return utt2neighbors
72
+
73
+ def _process_libritts_dataset(self, utt2neighbors, utt2cut, cuts):
74
+ speaker2utts = defaultdict(lambda: [])
75
+ for cut in cuts:
76
+ speaker = cut.supervisions[0].speaker
77
+ speaker2utts[speaker].append(cut.id)
78
+
79
+ for spk, uttids in speaker2utts.items():
80
+ sorted_uttids = sorted(uttids)
81
+ if len(sorted_uttids) == 1:
82
+ utt2neighbors[sorted_uttids[0]].append(utt2cut[sorted_uttids[0]])
83
+ continue
84
+
85
+ utt2prevutt = dict(
86
+ zip(sorted_uttids, [sorted_uttids[1]] + sorted_uttids[:-1])
87
+ )
88
+ utt2postutt = dict(zip(sorted_uttids[:-1], sorted_uttids[1:]))
89
+ for utt in sorted_uttids:
90
+ if utt in utt2prevutt:
91
+ utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
92
+ if utt in utt2postutt:
93
+ utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])
94
+
95
+ def _process_ljspeech_dataset(self, utt2neighbors, utt2cut, cuts):
96
+ uttids = [cut.id for cut in cuts]
97
+ if len(uttids) == 1:
98
+ utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
99
+ return
100
+
101
+ utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
102
+ utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
103
+ for utt in uttids:
104
+ prevutt, postutt = utt2prevutt.get(utt), utt2postutt.get(utt)
105
+ if prevutt and utt[:5] == prevutt[:5]:
106
+ utt2neighbors[utt].append(utt2cut[prevutt])
107
+ if postutt and utt[:5] == postutt[:5]:
108
+ utt2neighbors[utt].append(utt2cut[postutt])
109
+
110
+ def _collate_features(self, cuts):
111
+ return collate_features(
112
+ cuts,
113
+ executor=_get_executor(self.num_workers, executor_type=self._executor_type),
114
+ )
115
+
116
+ def _collate_prompts(self, cuts):
117
+ prompts_cuts = []
118
+ for k, cut in enumerate(cuts):
119
+ prompts_cut = random.choice(self.utt2neighbors[cut.id])
120
+ prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))
121
+
122
+ mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
123
+ prompts_cuts = CutSet(
124
+ cuts={k: cut for k, cut in enumerate(prompts_cuts)}
125
+ ).truncate(max_duration=mini_duration, offset_type="random", preserve_id=False)
126
+
127
+ return collate_features(
128
+ prompts_cuts,
129
+ executor=_get_executor(self.num_workers, executor_type=self._executor_type),
130
+ )
modules/general/scaling.py ADDED
@@ -0,0 +1,1349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module is modified from https://github.com/Plachtaa/VALL-E-X/blob/3faaf8ccadb154d63b38070caf518ce9309ea0f4/modules/scaling.py
2
+
3
+
4
+ import logging
5
+ import random
6
+ import math
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+
13
+
14
+ class Transpose(nn.Identity):
15
+ """(N, T, D) -> (N, D, T)"""
16
+
17
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
18
+ return input.transpose(1, 2)
19
+
20
+
21
+ class ActivationBalancerFunction(torch.autograd.Function):
22
+ @staticmethod
23
+ def forward(
24
+ ctx,
25
+ x: Tensor,
26
+ scale_factor: Tensor,
27
+ sign_factor: Optional[Tensor],
28
+ channel_dim: int,
29
+ ) -> Tensor:
30
+ if channel_dim < 0:
31
+ channel_dim += x.ndim
32
+ ctx.channel_dim = channel_dim
33
+ xgt0 = x > 0
34
+ if sign_factor is None:
35
+ ctx.save_for_backward(xgt0, scale_factor)
36
+ else:
37
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
38
+ return x
39
+
40
+ @staticmethod
41
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
42
+ if len(ctx.saved_tensors) == 3:
43
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
44
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
45
+ scale_factor = scale_factor.unsqueeze(-1)
46
+ sign_factor = sign_factor.unsqueeze(-1)
47
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
48
+ else:
49
+ xgt0, scale_factor = ctx.saved_tensors
50
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
51
+ scale_factor = scale_factor.unsqueeze(-1)
52
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
53
+ neg_delta_grad = x_grad.abs() * factor
54
+ return (
55
+ x_grad - neg_delta_grad,
56
+ None,
57
+ None,
58
+ None,
59
+ )
60
+
61
+
62
+ def _compute_scale_factor(
63
+ x: Tensor,
64
+ channel_dim: int,
65
+ min_abs: float,
66
+ max_abs: float,
67
+ gain_factor: float,
68
+ max_factor: float,
69
+ ) -> Tensor:
70
+ if channel_dim < 0:
71
+ channel_dim += x.ndim
72
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
73
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
74
+
75
+ if min_abs == 0.0:
76
+ below_threshold = 0.0
77
+ else:
78
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
79
+ # x_abs)_mean , min_abs.
80
+ below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
81
+ min=0, max=max_factor
82
+ )
83
+
84
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
85
+ min=0, max=max_factor
86
+ )
87
+
88
+ return below_threshold - above_threshold
89
+
90
+
91
+ def _compute_sign_factor(
92
+ x: Tensor,
93
+ channel_dim: int,
94
+ min_positive: float,
95
+ max_positive: float,
96
+ gain_factor: float,
97
+ max_factor: float,
98
+ ) -> Tensor:
99
+ if channel_dim < 0:
100
+ channel_dim += x.ndim
101
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
102
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
103
+ if min_positive == 0.0:
104
+ factor1 = 0.0
105
+ else:
106
+ # 0 if proportion_positive >= min_positive, else can be
107
+ # as large as max_factor.
108
+ factor1 = (
109
+ (min_positive - proportion_positive) * (gain_factor / min_positive)
110
+ ).clamp_(min=0, max=max_factor)
111
+
112
+ if max_positive == 1.0:
113
+ factor2 = 0.0
114
+ else:
115
+ # 0 if self.proportion_positive <= max_positive, else can be
116
+ # as large as -max_factor.
117
+ factor2 = (
118
+ (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
119
+ ).clamp_(min=0, max=max_factor)
120
+ sign_factor = factor1 - factor2
121
+ # require min_positive != 0 or max_positive != 1:
122
+ assert not isinstance(sign_factor, float)
123
+ return sign_factor
124
+
125
+
126
+ class ActivationScaleBalancerFunction(torch.autograd.Function):
127
+ """
128
+ This object is used in class ActivationBalancer when the user specified
129
+ min_positive=0, max_positive=1, so there are no constraints on the signs
130
+ of the activations and only the absolute value has a constraint.
131
+ """
132
+
133
+ @staticmethod
134
+ def forward(
135
+ ctx,
136
+ x: Tensor,
137
+ sign_factor: Tensor,
138
+ scale_factor: Tensor,
139
+ channel_dim: int,
140
+ ) -> Tensor:
141
+ if channel_dim < 0:
142
+ channel_dim += x.ndim
143
+ ctx.channel_dim = channel_dim
144
+ xgt0 = x > 0
145
+ ctx.save_for_backward(xgt0, sign_factor, scale_factor)
146
+ return x
147
+
148
+ @staticmethod
149
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
150
+ xgt0, sign_factor, scale_factor = ctx.saved_tensors
151
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
152
+ sign_factor = sign_factor.unsqueeze(-1)
153
+ scale_factor = scale_factor.unsqueeze(-1)
154
+
155
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
156
+ neg_delta_grad = x_grad.abs() * factor
157
+ return (
158
+ x_grad - neg_delta_grad,
159
+ None,
160
+ None,
161
+ None,
162
+ )
163
+
164
+
165
+ class RandomClampFunction(torch.autograd.Function):
166
+ @staticmethod
167
+ def forward(
168
+ ctx,
169
+ x: Tensor,
170
+ min: Optional[float],
171
+ max: Optional[float],
172
+ prob: float,
173
+ reflect: float,
174
+ ) -> Tensor:
175
+ x_clamped = torch.clamp(x, min=min, max=max)
176
+ mask = torch.rand_like(x) < prob
177
+ ans = torch.where(mask, x_clamped, x)
178
+ if x.requires_grad:
179
+ ctx.save_for_backward(ans == x)
180
+ ctx.reflect = reflect
181
+ if reflect != 0.0:
182
+ ans = ans * (1.0 + reflect) - (x * reflect)
183
+ return ans
184
+
185
+ @staticmethod
186
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
187
+ (is_same,) = ctx.saved_tensors
188
+ x_grad = ans_grad * is_same.to(ans_grad.dtype)
189
+ reflect = ctx.reflect
190
+ if reflect != 0.0:
191
+ x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
192
+ return x_grad, None, None, None, None
193
+
194
+
195
+ def random_clamp(
196
+ x: Tensor,
197
+ min: Optional[float] = None,
198
+ max: Optional[float] = None,
199
+ prob: float = 0.5,
200
+ reflect: float = 0.0,
201
+ ):
202
+ return RandomClampFunction.apply(x, min, max, prob, reflect)
203
+
204
+
205
+ def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
206
+ """
207
+ A randomized way of casting a floating point value to half precision.
208
+ """
209
+ if x.dtype == torch.float16:
210
+ return x
211
+ x_abs = x.abs()
212
+ is_too_small = x_abs < min_abs
213
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
214
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
215
+ # for those elements].
216
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
217
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
218
+
219
+
220
+ class RandomGradFunction(torch.autograd.Function):
221
+ """
222
+ Does nothing in forward pass; in backward pass, gets rid of very small grads using
223
+ randomized approach that preserves expectations (intended to reduce roundoff).
224
+ """
225
+
226
+ @staticmethod
227
+ def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
228
+ ctx.min_abs = min_abs
229
+ return x
230
+
231
+ @staticmethod
232
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
233
+ if ans_grad.dtype == torch.float16:
234
+ return (
235
+ random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs),
236
+ None,
237
+ )
238
+ else:
239
+ return ans_grad, None
240
+
241
+
242
+ class RandomGrad(torch.nn.Module):
243
+ """
244
+ Gets rid of very small gradients using an expectation-preserving method, intended to increase
245
+ accuracy of training when using amp (automatic mixed precision)
246
+ """
247
+
248
+ def __init__(self, min_abs: float = 5.0e-06):
249
+ super(RandomGrad, self).__init__()
250
+ self.min_abs = min_abs
251
+
252
+ def forward(self, x: Tensor):
253
+ if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
254
+ return x
255
+ else:
256
+ return RandomGradFunction.apply(x, self.min_abs)
257
+
258
+
259
+ class SoftmaxFunction(torch.autograd.Function):
260
+ """
261
+ Tries to handle half-precision derivatives in a randomized way that should
262
+ be more accurate for training than the default behavior.
263
+ """
264
+
265
+ @staticmethod
266
+ def forward(ctx, x: Tensor, dim: int):
267
+ ans = x.softmax(dim=dim)
268
+ # if x dtype is float16, x.softmax() returns a float32 because
269
+ # (presumably) that op does not support float16, and autocast
270
+ # is enabled.
271
+ if torch.is_autocast_enabled():
272
+ ans = ans.to(torch.float16)
273
+ ctx.save_for_backward(ans)
274
+ ctx.x_dtype = x.dtype
275
+ ctx.dim = dim
276
+ return ans
277
+
278
+ @staticmethod
279
+ def backward(ctx, ans_grad: Tensor):
280
+ (ans,) = ctx.saved_tensors
281
+ with torch.cuda.amp.autocast(enabled=False):
282
+ ans_grad = ans_grad.to(torch.float32)
283
+ ans = ans.to(torch.float32)
284
+ x_grad = ans_grad * ans
285
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
286
+ return x_grad, None
287
+
288
+
289
+ def softmax(x: Tensor, dim: int):
290
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
291
+ return x.softmax(dim)
292
+
293
+ return SoftmaxFunction.apply(x, dim)
294
+
295
+
296
+ class MaxEigLimiterFunction(torch.autograd.Function):
297
+ @staticmethod
298
+ def forward(
299
+ ctx,
300
+ x: Tensor,
301
+ coeffs: Tensor,
302
+ direction: Tensor,
303
+ channel_dim: int,
304
+ grad_scale: float,
305
+ ) -> Tensor:
306
+ ctx.channel_dim = channel_dim
307
+ ctx.grad_scale = grad_scale
308
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
309
+ return x
310
+
311
+ @staticmethod
312
+ def backward(ctx, x_grad, *args):
313
+ with torch.enable_grad():
314
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
315
+ x_orig.requires_grad = True
316
+ num_channels = x_orig.shape[ctx.channel_dim]
317
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
318
+ new_direction.requires_grad = False
319
+ x = x - x.mean(dim=0)
320
+ x_var = (x**2).mean()
321
+ x_residual = x - coeffs * new_direction
322
+ x_residual_var = (x_residual**2).mean()
323
+ # `variance_proportion` is the proportion of the variance accounted for
324
+ # by the top eigen-direction. This is to be minimized.
325
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
326
+ variance_proportion.backward()
327
+ x_orig_grad = x_orig.grad
328
+ x_extra_grad = (
329
+ x_orig.grad
330
+ * ctx.grad_scale
331
+ * x_grad.norm()
332
+ / (x_orig_grad.norm() + 1.0e-20)
333
+ )
334
+ return x_grad + x_extra_grad.detach(), None, None, None, None
335
+
336
+
337
+ class BasicNorm(torch.nn.Module):
338
+ """
339
+ This is intended to be a simpler, and hopefully cheaper, replacement for
340
+ LayerNorm. The observation this is based on, is that Transformer-type
341
+ networks, especially with pre-norm, sometimes seem to set one of the
342
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
343
+ the LayerNorm because the output magnitude is then not strongly dependent
344
+ on the other (useful) features. Presumably the weight and bias of the
345
+ LayerNorm are required to allow it to do this.
346
+
347
+ So the idea is to introduce this large constant value as an explicit
348
+ parameter, that takes the role of the "eps" in LayerNorm, so the network
349
+ doesn't have to do this trick. We make the "eps" learnable.
350
+
351
+ Args:
352
+ num_channels: the number of channels, e.g. 512.
353
+ channel_dim: the axis/dimension corresponding to the channel,
354
+ interprted as an offset from the input's ndim if negative.
355
+ shis is NOT the num_channels; it should typically be one of
356
+ {-2, -1, 0, 1, 2, 3}.
357
+ eps: the initial "epsilon" that we add as ballast in:
358
+ scale = ((input_vec**2).mean() + epsilon)**-0.5
359
+ Note: our epsilon is actually large, but we keep the name
360
+ to indicate the connection with conventional LayerNorm.
361
+ learn_eps: if true, we learn epsilon; if false, we keep it
362
+ at the initial value.
363
+ eps_min: float
364
+ eps_max: float
365
+ """
366
+
367
+ def __init__(
368
+ self,
369
+ num_channels: int,
370
+ channel_dim: int = -1, # CAUTION: see documentation.
371
+ eps: float = 0.25,
372
+ learn_eps: bool = True,
373
+ eps_min: float = -3.0,
374
+ eps_max: float = 3.0,
375
+ ) -> None:
376
+ super(BasicNorm, self).__init__()
377
+ self.num_channels = num_channels
378
+ self.channel_dim = channel_dim
379
+ if learn_eps:
380
+ self.eps = nn.Parameter(torch.tensor(eps).log().detach())
381
+ else:
382
+ self.register_buffer("eps", torch.tensor(eps).log().detach())
383
+ self.eps_min = eps_min
384
+ self.eps_max = eps_max
385
+
386
+ def forward(self, x: Tensor) -> Tensor:
387
+ assert x.shape[self.channel_dim] == self.num_channels
388
+ eps = self.eps
389
+ if self.training and random.random() < 0.25:
390
+ # with probability 0.25, in training mode, clamp eps between the min
391
+ # and max; this will encourage it to learn parameters within the
392
+ # allowed range by making parameters that are outside the allowed
393
+ # range noisy.
394
+
395
+ # gradients to allow the parameter to get back into the allowed
396
+ # region if it happens to exit it.
397
+ eps = eps.clamp(min=self.eps_min, max=self.eps_max)
398
+ scales = (
399
+ torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp()
400
+ ) ** -0.5
401
+ return x * scales
402
+
403
+
404
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
405
+ """
406
+ Behaves like a constructor of a modified version of nn.Linear
407
+ that gives an easy way to set the default initial parameter scale.
408
+
409
+ Args:
410
+ Accepts the standard args and kwargs that nn.Linear accepts
411
+ e.g. in_features, out_features, bias=False.
412
+
413
+ initial_scale: you can override this if you want to increase
414
+ or decrease the initial magnitude of the module's output
415
+ (affects the initialization of weight_scale and bias_scale).
416
+ Another option, if you want to do something like this, is
417
+ to re-initialize the parameters.
418
+ """
419
+ ans = nn.Linear(*args, **kwargs)
420
+ with torch.no_grad():
421
+ ans.weight[:] *= initial_scale
422
+ if ans.bias is not None:
423
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
424
+ return ans
425
+
426
+
427
+ def ScaledConv1d(
428
+ *args,
429
+ initial_scale: float = 1.0,
430
+ kernel_size: int = 3,
431
+ padding: str = "same",
432
+ **kwargs,
433
+ ) -> nn.Conv1d:
434
+ """
435
+ Behaves like a constructor of a modified version of nn.Conv1d
436
+ that gives an easy way to set the default initial parameter scale.
437
+
438
+ Args:
439
+ Accepts the standard args and kwargs that nn.Linear accepts
440
+ e.g. in_features, out_features, bias=False.
441
+
442
+ initial_scale: you can override this if you want to increase
443
+ or decrease the initial magnitude of the module's output
444
+ (affects the initialization of weight_scale and bias_scale).
445
+ Another option, if you want to do something like this, is
446
+ to re-initialize the parameters.
447
+ """
448
+ ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
449
+ with torch.no_grad():
450
+ ans.weight[:] *= initial_scale
451
+ if ans.bias is not None:
452
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
453
+ return ans
454
+
455
+
456
+ def TransposeScaledConv1d(
457
+ *args,
458
+ initial_scale: float = 1.0,
459
+ kernel_size: int = 3,
460
+ padding: str = "same",
461
+ **kwargs,
462
+ ) -> nn.Sequential:
463
+ """
464
+ Transpose -> ScaledConv1d
465
+ """
466
+ return nn.Sequential(
467
+ Transpose(),
468
+ ScaledConv1d(
469
+ *args,
470
+ initial_scale=initial_scale,
471
+ kernel_size=kernel_size,
472
+ padding=padding,
473
+ **kwargs,
474
+ ),
475
+ )
476
+
477
+
478
+ def ScaledConv1dTranspose(
479
+ *args,
480
+ initial_scale: float = 1.0,
481
+ kernel_size: int = 3,
482
+ padding: str = "same",
483
+ **kwargs,
484
+ ) -> nn.Sequential:
485
+ """
486
+ Transpose -> ScaledConv1d
487
+ """
488
+ return nn.Sequential(
489
+ ScaledConv1d(
490
+ *args,
491
+ initial_scale=initial_scale,
492
+ kernel_size=kernel_size,
493
+ padding=padding,
494
+ **kwargs,
495
+ ),
496
+ Transpose(),
497
+ )
498
+
499
+
500
+ def TransposeConv1d(
501
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
502
+ ) -> nn.Sequential:
503
+ """
504
+ Transpose -> Conv1d
505
+ """
506
+ return nn.Sequential(
507
+ Transpose(),
508
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
509
+ )
510
+
511
+
512
+ def Conv1dTranspose(
513
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
514
+ ) -> nn.Sequential:
515
+ """
516
+ ScaledConv1d -> Transpose
517
+ """
518
+ return nn.Sequential(
519
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
520
+ Transpose(),
521
+ )
522
+
523
+
524
+ class SRLinear(nn.Linear):
525
+ """https://arxiv.org/abs/2303.06296
526
+ Stabilizing Transformer Training by Preventing Attention Entropy Collapse
527
+ """
528
+
529
+ def __init__(self, in_features, out_features, bias=True, **kwargs):
530
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
531
+ self.register_buffer(
532
+ "u", nn.functional.normalize(torch.randn(in_features), dim=0)
533
+ )
534
+ with torch.no_grad():
535
+ sigma = self.get_sigma()
536
+ self.register_buffer("spectral_norm", sigma)
537
+ self.sigma = nn.Parameter(torch.ones(1))
538
+
539
+ def get_sigma(self):
540
+ with torch.no_grad():
541
+ u = self.u
542
+ v = self.weight.mv(u)
543
+ v = nn.functional.normalize(v, dim=0)
544
+ u = self.weight.T.mv(v)
545
+ u = nn.functional.normalize(u, dim=0)
546
+ self.u.data.copy_(u)
547
+ return torch.einsum("c,cd,d->", v, self.weight, u)
548
+
549
+ def get_weight(self):
550
+ sigma = self.get_sigma()
551
+ if self.training:
552
+ self.spectral_norm.data.copy_(sigma)
553
+ weight = (self.sigma / sigma) * self.weight
554
+ return weight
555
+
556
+ def forward(self, x):
557
+ return nn.functional.linear(x, self.get_weight(), self.bias)
558
+
559
+
560
+ class SRConv1d(SRLinear):
561
+ def __init__(
562
+ self,
563
+ in_features,
564
+ out_features,
565
+ kernel_size,
566
+ stride: int = 1,
567
+ padding: str = "same",
568
+ bias: bool = True,
569
+ **kwargs,
570
+ ):
571
+ in_features = in_features * kernel_size
572
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
573
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
574
+ self.kernel_size = kernel_size
575
+ self.stride = stride
576
+ self.padding = padding
577
+
578
+ def forward(self, x):
579
+ in_features = self.in_features // self.kernel_size
580
+ weight = self.get_weight().view(
581
+ self.out_features, in_features, self.kernel_size
582
+ )
583
+ return nn.functional.conv1d(
584
+ x, weight, bias=self.bias, stride=self.stride, padding=self.padding
585
+ )
586
+
587
+
588
+ def TransposeSRConv1d(
589
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
590
+ ) -> nn.Sequential:
591
+ """
592
+ Transpose -> SRConv1d
593
+ """
594
+ return nn.Sequential(
595
+ Transpose(),
596
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
597
+ )
598
+
599
+
600
+ def SRConv1dTranspose(
601
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
602
+ ) -> nn.Sequential:
603
+ """
604
+ SRConv1d -> Transpose
605
+ """
606
+ return nn.Sequential(
607
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
608
+ Transpose(),
609
+ )
610
+
611
+
612
+ class ActivationBalancer(torch.nn.Module):
613
+ """
614
+ Modifies the backpropped derivatives of a function to try to encourage, for
615
+ each channel, that it is positive at least a proportion `threshold` of the
616
+ time. It does this by multiplying negative derivative values by up to
617
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
618
+ interpolated from 1 at the threshold to those extremal values when none
619
+ of the inputs are positive.
620
+
621
+ Args:
622
+ num_channels: the number of channels
623
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
624
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
625
+ min_positive: the minimum, per channel, of the proportion of the time
626
+ that (x > 0), below which we start to modify the derivatives.
627
+ max_positive: the maximum, per channel, of the proportion of the time
628
+ that (x > 0), above which we start to modify the derivatives.
629
+ max_factor: the maximum factor by which we modify the derivatives for
630
+ either the sign constraint or the magnitude constraint;
631
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
632
+ values in the range [0.98..1.02].
633
+ sign_gain_factor: determines the 'gain' with which we increase the
634
+ change in gradient once the constraints on min_positive and max_positive
635
+ are violated.
636
+ scale_gain_factor: determines the 'gain' with which we increase the
637
+ change in gradient once the constraints on min_abs and max_abs
638
+ are violated.
639
+ min_abs: the minimum average-absolute-value difference from the mean
640
+ value per channel, which we allow, before we start to modify
641
+ the derivatives to prevent this.
642
+ max_abs: the maximum average-absolute-value difference from the mean
643
+ value per channel, which we allow, before we start to modify
644
+ the derivatives to prevent this.
645
+ min_prob: determines the minimum probability with which we modify the
646
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
647
+ on each forward(). This is done randomly to prevent all layers
648
+ from doing it at the same time. Early in training we may use
649
+ higher probabilities than this; it will decay to this value.
650
+ """
651
+
652
+ def __init__(
653
+ self,
654
+ num_channels: int,
655
+ channel_dim: int,
656
+ min_positive: float = 0.05,
657
+ max_positive: float = 0.95,
658
+ max_factor: float = 0.04,
659
+ sign_gain_factor: float = 0.01,
660
+ scale_gain_factor: float = 0.02,
661
+ min_abs: float = 0.2,
662
+ max_abs: float = 100.0,
663
+ min_prob: float = 0.1,
664
+ ):
665
+ super(ActivationBalancer, self).__init__()
666
+ self.num_channels = num_channels
667
+ self.channel_dim = channel_dim
668
+ self.min_positive = min_positive
669
+ self.max_positive = max_positive
670
+ self.max_factor = max_factor
671
+ self.min_abs = min_abs
672
+ self.max_abs = max_abs
673
+ self.min_prob = min_prob
674
+ self.sign_gain_factor = sign_gain_factor
675
+ self.scale_gain_factor = scale_gain_factor
676
+
677
+ # count measures how many times the forward() function has been called.
678
+ # We occasionally sync this to a tensor called `count`, that exists to
679
+ # make sure it is synced to disk when we load and save the model.
680
+ self.cpu_count = 0
681
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
682
+
683
+ def forward(self, x: Tensor) -> Tensor:
684
+ if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
685
+ return _no_op(x)
686
+
687
+ count = self.cpu_count
688
+ self.cpu_count += 1
689
+
690
+ if random.random() < 0.01:
691
+ # Occasionally sync self.cpu_count with self.count.
692
+ # count affects the decay of 'prob'. don't do this on every iter,
693
+ # because syncing with the GPU is slow.
694
+ self.cpu_count = max(self.cpu_count, self.count.item())
695
+ self.count.fill_(self.cpu_count)
696
+
697
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
698
+ # a floor at min_prob (==0.1, by default)
699
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
700
+
701
+ if random.random() < prob:
702
+ sign_gain_factor = 0.5
703
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
704
+ sign_factor = _compute_sign_factor(
705
+ x,
706
+ self.channel_dim,
707
+ self.min_positive,
708
+ self.max_positive,
709
+ gain_factor=self.sign_gain_factor / prob,
710
+ max_factor=self.max_factor,
711
+ )
712
+ else:
713
+ sign_factor = None
714
+
715
+ scale_factor = _compute_scale_factor(
716
+ x.detach(),
717
+ self.channel_dim,
718
+ min_abs=self.min_abs,
719
+ max_abs=self.max_abs,
720
+ gain_factor=self.scale_gain_factor / prob,
721
+ max_factor=self.max_factor,
722
+ )
723
+ return ActivationBalancerFunction.apply(
724
+ x,
725
+ scale_factor,
726
+ sign_factor,
727
+ self.channel_dim,
728
+ )
729
+ else:
730
+ return _no_op(x)
731
+
732
+
733
+ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
734
+ """
735
+ Returns x unmodified, but in backprop will put a penalty for the excess of
736
+ the absolute values of elements of x over the limit "limit". E.g. if
737
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
738
+
739
+ Caution: the value of this penalty will be affected by grad scaling used
740
+ in automatic mixed precision training. For this reasons we use this,
741
+ it shouldn't really matter, or may even be helpful; we just use this
742
+ to disallow really implausible values of scores to be given to softmax.
743
+ """
744
+ x_sign = x.sign()
745
+ over_limit = (x.abs() - limit) > 0
746
+ # The following is a memory efficient way to penalize the absolute values of
747
+ # x that's over the limit. (The memory efficiency comes when you think
748
+ # about which items torch needs to cache for the autograd, and which ones it
749
+ # can throw away). The numerical value of aux_loss as computed here will
750
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
751
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
752
+ # limit).relu().
753
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
754
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
755
+ # sum() due to how with_loss() works.
756
+ x = with_loss(x, aux_loss)
757
+ # you must use x for something, or this will be ineffective.
758
+ return x
759
+
760
+
761
+ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
762
+ if x.ndim == 2:
763
+ return x.diag()
764
+ else:
765
+ (batch, dim, dim) = x.shape
766
+ x = x.reshape(batch, dim * dim)
767
+ x = x[:, :: dim + 1]
768
+ assert x.shape == (batch, dim)
769
+ return x
770
+
771
+
772
+ def _whitening_metric(x: Tensor, num_groups: int):
773
+ """
774
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
775
+ of the centered feature covariance are the same within each group's covariance matrix
776
+ and also between groups.
777
+ Args:
778
+ x: a Tensor of shape (*, num_channels)
779
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
780
+ Returns:
781
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
782
+ greater than 1.0 otherwise.
783
+ """
784
+ assert x.dtype != torch.float16
785
+ x = x.reshape(-1, x.shape[-1])
786
+ (num_frames, num_channels) = x.shape
787
+ assert num_channels % num_groups == 0
788
+ channels_per_group = num_channels // num_groups
789
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
790
+ # x now has shape (num_groups, num_frames, channels_per_group)
791
+ # subtract the mean so we use the centered, not uncentered, covariance.
792
+ # My experience has been that when we "mess with the gradients" like this,
793
+ # it's better not do anything that tries to move the mean around, because
794
+ # that can easily cause instability.
795
+ x = x - x.mean(dim=1, keepdim=True)
796
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
797
+ x_covar = torch.matmul(x.transpose(1, 2), x)
798
+ x_covar_mean_diag = _diag(x_covar).mean()
799
+ # the following expression is what we'd get if we took the matrix product
800
+ # of each covariance and measured the mean of its trace, i.e.
801
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
802
+ x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
803
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
804
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
805
+ return metric
806
+
807
+
808
+ class WhiteningPenaltyFunction(torch.autograd.Function):
809
+ @staticmethod
810
+ def forward(
811
+ ctx,
812
+ x: Tensor,
813
+ num_groups: int,
814
+ whitening_limit: float,
815
+ grad_scale: float,
816
+ ) -> Tensor:
817
+ ctx.save_for_backward(x)
818
+ ctx.num_groups = num_groups
819
+ ctx.whitening_limit = whitening_limit
820
+ ctx.grad_scale = grad_scale
821
+ return x
822
+
823
+ @staticmethod
824
+ def backward(ctx, x_grad: Tensor):
825
+ (x_orig,) = ctx.saved_tensors
826
+ with torch.enable_grad():
827
+ with torch.cuda.amp.autocast(enabled=False):
828
+ x_detached = x_orig.to(torch.float32).detach()
829
+ x_detached.requires_grad = True
830
+
831
+ metric = _whitening_metric(x_detached, ctx.num_groups)
832
+
833
+ if random.random() < 0.005 or __name__ == "__main__":
834
+ logging.info(
835
+ f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
836
+ f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
837
+ )
838
+
839
+ (metric - ctx.whitening_limit).relu().backward()
840
+ penalty_grad = x_detached.grad
841
+ scale = ctx.grad_scale * (
842
+ x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20)
843
+ )
844
+ penalty_grad = penalty_grad * scale
845
+ return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
846
+
847
+
848
+ class Whiten(nn.Module):
849
+ def __init__(
850
+ self,
851
+ num_groups: int,
852
+ whitening_limit: float,
853
+ prob: Union[float, Tuple[float, float]],
854
+ grad_scale: float,
855
+ ):
856
+ """
857
+ Args:
858
+ num_groups: the number of groups to divide the channel dim into before
859
+ whitening. We will attempt to make the feature covariance
860
+ within each group, after mean subtraction, as "white" as possible,
861
+ while having the same trace across all groups.
862
+ whitening_limit: a value greater than 1.0, that dictates how much
863
+ freedom we have to violate the constraints. 1.0 would mean perfectly
864
+ white, with exactly the same trace across groups; larger values
865
+ give more freedom. E.g. 2.0.
866
+ prob: the probability with which we apply the gradient modification
867
+ (also affects the grad scale). May be supplied as a float,
868
+ or as a pair (min_prob, max_prob)
869
+
870
+ grad_scale: determines the scale on the gradient term from this object,
871
+ relative to the rest of the gradient on the attention weights.
872
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
873
+ """
874
+ super(Whiten, self).__init__()
875
+ assert num_groups >= 1
876
+ assert whitening_limit >= 1
877
+ assert grad_scale >= 0
878
+ self.num_groups = num_groups
879
+ self.whitening_limit = whitening_limit
880
+ if isinstance(prob, float):
881
+ assert 0 < prob <= 1
882
+ self.prob = prob
883
+ else:
884
+ (self.min_prob, self.max_prob) = prob
885
+ assert 0 < self.min_prob < self.max_prob <= 1
886
+ self.prob = self.max_prob
887
+
888
+ self.grad_scale = grad_scale
889
+
890
+ def forward(self, x: Tensor) -> Tensor:
891
+ """
892
+ In the forward pass, this function just returns the input unmodified.
893
+ In the backward pass, it will modify the gradients to ensure that the
894
+ distribution in each group has close to (lambda times I) as the covariance
895
+ after mean subtraction, with the same lambda across groups.
896
+ For whitening_limit > 1, there will be more freedom to violate this
897
+ constraint.
898
+
899
+ Args:
900
+ x: the input of shape (*, num_channels)
901
+
902
+ Returns:
903
+ x, unmodified. You should make sure
904
+ you use the returned value, or the graph will be freed
905
+ and nothing will happen in backprop.
906
+ """
907
+ if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
908
+ return _no_op(x)
909
+ else:
910
+ if hasattr(self, "min_prob") and random.random() < 0.25:
911
+ # occasionally switch between min_prob and max_prob, based on whether
912
+ # we are above or below the threshold.
913
+ if (
914
+ _whitening_metric(x.to(torch.float32), self.num_groups)
915
+ > self.whitening_limit
916
+ ):
917
+ # there would be a change to the grad.
918
+ self.prob = self.max_prob
919
+ else:
920
+ self.prob = self.min_prob
921
+
922
+ return WhiteningPenaltyFunction.apply(
923
+ x, self.num_groups, self.whitening_limit, self.grad_scale
924
+ )
925
+
926
+
927
+ class WithLoss(torch.autograd.Function):
928
+ @staticmethod
929
+ def forward(ctx, x: Tensor, y: Tensor):
930
+ ctx.y_shape = y.shape
931
+ return x
932
+
933
+ @staticmethod
934
+ def backward(ctx, ans_grad: Tensor):
935
+ return ans_grad, torch.ones(
936
+ ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
937
+ )
938
+
939
+
940
+ def with_loss(x, y):
941
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
942
+ return x
943
+ # returns x but adds y.sum() to the loss function.
944
+ return WithLoss.apply(x, y)
945
+
946
+
947
+ def _no_op(x: Tensor) -> Tensor:
948
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
949
+ return x
950
+ else:
951
+ # a no-op function that will have a node in the autograd graph,
952
+ # to avoid certain bugs relating to backward hooks
953
+ return x.chunk(1, dim=-1)[0]
954
+
955
+
956
+ class Identity(torch.nn.Module):
957
+ def __init__(self):
958
+ super(Identity, self).__init__()
959
+
960
+ def forward(self, x):
961
+ return _no_op(x)
962
+
963
+
964
+ class MaxEig(torch.nn.Module):
965
+ """
966
+ Modifies the backpropped derivatives of a function to try to discourage
967
+ that any given direction in activation space accounts for more than
968
+ a specified proportion of the covariance (e.g. 0.2).
969
+
970
+
971
+ Args:
972
+ num_channels: the number of channels
973
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
974
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
975
+ max_var_per_eig: the maximum proportion of the variance of the
976
+ features/channels, after mean subtraction, that can come from
977
+ any given eigenvalue.
978
+ min_prob: the minimum probability with which we apply this during any invocation
979
+ of forward(), assuming last time we applied the constraint it was
980
+ not active; supplied for speed.
981
+ scale: determines the scale with which we modify the gradients, relative
982
+ to the existing / unmodified gradients
983
+ """
984
+
985
+ def __init__(
986
+ self,
987
+ num_channels: int,
988
+ channel_dim: int,
989
+ max_var_per_eig: float = 0.2,
990
+ min_prob: float = 0.01,
991
+ scale: float = 0.01,
992
+ ):
993
+ super(MaxEig, self).__init__()
994
+ self.num_channels = num_channels
995
+ self.channel_dim = channel_dim
996
+ self.scale = scale
997
+ assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
998
+ self.max_var_per_eig = max_var_per_eig
999
+
1000
+ # we figure out the dominant direction using the power method: starting with
1001
+ # a random vector, keep multiplying by the covariance and renormalizing.
1002
+ with torch.no_grad():
1003
+ # arbitrary.. would use randn() but want to leave the rest of the model's
1004
+ # random parameters unchanged for comparison
1005
+ direction = torch.arange(num_channels).to(torch.float)
1006
+ direction = direction / direction.norm()
1007
+ self.register_buffer("max_eig_direction", direction)
1008
+
1009
+ self.min_prob = min_prob
1010
+ # cur_prob is the current probability we'll use to apply the ActivationBalancer.
1011
+ # We'll regress this towards prob, each tiem we try to apply it and it is not
1012
+ # active.
1013
+ self.cur_prob = 1.0
1014
+
1015
+ def forward(self, x: Tensor) -> Tensor:
1016
+ if (
1017
+ torch.jit.is_scripting()
1018
+ or self.max_var_per_eig <= 0
1019
+ or random.random() > self.cur_prob
1020
+ or torch.jit.is_tracing()
1021
+ ):
1022
+ return _no_op(x)
1023
+
1024
+ with torch.cuda.amp.autocast(enabled=False):
1025
+ eps = 1.0e-20
1026
+ orig_x = x
1027
+ x = x.to(torch.float32)
1028
+ with torch.no_grad():
1029
+ x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels)
1030
+ x = x - x.mean(dim=0)
1031
+ new_direction, coeffs = self._find_direction_coeffs(
1032
+ x, self.max_eig_direction
1033
+ )
1034
+ x_var = (x**2).mean()
1035
+ x_residual = x - coeffs * new_direction
1036
+ x_residual_var = (x_residual**2).mean()
1037
+
1038
+ # `variance_proportion` is the proportion of the variance accounted for
1039
+ # by the top eigen-direction.
1040
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
1041
+
1042
+ # ensure new direction is nonzero even if x == 0, by including `direction`.
1043
+ self._set_direction(0.1 * self.max_eig_direction + new_direction)
1044
+
1045
+ if random.random() < 0.01 or __name__ == "__main__":
1046
+ logging.info(
1047
+ f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
1048
+ )
1049
+
1050
+ if variance_proportion >= self.max_var_per_eig:
1051
+ # The constraint is active. Note, we should quite rarely
1052
+ # reach here, only near the beginning of training if we are
1053
+ # starting to diverge, should this constraint be active.
1054
+ cur_prob = self.cur_prob
1055
+ self.cur_prob = 1.0 # next time, do the update with probability 1.0.
1056
+ return MaxEigLimiterFunction.apply(
1057
+ orig_x, coeffs, new_direction, self.channel_dim, self.scale
1058
+ )
1059
+ else:
1060
+ # let self.cur_prob exponentially approach self.min_prob, as
1061
+ # long as the constraint is inactive.
1062
+ self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
1063
+ return orig_x
1064
+
1065
+ def _set_direction(self, direction: Tensor):
1066
+ """
1067
+ Sets self.max_eig_direction to a normalized version of `direction`
1068
+ """
1069
+ direction = direction.detach()
1070
+ direction = direction / direction.norm()
1071
+ direction_sum = direction.sum().item()
1072
+ if direction_sum - direction_sum == 0: # no inf/nan
1073
+ self.max_eig_direction[:] = direction
1074
+ else:
1075
+ logging.info(
1076
+ f"Warning: sum of direction in MaxEig is {direction_sum}, "
1077
+ "num_channels={self.num_channels}, channel_dim={self.channel_dim}"
1078
+ )
1079
+
1080
+ def _find_direction_coeffs(
1081
+ self, x: Tensor, prev_direction: Tensor
1082
+ ) -> Tuple[Tensor, Tensor, Tensor]:
1083
+ """
1084
+ Figure out (an approximation to) the proportion of the variance of a set of
1085
+ feature vectors that can be attributed to the top eigen-direction.
1086
+ Args:
1087
+ x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
1088
+ prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
1089
+ of the top eigen-direction, or a random direction if this is the first
1090
+ iteration. Does not have to be normalized, but should be nonzero.
1091
+
1092
+ Returns: (cur_direction, coeffs), where:
1093
+ cur_direction: a Tensor of shape (num_channels,) that is the current
1094
+ estimate of the top eigen-direction.
1095
+ coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
1096
+ approximately minimizes, (x - coeffs * cur_direction).norm()
1097
+ """
1098
+ (num_frames, num_channels) = x.shape
1099
+ assert num_channels > 1 and num_frames > 1
1100
+ assert prev_direction.shape == (num_channels,)
1101
+ # `coeffs` are the coefficients of `prev_direction` in x.
1102
+ # actually represent the coeffs up to a constant positive factor.
1103
+ coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
1104
+ cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20)
1105
+ return cur_direction, coeffs
1106
+
1107
+
1108
+ class DoubleSwishFunction(torch.autograd.Function):
1109
+ """
1110
+ double_swish(x) = x * torch.sigmoid(x-1)
1111
+ This is a definition, originally motivated by its close numerical
1112
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
1113
+
1114
+ Memory-efficient derivative computation:
1115
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
1116
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
1117
+ Now, s'(x) = s(x) * (1-s(x)).
1118
+ double_swish'(x) = x * s'(x) + s(x).
1119
+ = x * s(x) * (1-s(x)) + s(x).
1120
+ = double_swish(x) * (1-s(x)) + s(x)
1121
+ ... so we just need to remember s(x) but not x itself.
1122
+ """
1123
+
1124
+ @staticmethod
1125
+ def forward(ctx, x: Tensor) -> Tensor:
1126
+ requires_grad = x.requires_grad
1127
+ x_dtype = x.dtype
1128
+ if x.dtype == torch.float16:
1129
+ x = x.to(torch.float32)
1130
+
1131
+ s = torch.sigmoid(x - 1.0)
1132
+ y = x * s
1133
+
1134
+ if requires_grad:
1135
+ deriv = y * (1 - s) + s
1136
+ # notes on derivative of x * sigmoid(x - 1):
1137
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
1138
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
1139
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
1140
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
1141
+ # floors), should be expectation-preserving.
1142
+ floor = -0.043637
1143
+ ceil = 1.2
1144
+ d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
1145
+ deriv
1146
+ )
1147
+ if __name__ == "__main__":
1148
+ # for self-testing only.
1149
+ assert d_scaled.min() >= 0.0
1150
+ assert d_scaled.max() < 256.0
1151
+ d_int = d_scaled.to(torch.uint8)
1152
+ ctx.save_for_backward(d_int)
1153
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1154
+ y = y.to(torch.float16)
1155
+ return y
1156
+
1157
+ @staticmethod
1158
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1159
+ (d,) = ctx.saved_tensors
1160
+ # the same constants as used in forward pass.
1161
+ floor = -0.043637
1162
+ ceil = 1.2
1163
+ d = d * ((ceil - floor) / 255.0) + floor
1164
+ return y_grad * d
1165
+
1166
+
1167
+ class DoubleSwish(torch.nn.Module):
1168
+ def forward(self, x: Tensor) -> Tensor:
1169
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
1170
+ that we approximate closely with x * sigmoid(x-1).
1171
+ """
1172
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1173
+ return x * torch.sigmoid(x - 1.0)
1174
+ return DoubleSwishFunction.apply(x)
1175
+
1176
+
1177
+ def BalancedDoubleSwish(
1178
+ d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
1179
+ ) -> nn.Sequential:
1180
+ """
1181
+ ActivationBalancer -> DoubleSwish
1182
+ """
1183
+ balancer = ActivationBalancer(
1184
+ d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
1185
+ )
1186
+ return nn.Sequential(
1187
+ balancer,
1188
+ DoubleSwish(),
1189
+ )
1190
+
1191
+
1192
+ def _test_max_eig():
1193
+ for proportion in [0.1, 0.5, 10.0]:
1194
+ logging.info(f"proportion = {proportion}")
1195
+ x = torch.randn(100, 128)
1196
+ direction = torch.randn(128)
1197
+ coeffs = torch.randn(100, 1)
1198
+ x += proportion * direction * coeffs
1199
+
1200
+ x.requires_grad = True
1201
+
1202
+ num_channels = 128
1203
+ m = MaxEig(
1204
+ num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
1205
+ ) # grad_scale
1206
+
1207
+ for _ in range(4):
1208
+ y = m(x)
1209
+
1210
+ y_grad = torch.randn_like(x)
1211
+ y.backward(gradient=y_grad)
1212
+
1213
+ if proportion < 0.2:
1214
+ assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
1215
+ elif proportion > 1.0:
1216
+ assert not torch.allclose(x.grad, y_grad)
1217
+
1218
+
1219
+ def _test_whiten():
1220
+ for proportion in [0.1, 0.5, 10.0]:
1221
+ logging.info(f"_test_whiten(): proportion = {proportion}")
1222
+ x = torch.randn(100, 128)
1223
+ direction = torch.randn(128)
1224
+ coeffs = torch.randn(100, 1)
1225
+ x += proportion * direction * coeffs
1226
+
1227
+ x.requires_grad = True
1228
+
1229
+ num_channels = 128
1230
+ m = Whiten(
1231
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
1232
+ ) # grad_scale
1233
+
1234
+ for _ in range(4):
1235
+ y = m(x)
1236
+
1237
+ y_grad = torch.randn_like(x)
1238
+ y.backward(gradient=y_grad)
1239
+
1240
+ if proportion < 0.2:
1241
+ assert torch.allclose(x.grad, y_grad)
1242
+ elif proportion > 1.0:
1243
+ assert not torch.allclose(x.grad, y_grad)
1244
+
1245
+
1246
+ def _test_activation_balancer_sign():
1247
+ probs = torch.arange(0, 1, 0.01)
1248
+ N = 1000
1249
+ x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
1250
+ x = x.detach()
1251
+ x.requires_grad = True
1252
+ m = ActivationBalancer(
1253
+ probs.numel(),
1254
+ channel_dim=0,
1255
+ min_positive=0.05,
1256
+ max_positive=0.95,
1257
+ max_factor=0.2,
1258
+ min_abs=0.0,
1259
+ )
1260
+
1261
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
1262
+
1263
+ y = m(x)
1264
+ y.backward(gradient=y_grad)
1265
+ print("_test_activation_balancer_sign: x = ", x)
1266
+ print("_test_activation_balancer_sign: y grad = ", y_grad)
1267
+ print("_test_activation_balancer_sign: x grad = ", x.grad)
1268
+
1269
+
1270
+ def _test_activation_balancer_magnitude():
1271
+ magnitudes = torch.arange(0, 1, 0.01)
1272
+ N = 1000
1273
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
1274
+ x = x.detach()
1275
+ x.requires_grad = True
1276
+ m = ActivationBalancer(
1277
+ magnitudes.numel(),
1278
+ channel_dim=0,
1279
+ min_positive=0.0,
1280
+ max_positive=1.0,
1281
+ max_factor=0.2,
1282
+ min_abs=0.2,
1283
+ max_abs=0.8,
1284
+ min_prob=1.0,
1285
+ )
1286
+
1287
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1288
+
1289
+ y = m(x)
1290
+ y.backward(gradient=y_grad)
1291
+ print("_test_activation_balancer_magnitude: x = ", x)
1292
+ print("_test_activation_balancer_magnitude: y grad = ", y_grad)
1293
+ print("_test_activation_balancer_magnitude: x grad = ", x.grad)
1294
+
1295
+
1296
+ def _test_basic_norm():
1297
+ num_channels = 128
1298
+ m = BasicNorm(num_channels=num_channels, channel_dim=1)
1299
+
1300
+ x = torch.randn(500, num_channels)
1301
+
1302
+ y = m(x)
1303
+
1304
+ assert y.shape == x.shape
1305
+ x_rms = (x**2).mean().sqrt()
1306
+ y_rms = (y**2).mean().sqrt()
1307
+ print("x rms = ", x_rms)
1308
+ print("y rms = ", y_rms)
1309
+ assert y_rms < x_rms
1310
+ assert y_rms > 0.5 * x_rms
1311
+
1312
+
1313
+ def _test_double_swish_deriv():
1314
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1315
+ x.requires_grad = True
1316
+ m = DoubleSwish()
1317
+
1318
+ tol = (1.2 - (-0.043637)) / 255.0
1319
+ torch.autograd.gradcheck(m, x, atol=tol)
1320
+
1321
+ # for self-test.
1322
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1323
+ x.requires_grad = True
1324
+ y = m(x)
1325
+
1326
+
1327
+ def _test_softmax():
1328
+ a = torch.randn(2, 10, dtype=torch.float64)
1329
+ b = a.clone()
1330
+ a.requires_grad = True
1331
+ b.requires_grad = True
1332
+ a.softmax(dim=1)[:, 0].sum().backward()
1333
+ print("a grad = ", a.grad)
1334
+ softmax(b, dim=1)[:, 0].sum().backward()
1335
+ print("b grad = ", b.grad)
1336
+ assert torch.allclose(a.grad, b.grad)
1337
+
1338
+
1339
+ if __name__ == "__main__":
1340
+ logging.getLogger().setLevel(logging.INFO)
1341
+ torch.set_num_threads(1)
1342
+ torch.set_num_interop_threads(1)
1343
+ _test_softmax()
1344
+ _test_whiten()
1345
+ _test_max_eig()
1346
+ _test_activation_balancer_sign()
1347
+ _test_activation_balancer_magnitude()
1348
+ _test_basic_norm()
1349
+ _test_double_swish_deriv()