TaiYa1 commited on
Commit
aaaebbd
·
verified ·
1 Parent(s): a249d99

Update tools/llama/generate.py

Browse files
Files changed (1) hide show
  1. tools/llama/generate.py +724 -724
tools/llama/generate.py CHANGED
@@ -1,724 +1,724 @@
1
- import os
2
- import queue
3
- import threading
4
- import time
5
- from contextlib import nullcontext
6
- from dataclasses import dataclass
7
- from pathlib import Path
8
- from typing import Literal, Optional, Tuple, Union
9
- import spaces
10
- import click
11
- import hydra
12
- import numpy as np
13
- import torch
14
- import torch._dynamo.config
15
- import torch._inductor.config
16
- from loguru import logger
17
- from tqdm import tqdm
18
- import spaces
19
- from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
20
- from fish_speech.text import clean_text, split_text
21
-
22
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
- torch._inductor.config.coordinate_descent_tuning = True
24
- torch._inductor.config.triton.unique_kernel_names = True
25
-
26
- zero = torch.Tensor([0]).cuda()
27
-
28
- if hasattr(torch._inductor.config, "fx_graph_cache"):
29
- # Experimental feature to reduce compilation times, will be on by default in future
30
- torch._inductor.config.fx_graph_cache = True
31
-
32
-
33
- from fish_speech.models.text2semantic.llama import (
34
- BaseTransformer,
35
- DualARTransformer,
36
- NaiveTransformer,
37
- )
38
-
39
-
40
- def multinomial_sample_one_no_sync(
41
- probs_sort,
42
- ): # Does multinomial sampling without a cuda synchronization
43
- q = torch.empty_like(probs_sort).exponential_(1)
44
- return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
45
-
46
-
47
- def logits_to_probs(
48
- logits,
49
- previous_tokens: Optional[torch.Tensor] = None,
50
- temperature: torch.Tensor = 1.0,
51
- top_p: torch.Tensor = 1.0,
52
- repetition_penalty: torch.Tensor = 1.0,
53
- ) -> torch.Tensor:
54
- # Apply repetition penalty
55
- if previous_tokens is not None:
56
- previous_tokens = previous_tokens.long()
57
- score = torch.gather(logits, dim=0, index=previous_tokens)
58
- score = torch.where(
59
- score < 0, score * repetition_penalty, score / repetition_penalty
60
- )
61
- logits.scatter_(dim=0, index=previous_tokens, src=score)
62
-
63
- # Apply top-p sampling
64
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
65
- cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
66
- sorted_indices_to_remove = cum_probs > top_p
67
- sorted_indices_to_remove[0] = False # keep at least one option
68
- indices_to_remove = sorted_indices_to_remove.scatter(
69
- dim=0, index=sorted_indices, src=sorted_indices_to_remove
70
- )
71
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
72
-
73
- logits = logits / max(temperature, 1e-5)
74
-
75
- probs = torch.nn.functional.softmax(logits, dim=-1)
76
- return probs
77
-
78
- def sample(
79
- logits,
80
- previous_tokens: Optional[torch.Tensor] = None,
81
- **sampling_kwargs,
82
- ) -> Tuple[torch.Tensor, torch.Tensor]:
83
- probs = logits_to_probs(
84
- logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
85
- )
86
- idx_next = multinomial_sample_one_no_sync(probs)
87
- return idx_next, probs
88
-
89
- def decode_one_token_ar(
90
- model: DualARTransformer,
91
- x: torch.Tensor,
92
- input_pos: torch.Tensor,
93
- previous_tokens: torch.Tensor = None,
94
- **sampling_kwargs,
95
- ) -> torch.Tensor:
96
-
97
- x = model.forward_generate(x, input_pos)
98
-
99
- sampling_kwargs_main = sampling_kwargs.copy()
100
- sampling_kwargs_main["temperature"] = 0.1
101
- sampling_kwargs_main["top_p"] = 0.1
102
- sampling_kwargs_main["repetition_penalty"] = 1.0
103
-
104
- codebooks = [
105
- sample(
106
- x.logits,
107
- previous_tokens=None, # Disable repetition penalty for the token codebook
108
- **sampling_kwargs_main,
109
- )[0]
110
- ]
111
-
112
- x = x.hidden_states
113
-
114
- # Cleanup the cache
115
- for layer in model.fast_layers:
116
- layer.attention.kv_cache.k_cache.fill_(0)
117
- layer.attention.kv_cache.v_cache.fill_(0)
118
-
119
- for codebook_idx in range(model.config.num_codebooks):
120
- input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
121
- logits = model.forward_generate_fast(x, input_pos)
122
- a = sample(
123
- logits,
124
- previous_tokens=(
125
- previous_tokens[codebook_idx + 1]
126
- if previous_tokens is not None
127
- else None
128
- ),
129
- **sampling_kwargs,
130
- )[0]
131
- x = model.fast_embeddings(a)
132
- codebooks.append(a)
133
-
134
- return torch.stack(codebooks, dim=0)
135
-
136
- @torch.no_grad()
137
- def decode_one_token_naive(
138
- model: NaiveTransformer,
139
- x: torch.Tensor,
140
- input_pos: torch.Tensor,
141
- previous_tokens: torch.Tensor = None,
142
- **sampling_kwargs,
143
- ) -> torch.Tensor:
144
-
145
-
146
-
147
- x = model.forward_generate(x, input_pos)
148
-
149
- sampling_kwargs_main = sampling_kwargs.copy()
150
- sampling_kwargs_main["temperature"] = 0.1
151
- sampling_kwargs_main["top_p"] = 0.1
152
- sampling_kwargs_main["repetition_penalty"] = 1.0
153
-
154
- codebooks = [
155
- sample(
156
- x.logits,
157
- previous_tokens=None, # Disable repetition penalty for the token codebook
158
- **sampling_kwargs_main,
159
- )[0]
160
- ]
161
-
162
- for i in range(model.config.num_codebooks):
163
- codebooks.append(
164
- sample(
165
- x.codebook_logits[:, :, i],
166
- previous_tokens=(
167
- previous_tokens[i + 1] if previous_tokens is not None else None
168
- ),
169
- **sampling_kwargs,
170
- )[0]
171
- )
172
-
173
- return torch.stack(codebooks, dim=0)
174
-
175
- @torch.no_grad()
176
-
177
- def decode_n_tokens(
178
- model: NaiveTransformer,
179
- cur_token: torch.Tensor,
180
- input_pos: torch.Tensor,
181
- num_new_tokens: int,
182
- im_end_id: int = 4,
183
- decode_one_token=decode_one_token_naive,
184
- **sampling_kwargs,
185
- ):
186
- previous_tokens = torch.zeros(
187
- (model.config.num_codebooks + 1, model.config.max_seq_len),
188
- dtype=torch.int,
189
- device=cur_token.device,
190
- )
191
-
192
- for i in tqdm(range(num_new_tokens)):
193
- # We need to get windowed repeat penalty
194
- win_size = 16
195
- if i < win_size:
196
- window = previous_tokens[:, :win_size]
197
- else:
198
- window = previous_tokens[:, i - win_size : i]
199
-
200
- with (
201
- torch.backends.cuda.sdp_kernel(
202
- enable_flash=False, enable_mem_efficient=False, enable_math=True
203
- )
204
- if torch.cuda.is_available()
205
- else nullcontext()
206
- ): # Actually better for Inductor to codegen attention here
207
- next_token = decode_one_token(
208
- model=model,
209
- x=cur_token,
210
- input_pos=input_pos,
211
- previous_tokens=window,
212
- **sampling_kwargs,
213
- )
214
-
215
- input_pos += 1
216
- cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
217
- previous_tokens[:, i : i + 1] = next_token.view(
218
- model.config.num_codebooks + 1, -1
219
- )
220
-
221
- if cur_token[0, 0, -1] == im_end_id:
222
- break
223
-
224
- return previous_tokens[:, : i + 1]
225
-
226
-
227
- @torch.no_grad()
228
- @torch.inference_mode()
229
- def generate(
230
- *,
231
- model: NaiveTransformer,
232
- prompt: torch.Tensor,
233
- max_new_tokens: int,
234
- im_end_id: int = 4,
235
- decode_one_token=decode_one_token_naive,
236
- **sampling_kwargs,
237
- ) -> torch.Tensor:
238
- """
239
- Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
240
- """
241
-
242
- # create an empty tensor of the expected final shape and fill in the current tokens
243
-
244
- T = prompt.size(1)
245
-
246
- if max_new_tokens:
247
- if T + max_new_tokens > model.config.max_seq_len:
248
- max_new_tokens = model.config.max_seq_len - T
249
- logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
250
-
251
- T_new = T + max_new_tokens
252
- else:
253
- T_new = model.config.max_seq_len
254
- max_new_tokens = T_new - T
255
-
256
- device, dtype = prompt.device, prompt.dtype
257
- with torch.device(device):
258
- model.setup_caches(
259
- max_batch_size=1,
260
- max_seq_len=model.config.max_seq_len,
261
- dtype=next(model.parameters()).dtype,
262
- )
263
-
264
- codebook_dim = 1 + model.config.num_codebooks
265
- # create an empty tensor of the expected final shape and fill in the current tokens
266
- empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
267
- empty[:, :T] = prompt
268
- seq = empty
269
- input_pos = torch.arange(0, T, device=device)
270
-
271
- # Use non-accelerated version for now, to avoid compilation overhead
272
- prefill_decode = (
273
- decode_one_token_naive
274
- if isinstance(model, NaiveTransformer)
275
- else decode_one_token_ar
276
- )
277
-
278
- next_token = prefill_decode(
279
- model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
280
- )
281
- seq[:, T : T + 1] = next_token
282
-
283
- input_pos = torch.tensor([T], device=device, dtype=torch.int)
284
- x = decode_n_tokens(
285
- model,
286
- next_token.view(1, codebook_dim, -1),
287
- input_pos,
288
- max_new_tokens - 1,
289
- im_end_id=im_end_id,
290
- decode_one_token=decode_one_token,
291
- **sampling_kwargs,
292
- )
293
- # x = torch.cat(generated_tokens, dim=1)
294
- seq = seq[:, : T + 1 + x.size(1)]
295
- seq[:, T + 1 :] = x
296
-
297
- return seq
298
-
299
- @torch.no_grad()
300
- def encode_tokens(
301
- tokenizer,
302
- string,
303
- device="cuda",
304
- prompt_tokens=None,
305
- num_codebooks=4,
306
- ):
307
-
308
- string = clean_text(string)
309
- string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
310
-
311
- new_tokens = tokenizer.encode(
312
- string,
313
- add_special_tokens=False,
314
- max_length=10**6,
315
- truncation=False,
316
- )
317
- tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
318
-
319
- # Codebooks
320
- zeros = (
321
- torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
322
- * CODEBOOK_PAD_TOKEN_ID
323
- )
324
- prompt = torch.cat((tokens, zeros), dim=0)
325
-
326
- if prompt_tokens is None:
327
- return prompt
328
-
329
- # Get prompt tokens
330
- if prompt_tokens.ndim == 3:
331
- assert (
332
- prompt_tokens.shape[0] == 1
333
- ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
334
- prompt_tokens = prompt_tokens[0]
335
-
336
- assert prompt_tokens.ndim == 2
337
- data = prompt_tokens + 1
338
-
339
- if prompt_tokens.shape[0] > num_codebooks:
340
- logger.warning(
341
- f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
342
- )
343
- data = data[:num_codebooks]
344
-
345
- # Add pad token for each codebook
346
- data = torch.cat(
347
- (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
348
- dim=1,
349
- )
350
-
351
- # Since 1.0, we use <|semantic|>
352
- s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
353
- end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
354
- main_token_ids = (
355
- torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
356
- )
357
- main_token_ids[0, -1] = end_token_id
358
-
359
- data = torch.cat((main_token_ids, data), dim=0)
360
- prompt = torch.cat((prompt, data), dim=1)
361
-
362
- return prompt
363
-
364
-
365
- def load_model(checkpoint_path, device, precision, compile=False):
366
- model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
367
- checkpoint_path, load_weights=True
368
- )
369
-
370
- model = model.to(device=device, dtype=precision)
371
- logger.info(f"Restored model from checkpoint")
372
-
373
- if isinstance(model, DualARTransformer):
374
- decode_one_token = decode_one_token_ar
375
- logger.info("Using DualARTransformer")
376
- else:
377
- decode_one_token = decode_one_token_naive
378
- logger.info("Using NaiveTransformer")
379
-
380
- if compile:
381
- logger.info("Compiling function...")
382
- decode_one_token = torch.compile(
383
- decode_one_token,
384
- fullgraph=True,
385
- backend="inductor" if torch.cuda.is_available() else "aot_eager",
386
- mode="reduce-overhead" if torch.cuda.is_available() else None,
387
- )
388
-
389
- return model.eval(), decode_one_token
390
-
391
-
392
- @dataclass
393
- class GenerateResponse:
394
- action: Literal["sample", "next"]
395
- codes: Optional[torch.Tensor] = None
396
- text: Optional[str] = None
397
-
398
- @torch.no_grad()
399
- @spaces.GPU(duration=120)
400
- def generate_long(
401
- *,
402
- model,
403
- device: str | torch.device,
404
- decode_one_token: callable,
405
- text: str,
406
- num_samples: int = 1,
407
- max_new_tokens: int = 0,
408
- top_p: int = 0.7,
409
- repetition_penalty: float = 1.5,
410
- temperature: float = 0.7,
411
- compile: bool = False,
412
- iterative_prompt: bool = True,
413
- max_length: int = 2048,
414
- chunk_length: int = 150,
415
- prompt_text: Optional[str | list[str]] = None,
416
- prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
417
- ):
418
- assert 0 < top_p <= 1, "top_p must be in (0, 1]"
419
- assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
420
- assert 0 < temperature < 2, "temperature must be in (0, 2)"
421
-
422
- use_prompt = prompt_text is not None and prompt_tokens is not None
423
- if use_prompt and isinstance(prompt_text, str):
424
- prompt_text = [prompt_text]
425
- prompt_tokens = [prompt_tokens]
426
-
427
- assert use_prompt is False or len(prompt_text) == len(
428
- prompt_tokens
429
- ), "Prompt text and tokens must have the same length"
430
-
431
- model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
432
- tokenizer = model.tokenizer
433
- im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
434
-
435
- encoded = []
436
- texts = split_text(text, chunk_length) if iterative_prompt else [text]
437
- encoded_prompts = []
438
-
439
- if use_prompt:
440
- for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
441
- encoded_prompts.append(
442
- encode_tokens(
443
- tokenizer,
444
- string=t,
445
- device=device,
446
- prompt_tokens=c,
447
- num_codebooks=model.config.num_codebooks,
448
- )
449
- )
450
-
451
- for idx, text in enumerate(texts):
452
- encoded.append(
453
- encode_tokens(
454
- tokenizer,
455
- string=text,
456
- device=device,
457
- num_codebooks=model.config.num_codebooks,
458
- )
459
- )
460
- logger.info(f"Encoded text: {text}")
461
-
462
- # Move temperature, top_p, repetition_penalty to device
463
- # This is important so that changing params doesn't trigger recompile
464
- temperature = torch.tensor(temperature, device=device, dtype=torch.float)
465
- top_p = torch.tensor(top_p, device=device, dtype=torch.float)
466
- repetition_penalty = torch.tensor(
467
- repetition_penalty, device=device, dtype=torch.float
468
- )
469
-
470
- for sample_idx in range(num_samples):
471
- if torch.cuda.is_available():
472
- torch.cuda.synchronize()
473
-
474
- global_encoded = []
475
- seg_idx = 0
476
-
477
- while seg_idx < len(encoded):
478
- logger.info(
479
- f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
480
- )
481
-
482
- seg = encoded[seg_idx]
483
- global_encoded.append(seg)
484
-
485
- lengths = reversed([seg.size(1) for seg in global_encoded])
486
-
487
- # Pick last 2000 tokens
488
- count = 0
489
- for i, length in enumerate(lengths):
490
- count += length
491
- if count + length > max_length - 1024 - sum(
492
- t.shape[1] for t in encoded_prompts
493
- ):
494
- break
495
-
496
- if i != 0 and i % 2 == 0:
497
- i -= 1
498
-
499
- # Rotate the list, always make sure first segment is included to avoid drift
500
- if i < len(global_encoded) - 2:
501
- partial_encoded = global_encoded[:2] + global_encoded[-i:]
502
- else:
503
- partial_encoded = global_encoded
504
-
505
- if use_prompt:
506
- partial_encoded = encoded_prompts + partial_encoded
507
-
508
- cat_encoded = torch.cat(partial_encoded, dim=1)
509
- prompt_length = cat_encoded.size(1)
510
-
511
- t0 = time.perf_counter()
512
- y = generate(
513
- model=model,
514
- prompt=cat_encoded,
515
- max_new_tokens=max_new_tokens,
516
- im_end_id=im_end_id,
517
- decode_one_token=decode_one_token,
518
- temperature=temperature,
519
- top_p=top_p,
520
- repetition_penalty=repetition_penalty,
521
- )
522
-
523
- if sample_idx == 0 and seg_idx == 0 and compile:
524
- logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
525
-
526
- if torch.cuda.is_available():
527
- torch.cuda.synchronize()
528
-
529
- t = time.perf_counter() - t0
530
-
531
- tokens_generated = y.size(1) - prompt_length
532
- tokens_sec = tokens_generated / t
533
- logger.info(
534
- f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
535
- )
536
- logger.info(
537
- f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
538
- )
539
-
540
- if torch.cuda.is_available():
541
- logger.info(
542
- f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
543
- )
544
-
545
- # Put the generated tokens
546
- # since there is <im_end> and <eos> tokens, we remove last 2 tokens
547
- codes = y[1:, prompt_length:-1].clone()
548
- codes = codes - 1
549
- assert (codes >= 0).all(), f"Negative code found"
550
-
551
- decoded = y[:, prompt_length:-1].clone()
552
- # But for global encoding, we should keep the <im_end> token
553
-
554
- global_encoded.append(decoded)
555
- assert (codes >= 0).all(), f"Negative code found: {codes}"
556
- yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
557
- seg_idx += 1
558
-
559
- # This indicates the end of the current sample
560
- yield GenerateResponse(action="next")
561
-
562
-
563
- @dataclass
564
- class WrappedGenerateResponse:
565
- status: Literal["success", "error"]
566
- response: Optional[GenerateResponse | Exception] = None
567
-
568
-
569
- @dataclass
570
- class GenerateRequest:
571
- request: dict
572
- response_queue: queue.Queue
573
-
574
-
575
- def launch_thread_safe_queue(
576
- checkpoint_path,
577
- device,
578
- precision,
579
- compile: bool = False,
580
- ):
581
- input_queue = queue.Queue()
582
- init_event = threading.Event()
583
-
584
- def worker():
585
- model, decode_one_token = load_model(
586
- checkpoint_path, device, precision, compile=compile
587
- )
588
- init_event.set()
589
-
590
- while True:
591
- item: GenerateRequest | None = input_queue.get()
592
- if item is None:
593
- break
594
-
595
- kwargs = item.request
596
- response_queue = item.response_queue
597
-
598
- try:
599
- for chunk in generate_long(
600
- model=model, decode_one_token=decode_one_token, **kwargs
601
- ):
602
- response_queue.put(
603
- WrappedGenerateResponse(status="success", response=chunk)
604
- )
605
- except Exception as e:
606
- response_queue.put(WrappedGenerateResponse(status="error", response=e))
607
-
608
- threading.Thread(target=worker, daemon=True).start()
609
- init_event.wait()
610
-
611
- return input_queue
612
-
613
-
614
- @click.command()
615
- @click.option(
616
- "--text",
617
- type=str,
618
- default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
619
- )
620
- @click.option("--prompt-text", type=str, default=None, multiple=True)
621
- @click.option(
622
- "--prompt-tokens",
623
- type=click.Path(path_type=Path, exists=True),
624
- default=None,
625
- multiple=True,
626
- )
627
- @click.option("--num-samples", type=int, default=1)
628
- @click.option("--max-new-tokens", type=int, default=0)
629
- @click.option("--top-p", type=float, default=0.7)
630
- @click.option("--repetition-penalty", type=float, default=1.2)
631
- @click.option("--temperature", type=float, default=0.7)
632
- @click.option(
633
- "--checkpoint-path",
634
- type=click.Path(path_type=Path, exists=True),
635
- default="checkpoints/fish-speech-1.4",
636
- )
637
- @click.option("--device", type=str, default="cuda")
638
- @click.option("--compile/--no-compile", default=False)
639
- @click.option("--seed", type=int, default=42)
640
- @click.option("--half/--no-half", default=False)
641
- @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
642
- @click.option("--chunk-length", type=int, default=100)
643
- def main(
644
- text: str,
645
- prompt_text: Optional[list[str]],
646
- prompt_tokens: Optional[list[Path]],
647
- num_samples: int,
648
- max_new_tokens: int,
649
- top_p: int,
650
- repetition_penalty: float,
651
- temperature: float,
652
- checkpoint_path: Path,
653
- device: str,
654
- compile: bool,
655
- seed: int,
656
- half: bool,
657
- iterative_prompt: bool,
658
- chunk_length: int,
659
- ) -> None:
660
-
661
- precision = torch.half if half else torch.bfloat16
662
-
663
- if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
664
- raise ValueError(
665
- f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
666
- )
667
-
668
- logger.info("Loading model ...")
669
- t0 = time.time()
670
- model, decode_one_token = load_model(
671
- checkpoint_path, device, precision, compile=compile
672
- )
673
-
674
- if torch.cuda.is_available():
675
- torch.cuda.synchronize()
676
-
677
- logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
678
-
679
- if prompt_tokens is not None:
680
- prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
681
-
682
- torch.manual_seed(seed)
683
-
684
- if torch.cuda.is_available():
685
- torch.cuda.manual_seed(seed)
686
-
687
- generator = generate_long(
688
- model=model,
689
- device=device,
690
- decode_one_token=decode_one_token,
691
- text=text,
692
- num_samples=num_samples,
693
- max_new_tokens=max_new_tokens,
694
- top_p=top_p,
695
- repetition_penalty=repetition_penalty,
696
- temperature=temperature,
697
- compile=compile,
698
- iterative_prompt=iterative_prompt,
699
- chunk_length=chunk_length,
700
- prompt_text=prompt_text,
701
- prompt_tokens=prompt_tokens,
702
- )
703
-
704
- idx = 0
705
- codes = []
706
-
707
- for response in generator:
708
- if response.action == "sample":
709
- codes.append(response.codes)
710
- logger.info(f"Sampled text: {response.text}")
711
- elif response.action == "next":
712
- if codes:
713
- np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
714
- logger.info(f"Saved codes to codes_{idx}.npy")
715
- logger.info(f"Next sample")
716
- codes = []
717
- idx += 1
718
- else:
719
- logger.error(f"Error: {response}")
720
-
721
-
722
- if __name__ == "__main__":
723
- main()
724
-
 
1
+ import os
2
+ import queue
3
+ import threading
4
+ import time
5
+ from contextlib import nullcontext
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Literal, Optional, Tuple, Union
9
+ import spaces
10
+ import click
11
+ import hydra
12
+ import numpy as np
13
+ import torch
14
+ import torch._dynamo.config
15
+ import torch._inductor.config
16
+ from loguru import logger
17
+ from tqdm import tqdm
18
+ import spaces
19
+ from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
20
+ from fish_speech.text import clean_text, split_text
21
+
22
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
+ torch._inductor.config.coordinate_descent_tuning = True
24
+ torch._inductor.config.triton.unique_kernel_names = True
25
+
26
+ zero = torch.Tensor([0]).cuda()
27
+
28
+ if hasattr(torch._inductor.config, "fx_graph_cache"):
29
+ # Experimental feature to reduce compilation times, will be on by default in future
30
+ torch._inductor.config.fx_graph_cache = True
31
+
32
+
33
+ from fish_speech.models.text2semantic.llama import (
34
+ BaseTransformer,
35
+ DualARTransformer,
36
+ NaiveTransformer,
37
+ )
38
+
39
+
40
+ def multinomial_sample_one_no_sync(
41
+ probs_sort,
42
+ ): # Does multinomial sampling without a cuda synchronization
43
+ q = torch.empty_like(probs_sort).exponential_(1)
44
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
45
+
46
+
47
+ def logits_to_probs(
48
+ logits,
49
+ previous_tokens: Optional[torch.Tensor] = None,
50
+ temperature: torch.Tensor = 1.0,
51
+ top_p: torch.Tensor = 1.0,
52
+ repetition_penalty: torch.Tensor = 1.0,
53
+ ) -> torch.Tensor:
54
+ # Apply repetition penalty
55
+ if previous_tokens is not None:
56
+ previous_tokens = previous_tokens.long()
57
+ score = torch.gather(logits, dim=0, index=previous_tokens)
58
+ score = torch.where(
59
+ score < 0, score * repetition_penalty, score / repetition_penalty
60
+ )
61
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
62
+
63
+ # Apply top-p sampling
64
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
65
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
66
+ sorted_indices_to_remove = cum_probs > top_p
67
+ sorted_indices_to_remove[0] = False # keep at least one option
68
+ indices_to_remove = sorted_indices_to_remove.scatter(
69
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
70
+ )
71
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
72
+
73
+ logits = logits / max(temperature, 1e-5)
74
+
75
+ probs = torch.nn.functional.softmax(logits, dim=-1)
76
+ return probs
77
+
78
+ def sample(
79
+ logits,
80
+ previous_tokens: Optional[torch.Tensor] = None,
81
+ **sampling_kwargs,
82
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
83
+ probs = logits_to_probs(
84
+ logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
85
+ )
86
+ idx_next = multinomial_sample_one_no_sync(probs)
87
+ return idx_next, probs
88
+
89
+ def decode_one_token_ar(
90
+ model: DualARTransformer,
91
+ x: torch.Tensor,
92
+ input_pos: torch.Tensor,
93
+ previous_tokens: torch.Tensor = None,
94
+ **sampling_kwargs,
95
+ ) -> torch.Tensor:
96
+
97
+ x = model.forward_generate(x, input_pos)
98
+
99
+ sampling_kwargs_main = sampling_kwargs.copy()
100
+ sampling_kwargs_main["temperature"] = 0.1
101
+ sampling_kwargs_main["top_p"] = 0.1
102
+ sampling_kwargs_main["repetition_penalty"] = 1.0
103
+
104
+ codebooks = [
105
+ sample(
106
+ x.logits,
107
+ previous_tokens=None, # Disable repetition penalty for the token codebook
108
+ **sampling_kwargs_main,
109
+ )[0]
110
+ ]
111
+
112
+ x = x.hidden_states
113
+
114
+ # Cleanup the cache
115
+ for layer in model.fast_layers:
116
+ layer.attention.kv_cache.k_cache.fill_(0)
117
+ layer.attention.kv_cache.v_cache.fill_(0)
118
+
119
+ for codebook_idx in range(model.config.num_codebooks):
120
+ input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
121
+ logits = model.forward_generate_fast(x, input_pos)
122
+ a = sample(
123
+ logits,
124
+ previous_tokens=(
125
+ previous_tokens[codebook_idx + 1]
126
+ if previous_tokens is not None
127
+ else None
128
+ ),
129
+ **sampling_kwargs,
130
+ )[0]
131
+ x = model.fast_embeddings(a)
132
+ codebooks.append(a)
133
+
134
+ return torch.stack(codebooks, dim=0)
135
+
136
+ @torch.no_grad()
137
+ def decode_one_token_naive(
138
+ model: NaiveTransformer,
139
+ x: torch.Tensor,
140
+ input_pos: torch.Tensor,
141
+ previous_tokens: torch.Tensor = None,
142
+ **sampling_kwargs,
143
+ ) -> torch.Tensor:
144
+
145
+
146
+
147
+ x = model.forward_generate(x, input_pos)
148
+
149
+ sampling_kwargs_main = sampling_kwargs.copy()
150
+ sampling_kwargs_main["temperature"] = 0.1
151
+ sampling_kwargs_main["top_p"] = 0.1
152
+ sampling_kwargs_main["repetition_penalty"] = 1.0
153
+
154
+ codebooks = [
155
+ sample(
156
+ x.logits,
157
+ previous_tokens=None, # Disable repetition penalty for the token codebook
158
+ **sampling_kwargs_main,
159
+ )[0]
160
+ ]
161
+
162
+ for i in range(model.config.num_codebooks):
163
+ codebooks.append(
164
+ sample(
165
+ x.codebook_logits[:, :, i],
166
+ previous_tokens=(
167
+ previous_tokens[i + 1] if previous_tokens is not None else None
168
+ ),
169
+ **sampling_kwargs,
170
+ )[0]
171
+ )
172
+
173
+ return torch.stack(codebooks, dim=0)
174
+
175
+ @torch.no_grad()
176
+
177
+ def decode_n_tokens(
178
+ model: NaiveTransformer,
179
+ cur_token: torch.Tensor,
180
+ input_pos: torch.Tensor,
181
+ num_new_tokens: int,
182
+ im_end_id: int = 4,
183
+ decode_one_token=decode_one_token_naive,
184
+ **sampling_kwargs,
185
+ ):
186
+ previous_tokens = torch.zeros(
187
+ (model.config.num_codebooks + 1, model.config.max_seq_len),
188
+ dtype=torch.int,
189
+ device=cur_token.device,
190
+ )
191
+
192
+ for i in tqdm(range(num_new_tokens)):
193
+ # We need to get windowed repeat penalty
194
+ win_size = 16
195
+ if i < win_size:
196
+ window = previous_tokens[:, :win_size]
197
+ else:
198
+ window = previous_tokens[:, i - win_size : i]
199
+
200
+ with (
201
+ torch.backends.cuda.sdp_kernel(
202
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
203
+ )
204
+ if torch.cuda.is_available()
205
+ else nullcontext()
206
+ ): # Actually better for Inductor to codegen attention here
207
+ next_token = decode_one_token(
208
+ model=model,
209
+ x=cur_token,
210
+ input_pos=input_pos,
211
+ previous_tokens=window,
212
+ **sampling_kwargs,
213
+ )
214
+
215
+ input_pos += 1
216
+ cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
217
+ previous_tokens[:, i : i + 1] = next_token.view(
218
+ model.config.num_codebooks + 1, -1
219
+ )
220
+
221
+ if cur_token[0, 0, -1] == im_end_id:
222
+ break
223
+
224
+ return previous_tokens[:, : i + 1]
225
+
226
+
227
+ @torch.no_grad()
228
+ @torch.inference_mode()
229
+ def generate(
230
+ *,
231
+ model: NaiveTransformer,
232
+ prompt: torch.Tensor,
233
+ max_new_tokens: int,
234
+ im_end_id: int = 4,
235
+ decode_one_token=decode_one_token_naive,
236
+ **sampling_kwargs,
237
+ ) -> torch.Tensor:
238
+ """
239
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
240
+ """
241
+
242
+ # create an empty tensor of the expected final shape and fill in the current tokens
243
+
244
+ T = prompt.size(1)
245
+
246
+ if max_new_tokens:
247
+ if T + max_new_tokens > model.config.max_seq_len:
248
+ max_new_tokens = model.config.max_seq_len - T
249
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
250
+
251
+ T_new = T + max_new_tokens
252
+ else:
253
+ T_new = model.config.max_seq_len
254
+ max_new_tokens = T_new - T
255
+
256
+ device, dtype = prompt.device, prompt.dtype
257
+ with torch.device(device):
258
+ model.setup_caches(
259
+ max_batch_size=1,
260
+ max_seq_len=model.config.max_seq_len,
261
+ dtype=next(model.parameters()).dtype,
262
+ )
263
+
264
+ codebook_dim = 1 + model.config.num_codebooks
265
+ # create an empty tensor of the expected final shape and fill in the current tokens
266
+ empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
267
+ empty[:, :T] = prompt
268
+ seq = empty
269
+ input_pos = torch.arange(0, T, device=device)
270
+
271
+ # Use non-accelerated version for now, to avoid compilation overhead
272
+ prefill_decode = (
273
+ decode_one_token_naive
274
+ if isinstance(model, NaiveTransformer)
275
+ else decode_one_token_ar
276
+ )
277
+
278
+ next_token = prefill_decode(
279
+ model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
280
+ )
281
+ seq[:, T : T + 1] = next_token
282
+
283
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
284
+ x = decode_n_tokens(
285
+ model,
286
+ next_token.view(1, codebook_dim, -1),
287
+ input_pos,
288
+ max_new_tokens - 1,
289
+ im_end_id=im_end_id,
290
+ decode_one_token=decode_one_token,
291
+ **sampling_kwargs,
292
+ )
293
+ # x = torch.cat(generated_tokens, dim=1)
294
+ seq = seq[:, : T + 1 + x.size(1)]
295
+ seq[:, T + 1 :] = x
296
+
297
+ return seq
298
+
299
+ @torch.no_grad()
300
+ def encode_tokens(
301
+ tokenizer,
302
+ string,
303
+ device="cuda",
304
+ prompt_tokens=None,
305
+ num_codebooks=4,
306
+ ):
307
+
308
+ string = clean_text(string)
309
+ string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
310
+
311
+ new_tokens = tokenizer.encode(
312
+ string,
313
+ add_special_tokens=False,
314
+ max_length=10**6,
315
+ truncation=False,
316
+ )
317
+ tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
318
+
319
+ # Codebooks
320
+ zeros = (
321
+ torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
322
+ * CODEBOOK_PAD_TOKEN_ID
323
+ )
324
+ prompt = torch.cat((tokens, zeros), dim=0)
325
+
326
+ if prompt_tokens is None:
327
+ return prompt
328
+
329
+ # Get prompt tokens
330
+ if prompt_tokens.ndim == 3:
331
+ assert (
332
+ prompt_tokens.shape[0] == 1
333
+ ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
334
+ prompt_tokens = prompt_tokens[0]
335
+
336
+ assert prompt_tokens.ndim == 2
337
+ data = prompt_tokens + 1
338
+
339
+ if prompt_tokens.shape[0] > num_codebooks:
340
+ logger.warning(
341
+ f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
342
+ )
343
+ data = data[:num_codebooks]
344
+
345
+ # Add pad token for each codebook
346
+ data = torch.cat(
347
+ (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
348
+ dim=1,
349
+ )
350
+
351
+ # Since 1.0, we use <|semantic|>
352
+ s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
353
+ end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
354
+ main_token_ids = (
355
+ torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
356
+ )
357
+ main_token_ids[0, -1] = end_token_id
358
+
359
+ data = torch.cat((main_token_ids, data), dim=0)
360
+ prompt = torch.cat((prompt, data), dim=1)
361
+
362
+ return prompt
363
+
364
+
365
+ def load_model(checkpoint_path, device, precision, compile=False):
366
+ model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
367
+ checkpoint_path, load_weights=True
368
+ )
369
+
370
+ model = model.to(device=device, dtype=precision)
371
+ logger.info(f"Restored model from checkpoint")
372
+
373
+ if isinstance(model, DualARTransformer):
374
+ decode_one_token = decode_one_token_ar
375
+ logger.info("Using DualARTransformer")
376
+ else:
377
+ decode_one_token = decode_one_token_naive
378
+ logger.info("Using NaiveTransformer")
379
+
380
+ if compile:
381
+ logger.info("Compiling function...")
382
+ decode_one_token = torch.compile(
383
+ decode_one_token,
384
+ fullgraph=True,
385
+ backend="inductor" if torch.cuda.is_available() else "aot_eager",
386
+ mode="reduce-overhead" if torch.cuda.is_available() else None,
387
+ )
388
+
389
+ return model.eval(), decode_one_token
390
+
391
+
392
+ @dataclass
393
+ class GenerateResponse:
394
+ action: Literal["sample", "next"]
395
+ codes: Optional[torch.Tensor] = None
396
+ text: Optional[str] = None
397
+
398
+ @torch.no_grad()
399
+ @spaces.GPU
400
+ def generate_long(
401
+ *,
402
+ model,
403
+ device: str | torch.device,
404
+ decode_one_token: callable,
405
+ text: str,
406
+ num_samples: int = 1,
407
+ max_new_tokens: int = 0,
408
+ top_p: int = 0.7,
409
+ repetition_penalty: float = 1.5,
410
+ temperature: float = 0.7,
411
+ compile: bool = False,
412
+ iterative_prompt: bool = True,
413
+ max_length: int = 2048,
414
+ chunk_length: int = 150,
415
+ prompt_text: Optional[str | list[str]] = None,
416
+ prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
417
+ ):
418
+ assert 0 < top_p <= 1, "top_p must be in (0, 1]"
419
+ assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
420
+ assert 0 < temperature < 2, "temperature must be in (0, 2)"
421
+
422
+ use_prompt = prompt_text is not None and prompt_tokens is not None
423
+ if use_prompt and isinstance(prompt_text, str):
424
+ prompt_text = [prompt_text]
425
+ prompt_tokens = [prompt_tokens]
426
+
427
+ assert use_prompt is False or len(prompt_text) == len(
428
+ prompt_tokens
429
+ ), "Prompt text and tokens must have the same length"
430
+
431
+ model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
432
+ tokenizer = model.tokenizer
433
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
434
+
435
+ encoded = []
436
+ texts = split_text(text, chunk_length) if iterative_prompt else [text]
437
+ encoded_prompts = []
438
+
439
+ if use_prompt:
440
+ for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
441
+ encoded_prompts.append(
442
+ encode_tokens(
443
+ tokenizer,
444
+ string=t,
445
+ device=device,
446
+ prompt_tokens=c,
447
+ num_codebooks=model.config.num_codebooks,
448
+ )
449
+ )
450
+
451
+ for idx, text in enumerate(texts):
452
+ encoded.append(
453
+ encode_tokens(
454
+ tokenizer,
455
+ string=text,
456
+ device=device,
457
+ num_codebooks=model.config.num_codebooks,
458
+ )
459
+ )
460
+ logger.info(f"Encoded text: {text}")
461
+
462
+ # Move temperature, top_p, repetition_penalty to device
463
+ # This is important so that changing params doesn't trigger recompile
464
+ temperature = torch.tensor(temperature, device=device, dtype=torch.float)
465
+ top_p = torch.tensor(top_p, device=device, dtype=torch.float)
466
+ repetition_penalty = torch.tensor(
467
+ repetition_penalty, device=device, dtype=torch.float
468
+ )
469
+
470
+ for sample_idx in range(num_samples):
471
+ if torch.cuda.is_available():
472
+ torch.cuda.synchronize()
473
+
474
+ global_encoded = []
475
+ seg_idx = 0
476
+
477
+ while seg_idx < len(encoded):
478
+ logger.info(
479
+ f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
480
+ )
481
+
482
+ seg = encoded[seg_idx]
483
+ global_encoded.append(seg)
484
+
485
+ lengths = reversed([seg.size(1) for seg in global_encoded])
486
+
487
+ # Pick last 2000 tokens
488
+ count = 0
489
+ for i, length in enumerate(lengths):
490
+ count += length
491
+ if count + length > max_length - 1024 - sum(
492
+ t.shape[1] for t in encoded_prompts
493
+ ):
494
+ break
495
+
496
+ if i != 0 and i % 2 == 0:
497
+ i -= 1
498
+
499
+ # Rotate the list, always make sure first segment is included to avoid drift
500
+ if i < len(global_encoded) - 2:
501
+ partial_encoded = global_encoded[:2] + global_encoded[-i:]
502
+ else:
503
+ partial_encoded = global_encoded
504
+
505
+ if use_prompt:
506
+ partial_encoded = encoded_prompts + partial_encoded
507
+
508
+ cat_encoded = torch.cat(partial_encoded, dim=1)
509
+ prompt_length = cat_encoded.size(1)
510
+
511
+ t0 = time.perf_counter()
512
+ y = generate(
513
+ model=model,
514
+ prompt=cat_encoded,
515
+ max_new_tokens=max_new_tokens,
516
+ im_end_id=im_end_id,
517
+ decode_one_token=decode_one_token,
518
+ temperature=temperature,
519
+ top_p=top_p,
520
+ repetition_penalty=repetition_penalty,
521
+ )
522
+
523
+ if sample_idx == 0 and seg_idx == 0 and compile:
524
+ logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
525
+
526
+ if torch.cuda.is_available():
527
+ torch.cuda.synchronize()
528
+
529
+ t = time.perf_counter() - t0
530
+
531
+ tokens_generated = y.size(1) - prompt_length
532
+ tokens_sec = tokens_generated / t
533
+ logger.info(
534
+ f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
535
+ )
536
+ logger.info(
537
+ f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
538
+ )
539
+
540
+ if torch.cuda.is_available():
541
+ logger.info(
542
+ f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
543
+ )
544
+
545
+ # Put the generated tokens
546
+ # since there is <im_end> and <eos> tokens, we remove last 2 tokens
547
+ codes = y[1:, prompt_length:-1].clone()
548
+ codes = codes - 1
549
+ assert (codes >= 0).all(), f"Negative code found"
550
+
551
+ decoded = y[:, prompt_length:-1].clone()
552
+ # But for global encoding, we should keep the <im_end> token
553
+
554
+ global_encoded.append(decoded)
555
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
556
+ yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
557
+ seg_idx += 1
558
+
559
+ # This indicates the end of the current sample
560
+ yield GenerateResponse(action="next")
561
+
562
+
563
+ @dataclass
564
+ class WrappedGenerateResponse:
565
+ status: Literal["success", "error"]
566
+ response: Optional[GenerateResponse | Exception] = None
567
+
568
+
569
+ @dataclass
570
+ class GenerateRequest:
571
+ request: dict
572
+ response_queue: queue.Queue
573
+
574
+
575
+ def launch_thread_safe_queue(
576
+ checkpoint_path,
577
+ device,
578
+ precision,
579
+ compile: bool = False,
580
+ ):
581
+ input_queue = queue.Queue()
582
+ init_event = threading.Event()
583
+
584
+ def worker():
585
+ model, decode_one_token = load_model(
586
+ checkpoint_path, device, precision, compile=compile
587
+ )
588
+ init_event.set()
589
+
590
+ while True:
591
+ item: GenerateRequest | None = input_queue.get()
592
+ if item is None:
593
+ break
594
+
595
+ kwargs = item.request
596
+ response_queue = item.response_queue
597
+
598
+ try:
599
+ for chunk in generate_long(
600
+ model=model, decode_one_token=decode_one_token, **kwargs
601
+ ):
602
+ response_queue.put(
603
+ WrappedGenerateResponse(status="success", response=chunk)
604
+ )
605
+ except Exception as e:
606
+ response_queue.put(WrappedGenerateResponse(status="error", response=e))
607
+
608
+ threading.Thread(target=worker, daemon=True).start()
609
+ init_event.wait()
610
+
611
+ return input_queue
612
+
613
+
614
+ @click.command()
615
+ @click.option(
616
+ "--text",
617
+ type=str,
618
+ default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
619
+ )
620
+ @click.option("--prompt-text", type=str, default=None, multiple=True)
621
+ @click.option(
622
+ "--prompt-tokens",
623
+ type=click.Path(path_type=Path, exists=True),
624
+ default=None,
625
+ multiple=True,
626
+ )
627
+ @click.option("--num-samples", type=int, default=1)
628
+ @click.option("--max-new-tokens", type=int, default=0)
629
+ @click.option("--top-p", type=float, default=0.7)
630
+ @click.option("--repetition-penalty", type=float, default=1.2)
631
+ @click.option("--temperature", type=float, default=0.7)
632
+ @click.option(
633
+ "--checkpoint-path",
634
+ type=click.Path(path_type=Path, exists=True),
635
+ default="checkpoints/fish-speech-1.4",
636
+ )
637
+ @click.option("--device", type=str, default="cuda")
638
+ @click.option("--compile/--no-compile", default=False)
639
+ @click.option("--seed", type=int, default=42)
640
+ @click.option("--half/--no-half", default=False)
641
+ @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
642
+ @click.option("--chunk-length", type=int, default=100)
643
+ def main(
644
+ text: str,
645
+ prompt_text: Optional[list[str]],
646
+ prompt_tokens: Optional[list[Path]],
647
+ num_samples: int,
648
+ max_new_tokens: int,
649
+ top_p: int,
650
+ repetition_penalty: float,
651
+ temperature: float,
652
+ checkpoint_path: Path,
653
+ device: str,
654
+ compile: bool,
655
+ seed: int,
656
+ half: bool,
657
+ iterative_prompt: bool,
658
+ chunk_length: int,
659
+ ) -> None:
660
+
661
+ precision = torch.half if half else torch.bfloat16
662
+
663
+ if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
664
+ raise ValueError(
665
+ f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
666
+ )
667
+
668
+ logger.info("Loading model ...")
669
+ t0 = time.time()
670
+ model, decode_one_token = load_model(
671
+ checkpoint_path, device, precision, compile=compile
672
+ )
673
+
674
+ if torch.cuda.is_available():
675
+ torch.cuda.synchronize()
676
+
677
+ logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
678
+
679
+ if prompt_tokens is not None:
680
+ prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
681
+
682
+ torch.manual_seed(seed)
683
+
684
+ if torch.cuda.is_available():
685
+ torch.cuda.manual_seed(seed)
686
+
687
+ generator = generate_long(
688
+ model=model,
689
+ device=device,
690
+ decode_one_token=decode_one_token,
691
+ text=text,
692
+ num_samples=num_samples,
693
+ max_new_tokens=max_new_tokens,
694
+ top_p=top_p,
695
+ repetition_penalty=repetition_penalty,
696
+ temperature=temperature,
697
+ compile=compile,
698
+ iterative_prompt=iterative_prompt,
699
+ chunk_length=chunk_length,
700
+ prompt_text=prompt_text,
701
+ prompt_tokens=prompt_tokens,
702
+ )
703
+
704
+ idx = 0
705
+ codes = []
706
+
707
+ for response in generator:
708
+ if response.action == "sample":
709
+ codes.append(response.codes)
710
+ logger.info(f"Sampled text: {response.text}")
711
+ elif response.action == "next":
712
+ if codes:
713
+ np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
714
+ logger.info(f"Saved codes to codes_{idx}.npy")
715
+ logger.info(f"Next sample")
716
+ codes = []
717
+ idx += 1
718
+ else:
719
+ logger.error(f"Error: {response}")
720
+
721
+
722
+ if __name__ == "__main__":
723
+ main()
724
+