buttercrab commited on
Commit
1034391
·
1 Parent(s): 0d09dd0

initial commit

Browse files
Files changed (8) hide show
  1. .gitignore +20 -0
  2. app.py +390 -0
  3. dia/__init__.py +0 -0
  4. dia/audio.py +280 -0
  5. dia/config.py +206 -0
  6. dia/layers.py +873 -0
  7. dia/model.py +431 -0
  8. requirements.txt +8 -0
.gitignore ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ .gradio
13
+
14
+ **/*.pth
15
+ **/*.mp3
16
+ !example_prompt.mp3
17
+
18
+ .ruff_cache
19
+ .ipynb_checkpoints
20
+ config.json
app.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import tempfile
3
+ import time
4
+ from pathlib import Path
5
+ from typing import Optional, Tuple
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import torch
11
+
12
+ from dia.model import Dia
13
+
14
+
15
+ # --- Global Setup ---
16
+ parser = argparse.ArgumentParser(description="Gradio interface for Nari TTS")
17
+ parser.add_argument(
18
+ "--device", type=str, default=None, help="Force device (e.g., 'cuda', 'mps', 'cpu')"
19
+ )
20
+ parser.add_argument("--share", action="store_true", help="Enable Gradio sharing")
21
+
22
+ args = parser.parse_args()
23
+
24
+
25
+ # Determine device
26
+ if args.device:
27
+ device = torch.device(args.device)
28
+ elif torch.cuda.is_available():
29
+ device = torch.device("cuda")
30
+ # Simplified MPS check for broader compatibility
31
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
32
+ # Basic check is usually sufficient, detailed check can be problematic
33
+ device = torch.device("mps")
34
+ else:
35
+ device = torch.device("cpu")
36
+
37
+ print(f"Using device: {device}")
38
+
39
+ # Load Nari model and config
40
+ print("Loading Nari model...")
41
+ try:
42
+ # Use the function from inference.py
43
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B")
44
+ except Exception as e:
45
+ print(f"Error loading Nari model: {e}")
46
+ raise
47
+
48
+
49
+ def run_inference(
50
+ text_input: str,
51
+ audio_prompt_input: Optional[Tuple[int, np.ndarray]],
52
+ max_new_tokens: int,
53
+ cfg_scale: float,
54
+ temperature: float,
55
+ top_p: float,
56
+ cfg_filter_top_k: int,
57
+ speed_factor: float,
58
+ ):
59
+ """
60
+ Runs Nari inference using the globally loaded model and provided inputs.
61
+ Uses temporary files for text and audio prompt compatibility with inference.generate.
62
+ """
63
+ global model, device # Access global model, config, device
64
+
65
+ if not text_input or text_input.isspace():
66
+ raise gr.Error("Text input cannot be empty.")
67
+
68
+ temp_txt_file_path = None
69
+ temp_audio_prompt_path = None
70
+ output_audio = (44100, np.zeros(1, dtype=np.float32))
71
+
72
+ try:
73
+ prompt_path_for_generate = None
74
+ if audio_prompt_input is not None:
75
+ sr, audio_data = audio_prompt_input
76
+ # Check if audio_data is valid
77
+ if (
78
+ audio_data is None or audio_data.size == 0 or audio_data.max() == 0
79
+ ): # Check for silence/empty
80
+ gr.Warning("Audio prompt seems empty or silent, ignoring prompt.")
81
+ else:
82
+ # Save prompt audio to a temporary WAV file
83
+ with tempfile.NamedTemporaryFile(
84
+ mode="wb", suffix=".wav", delete=False
85
+ ) as f_audio:
86
+ temp_audio_prompt_path = f_audio.name # Store path for cleanup
87
+
88
+ # Basic audio preprocessing for consistency
89
+ # Convert to float32 in [-1, 1] range if integer type
90
+ if np.issubdtype(audio_data.dtype, np.integer):
91
+ max_val = np.iinfo(audio_data.dtype).max
92
+ audio_data = audio_data.astype(np.float32) / max_val
93
+ elif not np.issubdtype(audio_data.dtype, np.floating):
94
+ gr.Warning(
95
+ f"Unsupported audio prompt dtype {audio_data.dtype}, attempting conversion."
96
+ )
97
+ # Attempt conversion, might fail for complex types
98
+ try:
99
+ audio_data = audio_data.astype(np.float32)
100
+ except Exception as conv_e:
101
+ raise gr.Error(
102
+ f"Failed to convert audio prompt to float32: {conv_e}"
103
+ )
104
+
105
+ # Ensure mono (average channels if stereo)
106
+ if audio_data.ndim > 1:
107
+ if audio_data.shape[0] == 2: # Assume (2, N)
108
+ audio_data = np.mean(audio_data, axis=0)
109
+ elif audio_data.shape[1] == 2: # Assume (N, 2)
110
+ audio_data = np.mean(audio_data, axis=1)
111
+ else:
112
+ gr.Warning(
113
+ f"Audio prompt has unexpected shape {audio_data.shape}, taking first channel/axis."
114
+ )
115
+ audio_data = (
116
+ audio_data[0]
117
+ if audio_data.shape[0] < audio_data.shape[1]
118
+ else audio_data[:, 0]
119
+ )
120
+ audio_data = np.ascontiguousarray(
121
+ audio_data
122
+ ) # Ensure contiguous after slicing/mean
123
+
124
+ # Write using soundfile
125
+ try:
126
+ sf.write(
127
+ temp_audio_prompt_path, audio_data, sr, subtype="FLOAT"
128
+ ) # Explicitly use FLOAT subtype
129
+ prompt_path_for_generate = temp_audio_prompt_path
130
+ print(
131
+ f"Created temporary audio prompt file: {temp_audio_prompt_path} (orig sr: {sr})"
132
+ )
133
+ except Exception as write_e:
134
+ print(f"Error writing temporary audio file: {write_e}")
135
+ raise gr.Error(f"Failed to save audio prompt: {write_e}")
136
+
137
+ # 3. Run Generation
138
+
139
+ start_time = time.time()
140
+
141
+ # Use torch.inference_mode() context manager for the generation call
142
+ with torch.inference_mode():
143
+ output_audio_np = model.generate(
144
+ text_input,
145
+ max_tokens=max_new_tokens,
146
+ cfg_scale=cfg_scale,
147
+ temperature=temperature,
148
+ top_p=top_p,
149
+ use_cfg_filter=True,
150
+ cfg_filter_top_k=cfg_filter_top_k, # Pass the value here
151
+ use_torch_compile=False, # Keep False for Gradio stability
152
+ audio_prompt_path=prompt_path_for_generate,
153
+ )
154
+
155
+ end_time = time.time()
156
+ print(f"Generation finished in {end_time - start_time:.2f} seconds.")
157
+
158
+ # 4. Convert Codes to Audio
159
+ if output_audio_np is not None:
160
+ # Get sample rate from the loaded DAC model
161
+ output_sr = 44100
162
+
163
+ # --- Slow down audio ---
164
+ original_len = len(output_audio_np)
165
+ # Ensure speed_factor is positive and not excessively small/large to avoid issues
166
+ speed_factor = max(0.1, min(speed_factor, 5.0))
167
+ target_len = int(
168
+ original_len / speed_factor
169
+ ) # Target length based on speed_factor
170
+ if (
171
+ target_len != original_len and target_len > 0
172
+ ): # Only interpolate if length changes and is valid
173
+ x_original = np.arange(original_len)
174
+ x_resampled = np.linspace(0, original_len - 1, target_len)
175
+ resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np)
176
+ output_audio = (
177
+ output_sr,
178
+ resampled_audio_np.astype(np.float32),
179
+ ) # Use resampled audio
180
+ print(
181
+ f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed."
182
+ )
183
+ else:
184
+ output_audio = (
185
+ output_sr,
186
+ output_audio_np,
187
+ ) # Keep original if calculation fails or no change
188
+ print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).")
189
+ # --- End slowdown ---
190
+
191
+ print(
192
+ f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
193
+ )
194
+
195
+ else:
196
+ print("\nGeneration finished, but no valid tokens were produced.")
197
+ # Return default silence
198
+ gr.Warning("Generation produced no output.")
199
+
200
+ except Exception as e:
201
+ print(f"Error during inference: {e}")
202
+ import traceback
203
+
204
+ traceback.print_exc()
205
+ # Re-raise as Gradio error to display nicely in the UI
206
+ raise gr.Error(f"Inference failed: {e}")
207
+
208
+ finally:
209
+ # 5. Cleanup Temporary Files defensively
210
+ if temp_txt_file_path and Path(temp_txt_file_path).exists():
211
+ try:
212
+ Path(temp_txt_file_path).unlink()
213
+ print(f"Deleted temporary text file: {temp_txt_file_path}")
214
+ except OSError as e:
215
+ print(
216
+ f"Warning: Error deleting temporary text file {temp_txt_file_path}: {e}"
217
+ )
218
+ if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists():
219
+ try:
220
+ Path(temp_audio_prompt_path).unlink()
221
+ print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}")
222
+ except OSError as e:
223
+ print(
224
+ f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}"
225
+ )
226
+
227
+ return output_audio
228
+
229
+
230
+ # --- Create Gradio Interface ---
231
+ css = """
232
+ #col-container {max-width: 90%; margin-left: auto; margin-right: auto;}
233
+ """
234
+ # Attempt to load default text from example.txt
235
+ default_text = "[S1] Dia is an open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] Wow. Amazing. (laughs) \n[S2] Try it now on Git hub or Hugging Face."
236
+ example_txt_path = Path("./example.txt")
237
+ if example_txt_path.exists():
238
+ try:
239
+ default_text = example_txt_path.read_text(encoding="utf-8").strip()
240
+ if not default_text: # Handle empty example file
241
+ default_text = "Example text file was empty."
242
+ except Exception as e:
243
+ print(f"Warning: Could not read example.txt: {e}")
244
+
245
+
246
+ # Build Gradio UI
247
+ with gr.Blocks(css=css) as demo:
248
+ gr.Markdown("# Nari Text-to-Speech Synthesis")
249
+
250
+ with gr.Row(equal_height=False):
251
+ with gr.Column(scale=1):
252
+ text_input = gr.Textbox(
253
+ label="Input Text",
254
+ placeholder="Enter text here...",
255
+ value=default_text,
256
+ lines=5, # Increased lines
257
+ )
258
+ audio_prompt_input = gr.Audio(
259
+ label="Audio Prompt (Optional)",
260
+ show_label=True,
261
+ sources=["upload", "microphone"],
262
+ type="numpy",
263
+ )
264
+ with gr.Accordion("Generation Parameters", open=False):
265
+ max_new_tokens = gr.Slider(
266
+ label="Max New Tokens (Audio Length)",
267
+ minimum=860,
268
+ maximum=3072,
269
+ value=model.config.data.audio_length, # Use config default if available, else fallback
270
+ step=50,
271
+ info="Controls the maximum length of the generated audio (more tokens = longer audio).",
272
+ )
273
+ cfg_scale = gr.Slider(
274
+ label="CFG Scale (Guidance Strength)",
275
+ minimum=1.0,
276
+ maximum=5.0,
277
+ value=3.0, # Default from inference.py
278
+ step=0.1,
279
+ info="Higher values increase adherence to the text prompt.",
280
+ )
281
+ temperature = gr.Slider(
282
+ label="Temperature (Randomness)",
283
+ minimum=1.0,
284
+ maximum=1.5,
285
+ value=1.3, # Default from inference.py
286
+ step=0.05,
287
+ info="Lower values make the output more deterministic, higher values increase randomness.",
288
+ )
289
+ top_p = gr.Slider(
290
+ label="Top P (Nucleus Sampling)",
291
+ minimum=0.80,
292
+ maximum=1.0,
293
+ value=0.95, # Default from inference.py
294
+ step=0.01,
295
+ info="Filters vocabulary to the most likely tokens cumulatively reaching probability P.",
296
+ )
297
+ cfg_filter_top_k = gr.Slider(
298
+ label="CFG Filter Top K",
299
+ minimum=15,
300
+ maximum=50,
301
+ value=30,
302
+ step=1,
303
+ info="Top k filter for CFG guidance.",
304
+ )
305
+ speed_factor_slider = gr.Slider(
306
+ label="Speed Factor",
307
+ minimum=0.8,
308
+ maximum=1.0,
309
+ value=0.94,
310
+ step=0.02,
311
+ info="Adjusts the speed of the generated audio (1.0 = original speed).",
312
+ )
313
+
314
+ run_button = gr.Button("Generate Audio", variant="primary")
315
+
316
+ with gr.Column(scale=1):
317
+ audio_output = gr.Audio(
318
+ label="Generated Audio",
319
+ type="numpy",
320
+ autoplay=False,
321
+ )
322
+
323
+ # Link button click to function
324
+ run_button.click(
325
+ fn=run_inference,
326
+ inputs=[
327
+ text_input,
328
+ audio_prompt_input,
329
+ max_new_tokens,
330
+ cfg_scale,
331
+ temperature,
332
+ top_p,
333
+ cfg_filter_top_k,
334
+ speed_factor_slider,
335
+ ],
336
+ outputs=[audio_output], # Add status_output here if using it
337
+ api_name="generate_audio",
338
+ )
339
+
340
+ # Add examples (ensure the prompt path is correct or remove it if example file doesn't exist)
341
+ example_prompt_path = "./example_prompt.mp3" # Adjust if needed
342
+ examples_list = [
343
+ [
344
+ "[S1] Oh fire! Oh my goodness! What's the procedure? What to we do people? The smoke could be coming through an air duct! \n[S2] Oh my god! Okay.. it's happening. Everybody stay calm! \n[S1] What's the procedure... \n[S2] Everybody stay fucking calm!!!... Everybody fucking calm down!!!!! \n[S1] No! No! If you touch the handle, if its hot there might be a fire down the hallway! ",
345
+ None,
346
+ 3072,
347
+ 3.0,
348
+ 1.3,
349
+ 0.95,
350
+ 35,
351
+ 0.94,
352
+ ],
353
+ [
354
+ "[S1] Open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] I'm biased, but I think we clearly won. \n[S2] Hard to disagree. (laughs) \n[S1] Thanks for listening to this demo. \n[S2] Try it now on Git hub and Hugging Face. \n[S1] If you liked our model, please give us a star and share to your friends. \n[S2] This was Nari Labs.",
355
+ example_prompt_path if Path(example_prompt_path).exists() else None,
356
+ 3072,
357
+ 3.0,
358
+ 1.3,
359
+ 0.95,
360
+ 35,
361
+ 0.94,
362
+ ],
363
+ ]
364
+
365
+ if examples_list:
366
+ gr.Examples(
367
+ examples=examples_list,
368
+ inputs=[
369
+ text_input,
370
+ audio_prompt_input,
371
+ max_new_tokens,
372
+ cfg_scale,
373
+ temperature,
374
+ top_p,
375
+ cfg_filter_top_k,
376
+ speed_factor_slider,
377
+ ],
378
+ outputs=[audio_output],
379
+ fn=run_inference,
380
+ cache_examples=False,
381
+ label="Examples (Click to Run)",
382
+ )
383
+ else:
384
+ gr.Markdown("_(No examples configured or example prompt file missing)_")
385
+
386
+
387
+ # --- Launch the App ---
388
+ if __name__ == "__main__":
389
+ print("Launching Gradio interface...")
390
+ demo.launch()
dia/__init__.py ADDED
File without changes
dia/audio.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+
5
+ from .config import DataConfig
6
+
7
+
8
+ def build_delay_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
9
+ """
10
+ Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
11
+ Negative t_idx => BOS; t_idx >= T => PAD.
12
+ """
13
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
14
+
15
+ t_idx_BxT = torch.broadcast_to(
16
+ torch.arange(T, dtype=torch.int32)[None, :],
17
+ [B, T],
18
+ )
19
+ t_idx_BxTx1 = t_idx_BxT[..., None]
20
+ t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
21
+
22
+ b_idx_BxTxC = torch.broadcast_to(
23
+ torch.arange(B, dtype=torch.int32).view(B, 1, 1),
24
+ [B, T, C],
25
+ )
26
+ c_idx_BxTxC = torch.broadcast_to(
27
+ torch.arange(C, dtype=torch.int32).view(1, 1, C),
28
+ [B, T, C],
29
+ )
30
+
31
+ # We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail
32
+ t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
33
+
34
+ indices_BTCx3 = torch.stack(
35
+ [
36
+ b_idx_BxTxC.reshape(-1),
37
+ t_clamped_BxTxC.reshape(-1),
38
+ c_idx_BxTxC.reshape(-1),
39
+ ],
40
+ dim=1,
41
+ ).long() # Ensure indices are long type for indexing
42
+
43
+ return t_idx_BxTxC, indices_BTCx3
44
+
45
+
46
+ def apply_audio_delay(
47
+ audio_BxTxC: torch.Tensor,
48
+ pad_value: int,
49
+ bos_value: int,
50
+ precomp: tp.Tuple[torch.Tensor, torch.Tensor],
51
+ ) -> torch.Tensor:
52
+ """
53
+ Applies the delay pattern to batched audio tokens using precomputed indices,
54
+ inserting BOS where t_idx < 0 and PAD where t_idx >= T.
55
+
56
+ Args:
57
+ audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float)
58
+ pad_value: the padding token
59
+ bos_value: the BOS token
60
+ precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices
61
+
62
+ Returns:
63
+ result_BxTxC: [B, T, C] delayed audio tokens
64
+ """
65
+ device = audio_BxTxC.device # Get device from input tensor
66
+ t_idx_BxTxC, indices_BTCx3 = precomp
67
+ t_idx_BxTxC = t_idx_BxTxC.to(device) # Move precomputed indices to device
68
+ indices_BTCx3 = indices_BTCx3.to(device)
69
+
70
+ # Equivalent of tf.gather_nd using advanced indexing
71
+ # Ensure indices are long type if not already (build_delay_indices should handle this)
72
+ gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
73
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
74
+
75
+ # Create masks on the correct device
76
+ mask_bos = t_idx_BxTxC < 0 # => place bos_value
77
+ mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value
78
+
79
+ # Create scalar tensors on the correct device
80
+ bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
81
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
82
+
83
+ # If mask_bos, BOS; else if mask_pad, PAD; else original gather
84
+ # All tensors should now be on the same device
85
+ result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC))
86
+
87
+ return result_BxTxC
88
+
89
+
90
+ @torch.no_grad()
91
+ @torch.inference_mode()
92
+ def audio_to_codebook(
93
+ model,
94
+ input_values,
95
+ data_config: DataConfig,
96
+ padding_mask=None,
97
+ sample_rate=44100,
98
+ ):
99
+ """
100
+ Encodes the input audio waveform into discrete codes.
101
+
102
+ Args:
103
+ model: The model to use for encoding.
104
+ input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
105
+ Float values of the input audio waveform.
106
+ padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
107
+ Padding mask used to pad the `input_values`.
108
+ sample_rate (`int`, *optional*) :
109
+ Signal sampling_rate
110
+
111
+ Returns:
112
+ A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
113
+ factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
114
+ `codebook` of shape `[batch_size, num_codebooks, frames]`.
115
+ Scale is not used here.
116
+
117
+ """
118
+ audio_data = model.preprocess(input_values, sample_rate)
119
+
120
+ if padding_mask is None:
121
+ padding_mask = torch.ones_like(input_values).bool()
122
+
123
+ _, encoded_frame, _, _, _ = model.encode(audio_data, n_quantizers=None) # 1, C, T
124
+ seq_length = encoded_frame.shape[2]
125
+
126
+ t_idx_BxTxC, indices_BTCx3 = build_delay_indices(
127
+ B=1,
128
+ T=seq_length,
129
+ C=data_config.channels,
130
+ delay_pattern=data_config.delay_pattern,
131
+ )
132
+
133
+ encoded_frame = apply_audio_delay(
134
+ audio_BxTxC=encoded_frame.transpose(1, 2), # 1, T, C
135
+ pad_value=data_config.audio_pad_value,
136
+ bos_value=data_config.audio_bos_value,
137
+ precomp=(t_idx_BxTxC, indices_BTCx3),
138
+ )
139
+
140
+ return encoded_frame
141
+
142
+
143
+ def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
144
+ """
145
+ Precompute indices for the revert operation using PyTorch.
146
+
147
+ Returns:
148
+ A tuple (t_idx_BxTxC, indices_BTCx3) where:
149
+ - t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay.
150
+ - indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from:
151
+ batch indices, clamped time indices, and channel indices.
152
+ """
153
+ # Use default device unless specified otherwise; assumes inputs might define device later
154
+ device = None # Or determine dynamically if needed, e.g., from a model parameter
155
+
156
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
157
+
158
+ t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
159
+ t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
160
+
161
+ t_idx_BxTxC = torch.minimum(
162
+ t_idx_BT1 + delay_arr.view(1, 1, C),
163
+ torch.tensor(T - 1, device=device),
164
+ )
165
+ b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C])
166
+ c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C])
167
+
168
+ indices_BTCx3 = torch.stack(
169
+ [
170
+ b_idx_BxTxC.reshape(-1),
171
+ t_idx_BxTxC.reshape(-1),
172
+ c_idx_BxTxC.reshape(-1),
173
+ ],
174
+ axis=1,
175
+ ).long() # Ensure indices are long type
176
+
177
+ return t_idx_BxTxC, indices_BTCx3
178
+
179
+
180
+ def revert_audio_delay(
181
+ audio_BxTxC: torch.Tensor,
182
+ pad_value: int,
183
+ precomp: tp.Tuple[torch.Tensor, torch.Tensor],
184
+ T: int,
185
+ ) -> torch.Tensor:
186
+ """
187
+ Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version).
188
+
189
+ Args:
190
+ audio_BxTxC: Input delayed audio tensor
191
+ pad_value: Padding value for out-of-bounds indices
192
+ precomp: Precomputed revert indices tuple containing:
193
+ - t_idx_BxTxC: Time offset indices tensor
194
+ - indices_BTCx3: Gather indices tensor for original audio
195
+ T: Original sequence length before padding
196
+
197
+ Returns:
198
+ Reverted audio tensor with same shape as input
199
+ """
200
+ t_idx_BxTxC, indices_BTCx3 = precomp
201
+ device = audio_BxTxC.device # Get device from input tensor
202
+
203
+ # Move precomputed indices to the same device as audio_BxTxC if they aren't already
204
+ t_idx_BxTxC = t_idx_BxTxC.to(device)
205
+ indices_BTCx3 = indices_BTCx3.to(device)
206
+
207
+ # Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
208
+ gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
209
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.size()) # Use .size() for robust reshaping
210
+
211
+ # Create pad_tensor on the correct device
212
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
213
+ # Create T tensor on the correct device for comparison
214
+ T_tensor = torch.tensor(T, device=device)
215
+
216
+ result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC) # Changed np.where to torch.where
217
+
218
+ return result_BxTxC
219
+
220
+
221
+ @torch.no_grad()
222
+ @torch.inference_mode()
223
+ def decode(
224
+ model,
225
+ audio_codes,
226
+ ):
227
+ """
228
+ Decodes the given frames into an output audio waveform
229
+ """
230
+ if len(audio_codes) != 1:
231
+ raise ValueError(f"Expected one frame, got {len(audio_codes)}")
232
+
233
+ try:
234
+ audio_values = model.quantizer.from_codes(audio_codes)
235
+ audio_values = model.decode(audio_values[0])
236
+
237
+ return audio_values
238
+ except Exception as e:
239
+ print(f"Error in decode method: {str(e)}")
240
+ raise
241
+
242
+
243
+ def codebook_to_audio(generated_codes: torch.Tensor, model, delay_pattern, B=1, T=2600, C=9):
244
+ """Process a single codebook file to generate audio"""
245
+ # Remove BOS token
246
+ generated_codes = generated_codes[:, 1:]
247
+
248
+ if generated_codes.shape[1] > T:
249
+ generated_codes = generated_codes[:, :T]
250
+
251
+ seq_length = generated_codes.shape[1]
252
+
253
+ # Build revert indices
254
+ t_idx_BxTxC, indices_BTCx3 = build_revert_indices(B=B, T=seq_length, C=C, delay_pattern=delay_pattern)
255
+
256
+ # Transpose and add batch dimension
257
+ audio_BxTxC = generated_codes.transpose(1, 0).unsqueeze(0)
258
+ reverted_codebook = revert_audio_delay(
259
+ audio_BxTxC=audio_BxTxC,
260
+ pad_value=0,
261
+ precomp=(t_idx_BxTxC, indices_BTCx3),
262
+ T=seq_length,
263
+ )
264
+ reverted_codebook = reverted_codebook[:, :-30, :]
265
+
266
+ codebook = reverted_codebook.transpose(1, 2)
267
+
268
+ min_valid_index = 0
269
+ max_valid_index = 1023
270
+ invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
271
+
272
+ num_invalid = torch.sum(invalid_mask).item()
273
+ if num_invalid > 0:
274
+ print(f"Warning: Clamping {num_invalid} indices outside range [{min_valid_index}, {max_valid_index}] to 0.")
275
+
276
+ # Set invalid values to 0 (modify the tensor in-place)
277
+ codebook[invalid_mask] = 0
278
+ audio_array = decode(model, codebook)
279
+
280
+ return audio_array
dia/config.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management module for the Dia model.
2
+
3
+ This module provides comprehensive configuration management for the Dia model,
4
+ utilizing Pydantic for validation. It defines configurations for data processing,
5
+ model architecture (encoder and decoder), and training settings.
6
+
7
+ Key components:
8
+ - DataConfig: Parameters for data loading and preprocessing.
9
+ - EncoderConfig: Architecture details for the encoder module.
10
+ - DecoderConfig: Architecture details for the decoder module.
11
+ - ModelConfig: Combined model architecture settings.
12
+ - TrainingConfig: Training hyperparameters and settings.
13
+ - DiaConfig: Master configuration combining all components.
14
+ """
15
+
16
+ import os
17
+ from typing import Annotated
18
+
19
+ from pydantic import BaseModel, BeforeValidator, Field
20
+
21
+
22
+ class DataConfig(BaseModel, frozen=True):
23
+ """Configuration for data loading and preprocessing.
24
+
25
+ Attributes:
26
+ text_length: Maximum length of text sequences (must be multiple of 128).
27
+ audio_length: Maximum length of audio sequences (must be multiple of 128).
28
+ channels: Number of audio channels.
29
+ text_pad_value: Value used for padding text sequences.
30
+ audio_eos_value: Value representing the end of audio sequences.
31
+ audio_bos_value: Value representing the beginning of audio sequences.
32
+ audio_pad_value: Value used for padding audio sequences.
33
+ delay_pattern: List of delay values for each audio channel.
34
+ """
35
+
36
+ text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
37
+ audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
38
+ channels: int = Field(default=9, gt=0, multiple_of=1)
39
+ text_pad_value: int = Field(default=0)
40
+ audio_eos_value: int = Field(default=1024)
41
+ audio_pad_value: int = Field(default=1025)
42
+ audio_bos_value: int = Field(default=1026)
43
+ delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15])
44
+
45
+ def __hash__(self) -> int:
46
+ """Generate a hash based on all fields of the config."""
47
+ return hash(
48
+ (
49
+ self.text_length,
50
+ self.audio_length,
51
+ self.channels,
52
+ self.text_pad_value,
53
+ self.audio_pad_value,
54
+ self.audio_bos_value,
55
+ self.audio_eos_value,
56
+ tuple(self.delay_pattern),
57
+ )
58
+ )
59
+
60
+
61
+ class EncoderConfig(BaseModel, frozen=True):
62
+ """Configuration for the encoder component of the Dia model.
63
+
64
+ Attributes:
65
+ n_layer: Number of transformer layers.
66
+ n_embd: Embedding dimension.
67
+ n_hidden: Hidden dimension size in the MLP layers.
68
+ n_head: Number of attention heads.
69
+ head_dim: Dimension per attention head.
70
+ mlp_activations: List of activation functions for the MLP layers.
71
+ use_pre_norm: Whether to use pre-normalization (LayerNorm before attention/MLP).
72
+ """
73
+
74
+ n_layer: int = Field(gt=0)
75
+ n_embd: int = Field(gt=0)
76
+ n_hidden: int = Field(gt=0)
77
+ n_head: int = Field(gt=0)
78
+ head_dim: int = Field(gt=0)
79
+ mlp_activations: list[str] = Field(default=["silu", "linear"])
80
+ use_pre_norm: bool = Field(default=False)
81
+
82
+
83
+ class DecoderConfig(BaseModel, frozen=True):
84
+ """Configuration for the decoder component of the Dia model.
85
+
86
+ Attributes:
87
+ n_layer: Number of transformer layers.
88
+ n_embd: Embedding dimension.
89
+ n_hidden: Hidden dimension size in the MLP layers.
90
+ gqa_query_heads: Number of query heads for grouped-query self-attention.
91
+ kv_heads: Number of key/value heads for grouped-query self-attention.
92
+ gqa_head_dim: Dimension per query head for grouped-query self-attention.
93
+ cross_query_heads: Number of query heads for cross-attention.
94
+ cross_head_dim: Dimension per cross-attention head.
95
+ mlp_activations: List of activation functions for the MLP layers.
96
+ use_pre_norm: Whether to use pre-normalization.
97
+ """
98
+
99
+ n_layer: int = Field(gt=0)
100
+ n_embd: int = Field(gt=0)
101
+ n_hidden: int = Field(gt=0)
102
+ gqa_query_heads: int = Field(gt=0)
103
+ kv_heads: int = Field(gt=0)
104
+ gqa_head_dim: int = Field(gt=0)
105
+ cross_query_heads: int = Field(gt=0)
106
+ cross_head_dim: int = Field(gt=0)
107
+ mlp_activations: list[str] = Field(default=["silu", "linear"])
108
+ use_pre_norm: bool = Field(default=False)
109
+
110
+
111
+ class ModelConfig(BaseModel, frozen=True):
112
+ """Main configuration container for the Dia model architecture.
113
+
114
+ Attributes:
115
+ encoder: Configuration for the encoder component.
116
+ decoder: Configuration for the decoder component.
117
+ src_vocab_size: Size of the source (text) vocabulary.
118
+ tgt_vocab_size: Size of the target (audio code) vocabulary.
119
+ dropout: Dropout probability applied within the model.
120
+ normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm).
121
+ weight_dtype: Data type for model weights (e.g., "float32", "bfloat16").
122
+ rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE).
123
+ rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE).
124
+ """
125
+
126
+ encoder: EncoderConfig
127
+ decoder: DecoderConfig
128
+ src_vocab_size: int = Field(default=128, gt=0)
129
+ tgt_vocab_size: int = Field(default=1028, gt=0)
130
+ dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
131
+ normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
132
+ weight_dtype: str = Field(default="float32", description="Weight precision")
133
+ rope_min_timescale: int = Field(default=1, description="Timescale For global Attention")
134
+ rope_max_timescale: int = Field(default=10_000, description="Timescale For global Attention")
135
+
136
+
137
+ class TrainingConfig(BaseModel, frozen=True):
138
+ """Training process configuration and hyperparameters.
139
+
140
+ Note: This configuration currently only includes precision settings.
141
+ Other training parameters (like batch size, learning rate, optimizer settings)
142
+ are assumed to be handled externally.
143
+
144
+ Attributes:
145
+ dtype: Data type for activations during training (e.g., "bfloat16", "float32").
146
+ logits_dot_in_fp32: Whether to compute the final logits dot product in fp32 for stability.
147
+ """
148
+
149
+ dtype: str = Field(default="bfloat16", description="Activation precision")
150
+ logits_dot_in_fp32: bool = Field(default=False)
151
+
152
+
153
+ class DiaConfig(BaseModel, frozen=True):
154
+ """Master configuration for the Dia model.
155
+
156
+ Combines all sub-configurations into a single validated object.
157
+
158
+ Attributes:
159
+ version: Configuration version string.
160
+ model: Model architecture configuration.
161
+ training: Training process configuration (precision settings).
162
+ data: Data loading and processing configuration.
163
+ """
164
+
165
+ version: str = Field(default="1.0")
166
+ model: ModelConfig
167
+ training: TrainingConfig
168
+ data: DataConfig
169
+
170
+ def save(self, path: str) -> None:
171
+ """Save the current configuration instance to a JSON file.
172
+
173
+ Ensures the parent directory exists and the file has a .json extension.
174
+
175
+ Args:
176
+ path: The target file path to save the configuration.
177
+
178
+ Raises:
179
+ ValueError: If the path is not a file with a .json extension.
180
+ """
181
+ os.makedirs(os.path.dirname(path), exist_ok=True)
182
+ config_json = self.model_dump_json(indent=2)
183
+ with open(path, "w") as f:
184
+ f.write(config_json)
185
+
186
+ @classmethod
187
+ def load(cls, path: str) -> "DiaConfig | None":
188
+ """Load and validate a Dia configuration from a JSON file.
189
+
190
+ Args:
191
+ path: The path to the configuration file.
192
+
193
+ Returns:
194
+ A validated DiaConfig instance if the file exists and is valid,
195
+ otherwise None if the file is not found.
196
+
197
+ Raises:
198
+ ValueError: If the path does not point to an existing .json file.
199
+ pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema.
200
+ """
201
+ try:
202
+ with open(path, "r") as f:
203
+ content = f.read()
204
+ return cls.model_validate_json(content)
205
+ except FileNotFoundError:
206
+ return None
dia/layers.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+ from torch.nn import RMSNorm
8
+
9
+ from .config import DiaConfig
10
+
11
+
12
+ def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
13
+ return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
14
+
15
+
16
+ def _str_to_dtype(dtype_str: str) -> torch.dtype | None:
17
+ # Allow None for default behavior
18
+ if dtype_str is None or dtype_str.lower() == "none":
19
+ return None
20
+ if dtype_str == "float32":
21
+ return torch.float32
22
+ elif dtype_str == "float16":
23
+ return torch.float16
24
+ elif dtype_str == "bfloat16":
25
+ return torch.bfloat16
26
+ else:
27
+ raise ValueError(f"Unsupported dtype string: {dtype_str}")
28
+
29
+
30
+ class DenseGeneral(nn.Module):
31
+ """
32
+ PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
33
+
34
+ Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
35
+ for the generalized matrix multiplication. Weight/bias shapes are calculated
36
+ and parameters created during initialization based on config.
37
+ `load_weights` validates shapes and copies data.
38
+
39
+ Attributes:
40
+ axis (Tuple[int, ...]): Input axis or axes to contract.
41
+ in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
42
+ out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims).
43
+ use_bias (bool): Whether to add a bias term.
44
+ weight (nn.Parameter): The kernel parameter.
45
+ bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True).
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ in_shapes: tuple[int, ...],
51
+ out_features: tuple[int, ...],
52
+ axis: tuple[int, ...] = (-1,),
53
+ dtype: torch.dtype | None = None,
54
+ weight_dtype: torch.dtype | None = None,
55
+ device: torch.device | None = None,
56
+ ):
57
+ super().__init__()
58
+ self.in_shapes = in_shapes
59
+ self.out_features = out_features
60
+ self.axis = axis
61
+ self.dtype = dtype
62
+ self.kernel_shape = self.in_shapes + self.out_features
63
+
64
+ factory_kwargs = {"device": device, "dtype": weight_dtype}
65
+ self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
66
+ self.register_parameter("bias", None)
67
+
68
+ def forward(self, inputs: Tensor) -> Tensor:
69
+ norm_axis = _normalize_axes(self.axis, inputs.ndim)
70
+ kernel_contract_axes = tuple(range(len(norm_axis)))
71
+
72
+ output = torch.tensordot(
73
+ inputs.float(),
74
+ self.weight.float(),
75
+ dims=(norm_axis, kernel_contract_axes),
76
+ ).to(inputs.dtype)
77
+ return output
78
+
79
+
80
+ def get_activation_fn(activation_string: str) -> nn.Module: # Return Module instance
81
+ """Maps activation string to PyTorch activation function module."""
82
+ if activation_string == "gelu":
83
+ return nn.GELU()
84
+ elif activation_string == "relu":
85
+ return nn.ReLU()
86
+ elif activation_string == "silu" or activation_string == "swish":
87
+ return nn.SiLU()
88
+ elif activation_string == "linear":
89
+ return nn.Identity()
90
+ else:
91
+ raise ValueError(f"Unsupported activation function: {activation_string}")
92
+
93
+
94
+ class MlpBlock(nn.Module):
95
+ """MLP block using DenseGeneral."""
96
+
97
+ def __init__(
98
+ self,
99
+ config: DiaConfig,
100
+ embed_dim: int,
101
+ intermediate_dim: int,
102
+ dropout_rate: float,
103
+ activations: list[str] = ["silu", "linear"],
104
+ use_pre_norm: bool = False,
105
+ ):
106
+ super().__init__()
107
+ self.use_pre_norm = use_pre_norm
108
+ num_activations = len(activations)
109
+ compute_dtype = _str_to_dtype(config.training.dtype)
110
+ weight_dtype = _str_to_dtype(config.model.weight_dtype)
111
+ self.dtype = compute_dtype
112
+ # Assume default device for now, could be passed in config
113
+
114
+ if use_pre_norm:
115
+ self.pre_norm = RMSNorm(
116
+ embed_dim,
117
+ eps=config.model.normalization_layer_epsilon,
118
+ dtype=torch.float32,
119
+ )
120
+
121
+ self.wi_fused = DenseGeneral(
122
+ in_shapes=(embed_dim,),
123
+ out_features=(
124
+ num_activations,
125
+ intermediate_dim,
126
+ ),
127
+ axis=(-1,),
128
+ dtype=compute_dtype,
129
+ weight_dtype=weight_dtype,
130
+ )
131
+
132
+ self.activation_fn_0 = get_activation_fn(activations[0]) # silu
133
+ self.activation_fn_1 = get_activation_fn(activations[1]) # linear
134
+
135
+ self.dropout = nn.Dropout(dropout_rate)
136
+
137
+ # Output layer using DenseGeneral
138
+ self.wo = DenseGeneral(
139
+ in_shapes=(intermediate_dim,),
140
+ out_features=(embed_dim,),
141
+ axis=(-1,),
142
+ dtype=compute_dtype,
143
+ weight_dtype=weight_dtype,
144
+ )
145
+
146
+ def forward(self, x: torch.Tensor, deterministic: bool) -> torch.Tensor:
147
+ """Forward pass."""
148
+ if self.use_pre_norm and hasattr(self, "pre_norm"):
149
+ x = self.pre_norm(x)
150
+
151
+ fused_x = self.wi_fused(x)
152
+
153
+ gate_input = fused_x[..., 0, :]
154
+ up_input = fused_x[..., 1, :]
155
+
156
+ gate = self.activation_fn_0(gate_input)
157
+ up = self.activation_fn_1(up_input)
158
+ hidden = torch.mul(gate, up).to(self.dtype)
159
+
160
+ if not deterministic:
161
+ hidden = self.dropout(hidden)
162
+
163
+ output = self.wo(hidden)
164
+ return output
165
+
166
+
167
+ class RotaryEmbedding(nn.Module):
168
+ """Rotary Position Embedding (RoPE) implementation in PyTorch."""
169
+
170
+ def __init__(
171
+ self,
172
+ embedding_dims: int,
173
+ min_timescale: int = 1,
174
+ max_timescale: int = 10000,
175
+ dtype: torch.dtype = torch.float32,
176
+ ):
177
+ super().__init__()
178
+ if embedding_dims % 2 != 0:
179
+ raise ValueError("Embedding dim must be even for RoPE.")
180
+ self.embedding_dims = embedding_dims
181
+ self.min_timescale = min_timescale
182
+ self.max_timescale = max_timescale
183
+ self.dtype = dtype
184
+
185
+ half_embedding_dim = embedding_dims // 2
186
+ fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
187
+ self.register_buffer(
188
+ "timescale",
189
+ self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction,
190
+ persistent=False,
191
+ )
192
+
193
+ def extra_repr(self) -> str:
194
+ s = f"{self.timescale.shape}"
195
+ return s
196
+
197
+ def forward(self, inputs: torch.Tensor, position: torch.Tensor):
198
+ """Applies RoPE."""
199
+ position = position.unsqueeze(-1).unsqueeze(-1)
200
+ timescale = self.timescale.to(inputs.device)
201
+ sinusoid_inp = position / timescale
202
+ sin = torch.sin(sinusoid_inp).to(inputs.dtype)
203
+ cos = torch.cos(sinusoid_inp).to(inputs.dtype)
204
+ first_half, second_half = torch.chunk(inputs, 2, dim=-1)
205
+ first_part = first_half * cos - second_half * sin
206
+ second_part = second_half * cos + first_half * sin
207
+ return torch.cat((first_part, second_part), dim=-1)
208
+
209
+
210
+ class KVCache:
211
+ def __init__(self, num_heads, max_len, head_dim, device, k=None, v=None):
212
+ self.k = torch.zeros((2, num_heads, max_len, head_dim), device=device) if k is None else k
213
+ self.v = torch.zeros((2, num_heads, max_len, head_dim), device=device) if v is None else v
214
+ self.current_idx = 0
215
+ self.max_len = max_len
216
+
217
+ def get_kv_for_attention(self, current_k, current_v):
218
+ if self.current_idx == 0:
219
+ return current_k, current_v
220
+ else:
221
+ past_k = self.k[:, :, : self.current_idx, :]
222
+ past_v = self.v[:, :, : self.current_idx, :]
223
+ attn_k = torch.cat((past_k, current_k), dim=2)
224
+ attn_v = torch.cat((past_v, current_v), dim=2)
225
+ return attn_k, attn_v
226
+
227
+ def update_cache(self, k, v):
228
+ assert self.current_idx < self.max_len
229
+ self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
230
+ self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
231
+ self.current_idx += 1
232
+
233
+ def prefill_kv(self, k, v):
234
+ prefill_len = k.shape[2]
235
+ assert prefill_len <= self.max_len
236
+ self.k[:, :, :prefill_len, :] = k
237
+ self.v[:, :, :prefill_len, :] = v
238
+ self.current_idx = prefill_len
239
+
240
+
241
+ class Attention(nn.Module):
242
+ """Attention using DenseGeneral."""
243
+
244
+ def __init__(
245
+ self,
246
+ config: DiaConfig,
247
+ q_embed_dim: int,
248
+ kv_embed_dim: int,
249
+ num_query_heads: int,
250
+ num_kv_heads: int,
251
+ head_dim: int,
252
+ dropout_rate: float,
253
+ is_cross_attn: bool = False,
254
+ out_embed_dim: int | None = None,
255
+ ):
256
+ super().__init__()
257
+ self.num_query_heads = num_query_heads
258
+ self.num_kv_heads = num_kv_heads
259
+ self.head_dim = head_dim
260
+ self.is_cross_attn = is_cross_attn
261
+ self.dropout_rate = dropout_rate
262
+ compute_dtype = _str_to_dtype(config.training.dtype)
263
+ weight_dtype = _str_to_dtype(config.model.weight_dtype)
264
+ self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
265
+ self.projected_query_dim = num_query_heads * head_dim
266
+ if num_query_heads % num_kv_heads != 0:
267
+ raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})")
268
+ self.num_gqa_groups = num_query_heads // num_kv_heads
269
+
270
+ # --- Projection Layers using DenseGeneral ---
271
+ self.q_proj = DenseGeneral(
272
+ in_shapes=(q_embed_dim,),
273
+ out_features=(num_query_heads, head_dim),
274
+ axis=(-1,),
275
+ dtype=compute_dtype,
276
+ weight_dtype=weight_dtype,
277
+ )
278
+ self.k_proj = DenseGeneral(
279
+ in_shapes=(kv_embed_dim,),
280
+ out_features=(num_kv_heads, head_dim),
281
+ axis=(-1,),
282
+ dtype=compute_dtype,
283
+ weight_dtype=weight_dtype,
284
+ )
285
+ self.v_proj = DenseGeneral(
286
+ in_shapes=(kv_embed_dim,),
287
+ out_features=(num_kv_heads, head_dim),
288
+ axis=(-1,),
289
+ dtype=compute_dtype,
290
+ weight_dtype=weight_dtype,
291
+ )
292
+ self.o_proj = DenseGeneral(
293
+ in_shapes=(num_query_heads, head_dim),
294
+ out_features=(self.output_dim,),
295
+ axis=(-2, -1),
296
+ dtype=compute_dtype,
297
+ weight_dtype=weight_dtype,
298
+ )
299
+
300
+ # --- Rotary Embedding ---
301
+ self.rotary_emb = RotaryEmbedding(
302
+ embedding_dims=self.head_dim,
303
+ min_timescale=config.model.rope_min_timescale,
304
+ max_timescale=config.model.rope_max_timescale,
305
+ dtype=compute_dtype,
306
+ )
307
+
308
+ def forward(
309
+ self,
310
+ Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
311
+ Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
312
+ q_positions: torch.Tensor, # (B, T)
313
+ kv_positions: torch.Tensor | None = None, # (B, S)
314
+ deterministic: bool = True,
315
+ attn_mask: torch.Tensor | None = None, # None in Decoder Self Attention, Valid mask in Others
316
+ cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
317
+ prefill: bool = False, # True only when prefilling KV Cache
318
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
319
+ """
320
+ Performs attention calculation with optional KV caching.
321
+
322
+ Args:
323
+ Xq: Query tensor (B, T, D). T=1 during single-step decoding.
324
+ Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
325
+ q_positions: Positions for queries (B, T).
326
+ kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
327
+ deterministic: If True, disable dropout.
328
+ attn_mask: Attention mask.
329
+ cache: KVCache.
330
+ prefill: If True, use prefill mode.
331
+
332
+ Returns:
333
+ A tuple containing:
334
+ - output: The attention output tensor (B, T, output_dim).
335
+ - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
336
+ """
337
+ if kv_positions is None:
338
+ kv_positions = q_positions
339
+ original_dtype = Xq.dtype
340
+
341
+ Xq_BxTxNxH = self.q_proj(Xq)
342
+ Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
343
+ Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
344
+
345
+ # Input values into attention calculation
346
+ attn_k: torch.Tensor | None = None
347
+ attn_v: torch.Tensor | None = None
348
+ new_kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None
349
+
350
+ # Decoder Cross Attention
351
+ if self.is_cross_attn:
352
+ # Directly use cache (no need to check index)
353
+ attn_k, attn_v = cache.k, cache.v
354
+ if attn_k.shape[1] != self.num_query_heads or attn_v.shape[1] != self.num_query_heads:
355
+ raise ValueError(
356
+ f"Cross-attention cache head dimension ({attn_k.shape[1]}) "
357
+ f"does not match num_query_heads ({self.num_query_heads}). "
358
+ "Cache should be pre-repeated for GQA."
359
+ )
360
+ # Self Attention
361
+ else:
362
+ Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
363
+ Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
364
+ Xk_BxSxKxH = self.rotary_emb(Xk_BxSxKxH, position=kv_positions) # (B, S, K, H)
365
+
366
+ Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
367
+ Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
368
+ # S=1 for Decode Step
369
+
370
+ if self.num_gqa_groups > 1:
371
+ Xk_BxNxSxH = Xk_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
372
+ Xv_BxNxSxH = Xv_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
373
+ else:
374
+ Xk_BxNxSxH = Xk_BxKxSxH
375
+ Xv_BxNxSxH = Xv_BxKxSxH
376
+
377
+ # Encoder Self Attention
378
+ if cache is None:
379
+ attn_k = Xk_BxNxSxH
380
+ attn_v = Xv_BxNxSxH
381
+ # Decoder Self Attention
382
+ else:
383
+ # In prefill mode, we fill in cache until prefill length
384
+ if prefill:
385
+ attn_k, attn_v = Xk_BxNxSxH, Xv_BxNxSxH
386
+ cache.prefill_kv(attn_k, attn_v)
387
+ # In decode step, we add current K/V to cache step by step
388
+ else:
389
+ new_kv_cache = Xk_BxNxSxH, Xv_BxNxSxH
390
+ attn_k, attn_v = cache.get_kv_for_attention(Xk_BxNxSxH, Xv_BxNxSxH)
391
+
392
+ attn_output = F.scaled_dot_product_attention(
393
+ Xq_BxNxTxH,
394
+ attn_k,
395
+ attn_v,
396
+ attn_mask=attn_mask,
397
+ dropout_p=self.dropout_rate if not deterministic else 0.0,
398
+ scale=1.0,
399
+ )
400
+
401
+ attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
402
+ output = self.o_proj(attn_output)
403
+
404
+ return output.to(original_dtype), new_kv_cache
405
+
406
+
407
+ class EncoderLayer(nn.Module):
408
+ """Transformer Encoder Layer using DenseGeneral."""
409
+
410
+ def __init__(self, config: DiaConfig):
411
+ super().__init__()
412
+ self.config = config
413
+ model_config = config.model
414
+ enc_config = config.model.encoder
415
+ embed_dim = enc_config.n_embd
416
+
417
+ self.pre_sa_norm = RMSNorm(
418
+ embed_dim,
419
+ eps=model_config.normalization_layer_epsilon,
420
+ dtype=torch.float32,
421
+ )
422
+ self.self_attention = Attention(
423
+ config=config,
424
+ q_embed_dim=embed_dim,
425
+ kv_embed_dim=embed_dim,
426
+ num_query_heads=enc_config.n_head,
427
+ num_kv_heads=enc_config.n_head,
428
+ head_dim=enc_config.head_dim,
429
+ dropout_rate=model_config.dropout,
430
+ is_cross_attn=False,
431
+ out_embed_dim=embed_dim,
432
+ )
433
+ self.post_sa_norm = RMSNorm(
434
+ embed_dim,
435
+ eps=model_config.normalization_layer_epsilon,
436
+ dtype=torch.float32,
437
+ )
438
+ self.mlp = MlpBlock(
439
+ config=config,
440
+ embed_dim=embed_dim,
441
+ intermediate_dim=enc_config.n_hidden,
442
+ activations=enc_config.mlp_activations,
443
+ dropout_rate=model_config.dropout,
444
+ use_pre_norm=enc_config.use_pre_norm,
445
+ )
446
+ self.dropout = nn.Dropout(model_config.dropout)
447
+
448
+ def forward(
449
+ self,
450
+ x: torch.Tensor,
451
+ src_positions: torch.Tensor | None = None,
452
+ deterministic: bool = True,
453
+ attn_mask: torch.Tensor | None = None,
454
+ ) -> torch.Tensor:
455
+ residual = x
456
+ x_norm = self.pre_sa_norm(x)
457
+
458
+ sa_out, _ = self.self_attention(
459
+ Xq=x_norm,
460
+ Xkv=x_norm,
461
+ q_positions=src_positions,
462
+ kv_positions=src_positions,
463
+ deterministic=deterministic,
464
+ attn_mask=attn_mask,
465
+ )
466
+ x = residual + sa_out
467
+
468
+ residual = x
469
+ x_norm = self.post_sa_norm(x)
470
+ mlp_out = self.mlp(x_norm, deterministic=deterministic)
471
+ x = residual + mlp_out
472
+
473
+ if not deterministic:
474
+ x = self.dropout(x)
475
+ return x
476
+
477
+
478
+ class Encoder(nn.Module):
479
+ """Transformer Encoder Stack using DenseGeneral."""
480
+
481
+ def __init__(self, config: DiaConfig):
482
+ super().__init__()
483
+ self.config = config
484
+ model_config = config.model
485
+ enc_config = config.model.encoder
486
+ compute_dtype = _str_to_dtype(config.training.dtype)
487
+
488
+ self.embedding = nn.Embedding(
489
+ model_config.src_vocab_size,
490
+ enc_config.n_embd,
491
+ dtype=compute_dtype,
492
+ )
493
+ self.dropout = nn.Dropout(model_config.dropout)
494
+ self.layers = nn.ModuleList([EncoderLayer(config=config) for _ in range(enc_config.n_layer)])
495
+ self.norm = RMSNorm(
496
+ enc_config.n_embd,
497
+ eps=model_config.normalization_layer_epsilon,
498
+ dtype=torch.float32,
499
+ )
500
+
501
+ def forward(
502
+ self,
503
+ x_ids: torch.Tensor,
504
+ src_positions: torch.Tensor | None = None,
505
+ deterministic: bool = True,
506
+ attn_mask: torch.Tensor | None = None,
507
+ ) -> torch.Tensor:
508
+ x = self.embedding(x_ids)
509
+
510
+ if not deterministic:
511
+ x = self.dropout(x)
512
+
513
+ for layer in self.layers:
514
+ x = layer(
515
+ x,
516
+ src_positions=src_positions,
517
+ deterministic=deterministic,
518
+ attn_mask=attn_mask,
519
+ )
520
+ x = self.norm(x)
521
+ if not deterministic:
522
+ x = self.dropout(x)
523
+ return x
524
+
525
+
526
+ class DecoderLayer(nn.Module):
527
+ """Transformer Decoder Layer using DenseGeneral."""
528
+
529
+ def __init__(self, config: DiaConfig):
530
+ super().__init__()
531
+ self.config = config
532
+ model_config = config.model
533
+ dec_config = config.model.decoder
534
+ enc_config = config.model.encoder
535
+ dec_embed_dim = dec_config.n_embd
536
+ enc_embed_dim = enc_config.n_embd
537
+
538
+ # Norms
539
+ self.pre_sa_norm = RMSNorm(
540
+ dec_embed_dim,
541
+ eps=model_config.normalization_layer_epsilon,
542
+ dtype=torch.float32,
543
+ )
544
+ self.pre_ca_norm = RMSNorm(
545
+ dec_embed_dim,
546
+ eps=model_config.normalization_layer_epsilon,
547
+ dtype=torch.float32,
548
+ )
549
+ self.pre_mlp_norm = RMSNorm(
550
+ dec_embed_dim,
551
+ eps=model_config.normalization_layer_epsilon,
552
+ dtype=torch.float32,
553
+ )
554
+
555
+ # Self-Attention (GQA) with Causal Masking
556
+ self.self_attention = Attention(
557
+ config=config,
558
+ q_embed_dim=dec_embed_dim,
559
+ kv_embed_dim=dec_embed_dim,
560
+ num_query_heads=dec_config.gqa_query_heads,
561
+ num_kv_heads=dec_config.kv_heads,
562
+ head_dim=dec_config.gqa_head_dim,
563
+ dropout_rate=model_config.dropout,
564
+ is_cross_attn=False,
565
+ out_embed_dim=dec_embed_dim,
566
+ )
567
+ # Cross-Attention (MHA)
568
+ self.cross_attention = Attention(
569
+ config=config,
570
+ q_embed_dim=dec_embed_dim,
571
+ kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
572
+ num_query_heads=dec_config.cross_query_heads,
573
+ num_kv_heads=dec_config.cross_query_heads,
574
+ head_dim=dec_config.cross_head_dim,
575
+ dropout_rate=model_config.dropout,
576
+ is_cross_attn=True,
577
+ out_embed_dim=dec_embed_dim,
578
+ )
579
+ # MLP
580
+ self.mlp = MlpBlock(
581
+ config=config,
582
+ embed_dim=dec_embed_dim,
583
+ intermediate_dim=dec_config.n_hidden,
584
+ activations=dec_config.mlp_activations,
585
+ dropout_rate=model_config.dropout,
586
+ use_pre_norm=dec_config.use_pre_norm,
587
+ )
588
+
589
+ def forward(
590
+ self,
591
+ x: torch.Tensor,
592
+ encoder_out: torch.Tensor,
593
+ tgt_positions: torch.Tensor,
594
+ src_positions: torch.Tensor | None,
595
+ deterministic: bool,
596
+ self_attn_mask: torch.Tensor,
597
+ cross_attn_mask: torch.Tensor,
598
+ self_attn_cache: KVCache,
599
+ cross_attn_cache: KVCache,
600
+ prefill: bool = False,
601
+ ) -> torch.Tensor:
602
+ residual = x
603
+ x_norm = self.pre_sa_norm(x)
604
+
605
+ sa_out, new_kv_cache = self.self_attention(
606
+ Xq=x_norm, # (2, 1, D)
607
+ Xkv=x_norm, # (2, 1, D)
608
+ q_positions=tgt_positions, # (2, 1)
609
+ kv_positions=tgt_positions, # (2, 1)
610
+ deterministic=deterministic,
611
+ attn_mask=self_attn_mask, # (2, 1, 1, S_max)
612
+ cache=self_attn_cache,
613
+ prefill=prefill,
614
+ )
615
+
616
+ x = residual + sa_out
617
+
618
+ # 2. Cross-Attention
619
+ residual = x
620
+ x_norm = self.pre_ca_norm(x)
621
+ ca_out, _ = self.cross_attention(
622
+ Xq=x_norm,
623
+ Xkv=encoder_out,
624
+ q_positions=tgt_positions,
625
+ kv_positions=src_positions,
626
+ deterministic=deterministic,
627
+ attn_mask=cross_attn_mask,
628
+ cache=cross_attn_cache,
629
+ )
630
+ x = residual + ca_out
631
+
632
+ # 3. MLP
633
+ residual = x
634
+ x_norm = self.pre_mlp_norm(x)
635
+ mlp_out = self.mlp(x_norm, deterministic=deterministic)
636
+ x = residual + mlp_out
637
+
638
+ return x, new_kv_cache
639
+
640
+
641
+ class Decoder(nn.Module):
642
+ """Transformer Decoder Stack using DenseGeneral."""
643
+
644
+ def __init__(self, config: DiaConfig):
645
+ super().__init__()
646
+ self.config = config
647
+ model_config = config.model
648
+ dec_config = config.model.decoder
649
+ train_config = config.training
650
+ data_config = config.data
651
+ compute_dtype = _str_to_dtype(config.training.dtype)
652
+ weight_dtype = _str_to_dtype(config.model.weight_dtype)
653
+ self.num_channels = data_config.channels
654
+ self.num_layers = dec_config.n_layer
655
+
656
+ self.embeddings = nn.ModuleList(
657
+ [
658
+ nn.Embedding(model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype)
659
+ for _ in range(self.num_channels)
660
+ ]
661
+ )
662
+ self.dropout = nn.Dropout(model_config.dropout)
663
+ self.layers = nn.ModuleList([DecoderLayer(config=config) for _ in range(self.num_layers)])
664
+ self.norm = RMSNorm(
665
+ dec_config.n_embd,
666
+ eps=model_config.normalization_layer_epsilon,
667
+ dtype=torch.float32,
668
+ )
669
+
670
+ # Final Logits Projection using DenseGeneral
671
+ self.logits_dense = DenseGeneral(
672
+ in_shapes=(dec_config.n_embd,),
673
+ out_features=(self.num_channels, model_config.tgt_vocab_size),
674
+ axis=(-1,),
675
+ dtype=(torch.float32 if train_config.logits_dot_in_fp32 else compute_dtype),
676
+ weight_dtype=weight_dtype,
677
+ )
678
+ self.logits_in_fp32 = train_config.logits_dot_in_fp32
679
+
680
+ def precompute_cross_attention_kv(
681
+ self,
682
+ max_len: int,
683
+ encoder_out: torch.Tensor, # (B, S, E)
684
+ src_positions: torch.Tensor | None, # (B, S)
685
+ ) -> list[KVCache]:
686
+ """
687
+ Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
688
+ """
689
+ per_layer_kv_cache: list[KVCache] = []
690
+
691
+ for layer in self.layers:
692
+ cross_attn_module = layer.cross_attention
693
+ k_proj = cross_attn_module.k_proj(encoder_out)
694
+ v_proj = cross_attn_module.v_proj(encoder_out)
695
+
696
+ k_proj = cross_attn_module.rotary_emb(k_proj, position=src_positions)
697
+ k = k_proj.transpose(1, 2)
698
+ v = v_proj.transpose(1, 2)
699
+
700
+ per_layer_kv_cache.append(
701
+ KVCache(
702
+ cross_attn_module.num_kv_heads,
703
+ max_len,
704
+ cross_attn_module.head_dim,
705
+ k.device,
706
+ k=k,
707
+ v=v,
708
+ )
709
+ )
710
+
711
+ return per_layer_kv_cache
712
+
713
+ def decode_step(
714
+ self,
715
+ tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
716
+ tgt_pos_Bx1: torch.Tensor, # [B, 1]
717
+ encoder_out: torch.Tensor, # [B, S, E]
718
+ self_attn_mask: Any, # None
719
+ cross_attn_mask: torch.Tensor, # [B, 1, 1, S]
720
+ self_attention_cache: list[KVCache],
721
+ cross_attention_cache: list[KVCache],
722
+ ) -> torch.Tensor:
723
+ """
724
+ Performs a single decoding step, managing KV caches layer by layer.
725
+
726
+ Returns:
727
+ A tuple containing:
728
+ - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
729
+ """
730
+ assert self_attn_mask is None, "Self-attention mask should be None, kept for pattern"
731
+
732
+ x = None
733
+ for i in range(self.num_channels):
734
+ channel_tokens = tgt_ids_Bx1xC[..., i]
735
+ channel_embed = self.embeddings[i](channel_tokens)
736
+ x = channel_embed if x is None else x + channel_embed
737
+
738
+ new_cache = []
739
+
740
+ for i, layer in enumerate(self.layers):
741
+ self_cache = self_attention_cache[i]
742
+ cross_cache = cross_attention_cache[i]
743
+ x, new_kv_cache = layer(
744
+ x, # (2, 1, D)
745
+ encoder_out, # (2, S, E)
746
+ src_positions=None, # CA KV is already computed
747
+ tgt_positions=tgt_pos_Bx1, # (2, 1)
748
+ deterministic=True,
749
+ self_attn_mask=None,
750
+ cross_attn_mask=cross_attn_mask,
751
+ self_attn_cache=self_cache,
752
+ cross_attn_cache=cross_cache,
753
+ )
754
+ new_cache.append(new_kv_cache)
755
+
756
+ x = self.norm(x)
757
+ logits_Bx1xCxV = self.logits_dense(x)
758
+
759
+ return logits_Bx1xCxV.to(torch.float32), new_cache
760
+
761
+ def forward(
762
+ self,
763
+ tgt_ids_BxTxC: torch.Tensor,
764
+ encoder_out: torch.Tensor,
765
+ tgt_positions: torch.Tensor,
766
+ src_positions: torch.Tensor,
767
+ deterministic: bool,
768
+ self_attn_mask: torch.Tensor,
769
+ cross_attn_mask: torch.Tensor,
770
+ self_attention_cache: list[KVCache],
771
+ cross_attention_cache: list[KVCache],
772
+ ) -> torch.Tensor:
773
+ """
774
+ Forward pass for the Decoder stack, managing KV caches.
775
+
776
+ Args:
777
+ tgt_ids_BxTxC: Target token IDs (B, T, C).
778
+ encoder_out: Output from the encoder (B, S, E).
779
+ tgt_positions: Positions for target sequence (B, T).
780
+ src_positions: Positions for source sequence (B, S).
781
+ deterministic: Disable dropout if True.
782
+ self_attn_mask: Mask for self-attention.
783
+ cross_attn_mask: Mask for cross-attention.
784
+ past_key_values: List containing the self-attention KV cache for each layer
785
+ from the previous decoding step. `len(past_key_values)` should
786
+ equal `num_layers`.
787
+ precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
788
+ derived from `encoder_out`. This is passed identically
789
+ to all layers.
790
+
791
+ Returns:
792
+ A tuple containing:
793
+ - logits: The final output logits (B, T, C * V), cast to float32.
794
+ - present_key_values: A list containing the updated self-attention KV cache
795
+ for each layer for the *current* decoding step.
796
+ """
797
+ _, _, num_channels_in = tgt_ids_BxTxC.shape
798
+ assert num_channels_in == self.num_channels, "Input channels mismatch"
799
+
800
+ # Embeddings
801
+ x = None
802
+ for i in range(self.num_channels):
803
+ channel_tokens = tgt_ids_BxTxC[..., i]
804
+ channel_embed = self.embeddings[i](channel_tokens)
805
+ x = channel_embed if x is None else x + channel_embed
806
+
807
+ if not deterministic:
808
+ x = self.dropout(x)
809
+
810
+ for i, layer in enumerate(self.layers):
811
+ x, _ = layer(
812
+ x,
813
+ encoder_out,
814
+ tgt_positions=tgt_positions,
815
+ src_positions=src_positions,
816
+ deterministic=deterministic,
817
+ self_attn_mask=self_attn_mask,
818
+ cross_attn_mask=cross_attn_mask,
819
+ self_attn_cache=self_attention_cache[i],
820
+ cross_attn_cache=cross_attention_cache[i],
821
+ prefill=True,
822
+ )
823
+
824
+ # Final Norm
825
+ x = self.norm(x)
826
+ logits_BxTxCxV = self.logits_dense(x)
827
+
828
+ return logits_BxTxCxV.to(torch.float32)
829
+
830
+
831
+ class DiaModel(nn.Module):
832
+ """PyTorch Dia Model using DenseGeneral."""
833
+
834
+ def __init__(self, config: DiaConfig):
835
+ super().__init__()
836
+ self.config = config
837
+ self.encoder = Encoder(config)
838
+ self.decoder = Decoder(config)
839
+
840
+ def forward(
841
+ self,
842
+ src_BxS: torch.Tensor,
843
+ tgt_BxTxC: torch.Tensor,
844
+ src_positions: torch.Tensor | None = None,
845
+ tgt_positions: torch.Tensor | None = None,
846
+ enc_self_attn_mask: torch.Tensor | None = None,
847
+ dec_self_attn_mask: torch.Tensor | None = None,
848
+ dec_cross_attn_mask: torch.Tensor | None = None,
849
+ enable_dropout: bool = True,
850
+ ):
851
+ deterministic = not enable_dropout
852
+
853
+ # --- Encoder Pass ---
854
+ encoder_out = self.encoder(
855
+ x_ids=src_BxS,
856
+ src_positions=src_positions,
857
+ deterministic=deterministic,
858
+ attn_mask=enc_self_attn_mask,
859
+ )
860
+
861
+ # --- Decoder Pass ---
862
+ logits, _ = self.decoder(
863
+ tgt_ids_BxTxC=tgt_BxTxC,
864
+ encoder_out=encoder_out,
865
+ tgt_positions=tgt_positions,
866
+ src_positions=src_positions,
867
+ deterministic=deterministic,
868
+ self_attn_mask=dec_self_attn_mask,
869
+ cross_attn_mask=dec_cross_attn_mask,
870
+ precomputed_cross_attn_kv=None,
871
+ )
872
+
873
+ return logits
dia/model.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dac
2
+ import numpy as np
3
+ import torch
4
+ import torchaudio
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ from .audio import audio_to_codebook, codebook_to_audio
8
+ from .config import DiaConfig
9
+ from .layers import DiaModel, KVCache
10
+
11
+
12
+ def _sample_next_token(
13
+ logits_BCxV: torch.Tensor,
14
+ temperature: float,
15
+ top_p: float,
16
+ use_cfg_filter: bool,
17
+ cfg_filter_top_k: int | None = None,
18
+ ) -> torch.Tensor:
19
+ if temperature == 0.0:
20
+ return torch.argmax(logits_BCxV, dim=-1)
21
+
22
+ logits_BCxV = logits_BCxV / temperature
23
+ if use_cfg_filter and cfg_filter_top_k is not None:
24
+ _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
25
+ mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
26
+ mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
27
+ logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
28
+
29
+ if top_p < 1.0:
30
+ probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
31
+ sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True)
32
+ cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
33
+
34
+ # Calculate indices to remove based on top_p
35
+ sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
36
+ # Shift the mask to the right to keep the first token above the threshold
37
+ sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[..., :-1].clone()
38
+ sorted_indices_to_remove_BCxV[..., 0] = 0 # Always keep the most probable token
39
+
40
+ indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
41
+ indices_to_remove_BCxV.scatter_(dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV)
42
+ logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
43
+
44
+ final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
45
+
46
+ sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1)
47
+ sampled_indices_C = sampled_indices_BC.squeeze(-1)
48
+ return sampled_indices_C
49
+
50
+
51
+ class Dia:
52
+ def __init__(self, config: DiaConfig, device: torch.device = torch.device("cuda")):
53
+ """Initializes the Dia model.
54
+
55
+ Args:
56
+ config: The configuration object for the model.
57
+ device: The device to load the model onto.
58
+
59
+ Raises:
60
+ RuntimeError: If there is an error loading the DAC model.
61
+ """
62
+ super().__init__()
63
+ self.config = config
64
+ self.device = device
65
+ self.model = DiaModel(config)
66
+ self.dac_model = None
67
+
68
+ @classmethod
69
+ def from_local(cls, config_path: str, checkpoint_path: str, device: torch.device = torch.device("cuda")) -> "Dia":
70
+ """Loads the Dia model from local configuration and checkpoint files.
71
+
72
+ Args:
73
+ config_path: Path to the configuration JSON file.
74
+ checkpoint_path: Path to the model checkpoint (.pth) file.
75
+ device: The device to load the model onto.
76
+
77
+ Returns:
78
+ An instance of the Dia model loaded with weights and set to eval mode.
79
+
80
+ Raises:
81
+ FileNotFoundError: If the config or checkpoint file is not found.
82
+ RuntimeError: If there is an error loading the checkpoint.
83
+ """
84
+ config = DiaConfig.load(config_path)
85
+ if config is None:
86
+ raise FileNotFoundError(f"Config file not found at {config_path}")
87
+
88
+ dia = cls(config, device)
89
+
90
+ try:
91
+ dia.model.load_state_dict(torch.load(checkpoint_path, map_location=device))
92
+ except FileNotFoundError:
93
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
94
+ except Exception as e:
95
+ raise RuntimeError(f"Error loading checkpoint from {checkpoint_path}") from e
96
+
97
+ dia.model.to(device)
98
+ dia.model.eval()
99
+ dia._load_dac_model()
100
+ return dia
101
+
102
+ @classmethod
103
+ def from_pretrained(
104
+ cls, model_name: str = "nari-labs/Dia-1.6B", device: torch.device = torch.device("cuda")
105
+ ) -> "Dia":
106
+ """Loads the Dia model from a Hugging Face Hub repository.
107
+
108
+ Downloads the configuration and checkpoint files from the specified
109
+ repository ID and then loads the model.
110
+
111
+ Args:
112
+ model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
113
+ device: The device to load the model onto.
114
+
115
+ Returns:
116
+ An instance of the Dia model loaded with weights and set to eval mode.
117
+
118
+ Raises:
119
+ FileNotFoundError: If config or checkpoint download/loading fails.
120
+ RuntimeError: If there is an error loading the checkpoint.
121
+ """
122
+ config_path = hf_hub_download(repo_id=model_name, filename="config.json")
123
+ checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
124
+ return cls.from_local(config_path, checkpoint_path, device)
125
+
126
+ def _load_dac_model(self):
127
+ try:
128
+ dac_model_path = dac.utils.download()
129
+ dac_model = dac.DAC.load(dac_model_path).to(self.device)
130
+ except Exception as e:
131
+ raise RuntimeError("Failed to load DAC model") from e
132
+ self.dac_model = dac_model
133
+
134
+ def _create_attn_mask(
135
+ self,
136
+ q_padding_mask_1d: torch.Tensor,
137
+ k_padding_mask_1d: torch.Tensor,
138
+ is_causal: bool = False,
139
+ ) -> torch.Tensor:
140
+ """
141
+ Creates the attention mask (self or cross) mimicking JAX segment ID logic.
142
+ """
143
+ B1, Tq = q_padding_mask_1d.shape
144
+ B2, Tk = k_padding_mask_1d.shape
145
+ assert B1 == B2, "Query and key batch dimensions must match"
146
+
147
+ p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
148
+ p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
149
+
150
+ # Condition A: Non-padding query attends to non-padding key
151
+ non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
152
+
153
+ # Condition B: Padding query attends to padding key
154
+ pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
155
+
156
+ # Combine: True if padding status is compatible (both non-pad OR both pad)
157
+ # This implementation follows Jax TPU splash attention kernel
158
+ mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
159
+
160
+ if is_causal:
161
+ # Ensure causality for self-attention (Tq == Tk)
162
+ assert Tq == Tk, "Causal mask requires query and key sequence lengths to be equal"
163
+ # Standard lower-triangular causal mask (True means allow)
164
+ causal_mask_2d = torch.tril(torch.ones((Tq, Tk), dtype=torch.bool, device=self.device)) # Shape [Tq, Tk]
165
+ causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
166
+ return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads
167
+ else:
168
+ # For cross-attention or non-causal self-attention
169
+ return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads
170
+
171
+ def _prepare_text_input(self, text: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
172
+ """Encodes text prompt, pads, and creates attention mask and positions."""
173
+ text_pad_value = self.config.data.text_pad_value
174
+ max_len = self.config.data.text_length
175
+
176
+ byte_text = text.encode("utf-8")
177
+ replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02")
178
+ text_tokens = list(replaced_bytes)
179
+
180
+ current_len = len(text_tokens)
181
+ padding_needed = max_len - current_len
182
+ if padding_needed <= 0:
183
+ text_tokens = text_tokens[:max_len]
184
+ padded_text_np = np.array(text_tokens, dtype=np.uint8)
185
+ else:
186
+ padded_text_np = np.pad(
187
+ text_tokens,
188
+ (0, padding_needed),
189
+ mode="constant",
190
+ constant_values=text_pad_value,
191
+ ).astype(np.uint8)
192
+
193
+ src_tokens = torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0) # [1, S]
194
+ src_positions = torch.arange(max_len, device=self.device).to(torch.long).unsqueeze(0) # [1, S]
195
+
196
+ src_padding_mask = (src_tokens != text_pad_value).to(self.device) # [1, S]
197
+
198
+ enc_self_attn_mask = self._create_attn_mask(src_padding_mask, src_padding_mask, is_causal=False) # [1, S, S]
199
+
200
+ return src_tokens, src_positions, src_padding_mask, enc_self_attn_mask
201
+
202
+ @torch.inference_mode()
203
+ def generate(
204
+ self,
205
+ text: str,
206
+ max_tokens: int | None = None,
207
+ cfg_scale: float = 3.0,
208
+ temperature: float = 1.3,
209
+ top_p: float = 0.95,
210
+ use_cfg_filter: bool = True,
211
+ use_torch_compile: bool = True,
212
+ cfg_filter_top_k: int = 100,
213
+ audio_prompt_path: str | None = None,
214
+ ) -> np.ndarray:
215
+ """
216
+ Generates audio from a text prompt (and optional audio prompt) using the Nari model.
217
+
218
+ Returns:
219
+ A tensor of generated audio codes (shape: [max_tokens, num_channels]).
220
+ """
221
+ num_channels = self.config.data.channels
222
+ audio_bos_value = self.config.data.audio_bos_value
223
+ audio_eos_value = self.config.data.audio_eos_value
224
+ audio_pad_value = self.config.data.audio_pad_value
225
+ delay_pattern = self.config.data.delay_pattern
226
+ max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
227
+ delay_tensor = torch.tensor(delay_pattern, dtype=torch.long, device=self.device)
228
+ max_delay_pattern = max(delay_pattern)
229
+ self.model.eval()
230
+
231
+ (
232
+ cond_src_BxS,
233
+ cond_src_positions_BxS,
234
+ cond_src_padding_mask_BxS,
235
+ cond_enc_self_attn_mask_Bx1xSxS,
236
+ ) = self._prepare_text_input(text)
237
+
238
+ unc_src_BxS = torch.zeros_like(cond_src_BxS)
239
+ src_BxS = torch.cat([unc_src_BxS, cond_src_BxS], dim=0)
240
+ src_positions_BxS = cond_src_positions_BxS.expand(2, -1)
241
+ src_padding_mask_BxS = cond_src_padding_mask_BxS.expand(2, -1)
242
+ enc_self_attn_mask_Bx1xSxS = cond_enc_self_attn_mask_Bx1xSxS.expand(2, -1, -1, -1)
243
+
244
+ # 2. Encoder Pass
245
+ # with torch.autocast(device_type="cuda", dtype=forward_dtype):
246
+ encoder_out = self.model.encoder(
247
+ x_ids=src_BxS,
248
+ src_positions=src_positions_BxS,
249
+ deterministic=True,
250
+ attn_mask=enc_self_attn_mask_Bx1xSxS,
251
+ ) # Shape: (B, S, E)
252
+
253
+ # 3. Prepare Decoder Inputs
254
+ # 3-1. Allocate KV Cache (Static)
255
+ decoder_cross_attention_cache: list[KVCache] = self.model.decoder.precompute_cross_attention_kv(
256
+ max_tokens, encoder_out, src_positions_BxS
257
+ )
258
+
259
+ decoder_self_attention_cache: list[KVCache] = []
260
+ for _ in range(self.model.decoder.num_layers):
261
+ decoder_self_attention_cache.append(
262
+ KVCache(
263
+ self.config.model.decoder.gqa_query_heads,
264
+ max_tokens,
265
+ self.config.model.decoder.gqa_head_dim,
266
+ self.device,
267
+ )
268
+ )
269
+
270
+ # 3-2. Initialize Decoder Inputs
271
+ generated_BxTxC = torch.full(
272
+ (2, 1, num_channels),
273
+ fill_value=audio_bos_value,
274
+ dtype=torch.long,
275
+ device=self.device,
276
+ )
277
+
278
+ current_step = 0
279
+ prompt_len_inc_bos = 1 # Start with BOS length
280
+
281
+ # 3-3. Load Audio Prompt (if provided)
282
+ if audio_prompt_path is not None:
283
+ audio_prompt, sr = torchaudio.load(audio_prompt_path, channels_first=True) # C, T
284
+ if sr != 44100: # Resample to 44.1kHz
285
+ audio_prompt = torchaudio.functional.resample(audio_prompt, sr, 44100)
286
+ audio_prompt = audio_prompt.to(self.device).unsqueeze(0) # 1, C, T
287
+ audio_prompt = audio_to_codebook(self.dac_model, audio_prompt, data_config=self.config.data)
288
+ generated_BxTxC = torch.cat([generated_BxTxC, audio_prompt.expand(2, -1, -1)], dim=1)
289
+
290
+ prefill_len = generated_BxTxC.shape[1]
291
+ prompt_len_inc_bos = prefill_len
292
+ prefill_tgt_pos = torch.arange(prefill_len, device=self.device).unsqueeze(0).expand(2, -1)
293
+ prefill_tgt_padding_mask = (generated_BxTxC != audio_pad_value).any(dim=2)
294
+
295
+ prefill_self_attn_mask = self._create_attn_mask(
296
+ prefill_tgt_padding_mask,
297
+ prefill_tgt_padding_mask,
298
+ is_causal=True,
299
+ )
300
+ prefill_cross_attn_mask = self._create_attn_mask(
301
+ prefill_tgt_padding_mask,
302
+ src_padding_mask_BxS,
303
+ is_causal=False,
304
+ )
305
+
306
+ _ = self.model.decoder.forward(
307
+ tgt_ids_BxTxC=generated_BxTxC,
308
+ encoder_out=encoder_out,
309
+ tgt_positions=prefill_tgt_pos,
310
+ src_positions=src_positions_BxS,
311
+ deterministic=True,
312
+ self_attn_mask=prefill_self_attn_mask,
313
+ cross_attn_mask=prefill_cross_attn_mask,
314
+ self_attention_cache=decoder_self_attention_cache,
315
+ cross_attention_cache=decoder_cross_attention_cache,
316
+ )
317
+
318
+ current_step = prefill_len - 1
319
+
320
+ # 4. Autoregressive Generation Loop
321
+ eos_detected_channel_0 = False
322
+ eos_countdown = -1
323
+ extra_steps_after_eos = 30
324
+ # Make generated_BxTxC a fixed size tensor
325
+ # Length is either 1 + max tokens or 1 + prompt len + max tokens
326
+ generated_BxTxC = torch.cat(
327
+ [
328
+ generated_BxTxC,
329
+ torch.full(
330
+ (2, max_tokens, num_channels),
331
+ fill_value=-1,
332
+ dtype=torch.long,
333
+ device=self.device,
334
+ ),
335
+ ],
336
+ dim=1,
337
+ )
338
+
339
+ decode_step = self.model.decoder.decode_step
340
+ if use_torch_compile:
341
+ decode_step = torch.compile(
342
+ self.model.decoder.decode_step,
343
+ mode="default",
344
+ )
345
+
346
+ tgt_padding_mask = (
347
+ (generated_BxTxC[:, -1, :].unsqueeze(1) != audio_pad_value).any(dim=2).to(self.device)
348
+ ) # [B, 1]
349
+ # Generated tokens are never PAD, so we use fixed mask
350
+ decoder_cross_attn_mask = self._create_attn_mask(
351
+ tgt_padding_mask, # Query mask [B, 1]
352
+ src_padding_mask_BxS, # Key mask [B, S]
353
+ is_causal=False,
354
+ ) # [B, 1, 1, S]
355
+
356
+ for step in range(current_step, current_step + max_tokens):
357
+ tgt_ids_Bx1xC = generated_BxTxC[:, step, :].unsqueeze(1)
358
+ tgt_pos_Bx1 = torch.full(
359
+ (2, 1),
360
+ fill_value=step,
361
+ dtype=torch.long,
362
+ device=self.device,
363
+ )
364
+
365
+ logits_Bx1xCxV, new_cache = decode_step(
366
+ tgt_ids_Bx1xC=tgt_ids_Bx1xC,
367
+ tgt_pos_Bx1=tgt_pos_Bx1,
368
+ encoder_out=encoder_out,
369
+ self_attn_mask=None,
370
+ cross_attn_mask=decoder_cross_attn_mask,
371
+ self_attention_cache=decoder_self_attention_cache,
372
+ cross_attention_cache=decoder_cross_attention_cache,
373
+ )
374
+
375
+ for i, layer_cache in enumerate(decoder_self_attention_cache):
376
+ layer_cache.update_cache(new_cache[i][0], new_cache[i][1])
377
+
378
+ V = self.config.model.tgt_vocab_size
379
+ logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :] # B, C, V
380
+ uncond_logits_CxV = logits_last_BxCxV[0, :, :]
381
+ cond_logits_CxV = logits_last_BxCxV[1, :, :]
382
+
383
+ cfg_logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
384
+
385
+ logits_CxV = cfg_logits_CxV.reshape((-1, V)) # C, V
386
+ logits_CxV[:, 1025:] = -torch.inf
387
+
388
+ # Sample next token
389
+ pred_C = _sample_next_token(
390
+ logits_CxV.float(),
391
+ temperature=temperature,
392
+ top_p=top_p,
393
+ use_cfg_filter=use_cfg_filter,
394
+ cfg_filter_top_k=cfg_filter_top_k,
395
+ )
396
+
397
+ generation_step_index = step - current_step
398
+ if audio_prompt_path is None:
399
+ pred_C = torch.where(
400
+ generation_step_index >= delay_tensor,
401
+ pred_C,
402
+ audio_bos_value,
403
+ )
404
+
405
+ generated_BxTxC[:, step + 1, :] = pred_C.unsqueeze(0).expand(2, -1)
406
+
407
+ if not eos_detected_channel_0 and pred_C[0] == audio_eos_value:
408
+ eos_detected_channel_0 = True
409
+ eos_countdown = extra_steps_after_eos
410
+
411
+ if eos_countdown > 0:
412
+ step_after_eos = max_delay_pattern - eos_countdown
413
+ for i, d in enumerate(delay_pattern):
414
+ if step_after_eos == d:
415
+ generated_BxTxC[:, step + 1, i] = audio_eos_value
416
+ elif step_after_eos > d:
417
+ generated_BxTxC[:, step + 1, i] = audio_pad_value
418
+ eos_countdown -= 1
419
+ if eos_countdown == 0:
420
+ break
421
+
422
+ generation_step_index = step - current_step + 1
423
+
424
+ output_codes = generated_BxTxC[:, prompt_len_inc_bos : step + 1, :]
425
+
426
+ generated_codes = output_codes[0]
427
+
428
+ audio = codebook_to_audio(
429
+ generated_codes.transpose(1, 0), self.dac_model, delay_pattern, B=1, T=max_tokens, C=num_channels
430
+ )
431
+ return audio.squeeze().cpu().numpy()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ descript-audio-codec>=1.0.0
2
+ gradio>=5.25.2
3
+ huggingface-hub>=0.30.2
4
+ numpy>=2.2.4
5
+ pydantic>=2.11.3
6
+ soundfile>=0.13.1
7
+ torch>=2.6.0
8
+ torchaudio>=2.6.0