Spaces:
Runtime error
Runtime error
Optimize graph
Browse files- app.py +5 -10
- tools/llama/generate.py +37 -26
app.py
CHANGED
|
@@ -41,6 +41,9 @@ Related code are released under BSD-3-Clause License, and weights are released u
|
|
| 41 |
|
| 42 |
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
|
| 43 |
我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
|
|
|
|
|
|
|
|
|
|
| 44 |
"""
|
| 45 |
|
| 46 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
|
@@ -76,7 +79,6 @@ def inference(
|
|
| 76 |
reference_text,
|
| 77 |
max_new_tokens,
|
| 78 |
chunk_length,
|
| 79 |
-
top_k,
|
| 80 |
top_p,
|
| 81 |
repetition_penalty,
|
| 82 |
temperature,
|
|
@@ -112,7 +114,6 @@ def inference(
|
|
| 112 |
device=vqgan_model.device,
|
| 113 |
max_new_tokens=max_new_tokens,
|
| 114 |
text=text,
|
| 115 |
-
top_k=int(top_k) if top_k > 0 else None,
|
| 116 |
top_p=top_p,
|
| 117 |
repetition_penalty=repetition_penalty,
|
| 118 |
temperature=temperature,
|
|
@@ -194,10 +195,6 @@ def build_app():
|
|
| 194 |
step=8,
|
| 195 |
)
|
| 196 |
|
| 197 |
-
top_k = gr.Slider(
|
| 198 |
-
label="Top-K", minimum=0, maximum=5, value=0, step=1
|
| 199 |
-
)
|
| 200 |
-
|
| 201 |
top_p = gr.Slider(
|
| 202 |
label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
|
| 203 |
)
|
|
@@ -264,7 +261,6 @@ def build_app():
|
|
| 264 |
reference_text,
|
| 265 |
max_new_tokens,
|
| 266 |
chunk_length,
|
| 267 |
-
top_k,
|
| 268 |
top_p,
|
| 269 |
repetition_penalty,
|
| 270 |
temperature,
|
|
@@ -310,8 +306,8 @@ if __name__ == "__main__":
|
|
| 310 |
args.compile = True
|
| 311 |
args.max_gradio_length = 1024
|
| 312 |
args.tokenizer = "./checkpoints/fish-speech-1"
|
| 313 |
-
args.llama_checkpoint_path = "./checkpoints/fish-speech-1/text2semantic-sft-
|
| 314 |
-
args.llama_config_name = "
|
| 315 |
args.vqgan_checkpoint_path = "./checkpoints/fish-speech-1/vq-gan-group-fsq-2x1024.pth"
|
| 316 |
args.vqgan_config_name = "vqgan_pretrain"
|
| 317 |
|
|
@@ -343,7 +339,6 @@ if __name__ == "__main__":
|
|
| 343 |
reference_text="",
|
| 344 |
max_new_tokens=0,
|
| 345 |
chunk_length=0,
|
| 346 |
-
top_k=0, # 0 means no limit
|
| 347 |
top_p=0.7,
|
| 348 |
repetition_penalty=1.5,
|
| 349 |
temperature=0.7,
|
|
|
|
| 41 |
|
| 42 |
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
|
| 43 |
我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
|
| 44 |
+
|
| 45 |
+
The model running in this WebUI is Fish Speech V1 Medium SFT 4K.
|
| 46 |
+
在此 WebUI 中运行的模型是 Fish Speech V1 Medium SFT 4K.
|
| 47 |
"""
|
| 48 |
|
| 49 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
|
|
|
| 79 |
reference_text,
|
| 80 |
max_new_tokens,
|
| 81 |
chunk_length,
|
|
|
|
| 82 |
top_p,
|
| 83 |
repetition_penalty,
|
| 84 |
temperature,
|
|
|
|
| 114 |
device=vqgan_model.device,
|
| 115 |
max_new_tokens=max_new_tokens,
|
| 116 |
text=text,
|
|
|
|
| 117 |
top_p=top_p,
|
| 118 |
repetition_penalty=repetition_penalty,
|
| 119 |
temperature=temperature,
|
|
|
|
| 195 |
step=8,
|
| 196 |
)
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
top_p = gr.Slider(
|
| 199 |
label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
|
| 200 |
)
|
|
|
|
| 261 |
reference_text,
|
| 262 |
max_new_tokens,
|
| 263 |
chunk_length,
|
|
|
|
| 264 |
top_p,
|
| 265 |
repetition_penalty,
|
| 266 |
temperature,
|
|
|
|
| 306 |
args.compile = True
|
| 307 |
args.max_gradio_length = 1024
|
| 308 |
args.tokenizer = "./checkpoints/fish-speech-1"
|
| 309 |
+
args.llama_checkpoint_path = "./checkpoints/fish-speech-1/text2semantic-sft-medium-v1-4k.pth"
|
| 310 |
+
args.llama_config_name = "dual_ar_2_codebook_medium"
|
| 311 |
args.vqgan_checkpoint_path = "./checkpoints/fish-speech-1/vq-gan-group-fsq-2x1024.pth"
|
| 312 |
args.vqgan_config_name = "vqgan_pretrain"
|
| 313 |
|
|
|
|
| 339 |
reference_text="",
|
| 340 |
max_new_tokens=0,
|
| 341 |
chunk_length=0,
|
|
|
|
| 342 |
top_p=0.7,
|
| 343 |
repetition_penalty=1.5,
|
| 344 |
temperature=0.7,
|
tools/llama/generate.py
CHANGED
|
@@ -42,11 +42,11 @@ def multinomial_sample_one_no_sync(
|
|
| 42 |
def logits_to_probs(
|
| 43 |
logits,
|
| 44 |
previous_tokens: Optional[torch.Tensor] = None,
|
| 45 |
-
temperature:
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
if previous_tokens is not None:
|
| 51 |
previous_tokens = previous_tokens.long()
|
| 52 |
score = torch.gather(logits, dim=0, index=previous_tokens)
|
|
@@ -55,11 +55,9 @@ def logits_to_probs(
|
|
| 55 |
)
|
| 56 |
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
| 57 |
|
| 58 |
-
#
|
| 59 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 60 |
-
cum_probs = torch.cumsum(
|
| 61 |
-
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
| 62 |
-
)
|
| 63 |
sorted_indices_to_remove = cum_probs > top_p
|
| 64 |
sorted_indices_to_remove[0] = False # keep at least one option
|
| 65 |
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
@@ -69,11 +67,6 @@ def logits_to_probs(
|
|
| 69 |
|
| 70 |
logits = logits / max(temperature, 1e-5)
|
| 71 |
|
| 72 |
-
# if top_k is not None:
|
| 73 |
-
# v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 74 |
-
# pivot = v.select(-1, -1).unsqueeze(-1)
|
| 75 |
-
# logits = torch.where(logits < pivot, -float("Inf"), logits)
|
| 76 |
-
|
| 77 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 78 |
return probs
|
| 79 |
|
|
@@ -449,7 +442,6 @@ def generate_long(
|
|
| 449 |
text: str,
|
| 450 |
num_samples: int = 1,
|
| 451 |
max_new_tokens: int = 0,
|
| 452 |
-
top_k: int = None,
|
| 453 |
top_p: int = 0.7,
|
| 454 |
repetition_penalty: float = 1.5,
|
| 455 |
temperature: float = 0.7,
|
|
@@ -462,6 +454,10 @@ def generate_long(
|
|
| 462 |
prompt_tokens: Optional[torch.Tensor] = None,
|
| 463 |
is_streaming: bool = False,
|
| 464 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 466 |
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 467 |
|
|
@@ -493,8 +489,18 @@ def generate_long(
|
|
| 493 |
)
|
| 494 |
logger.info(f"Encoded text: {text}")
|
| 495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
for sample_idx in range(num_samples):
|
| 497 |
-
torch.cuda.
|
|
|
|
|
|
|
| 498 |
global_encoded = []
|
| 499 |
all_codes = []
|
| 500 |
seg_idx = 0
|
|
@@ -540,7 +546,6 @@ def generate_long(
|
|
| 540 |
im_end_id=im_end_id,
|
| 541 |
decode_one_token=decode_one_token,
|
| 542 |
temperature=temperature,
|
| 543 |
-
top_k=top_k,
|
| 544 |
top_p=top_p,
|
| 545 |
repetition_penalty=repetition_penalty,
|
| 546 |
)
|
|
@@ -548,7 +553,9 @@ def generate_long(
|
|
| 548 |
if sample_idx == 0 and seg_idx == 0 and compile:
|
| 549 |
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
| 550 |
|
| 551 |
-
torch.cuda.
|
|
|
|
|
|
|
| 552 |
t = time.perf_counter() - t0
|
| 553 |
|
| 554 |
tokens_generated = y.size(1) - prompt_length
|
|
@@ -559,9 +566,11 @@ def generate_long(
|
|
| 559 |
logger.info(
|
| 560 |
f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
|
| 561 |
)
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
|
|
|
|
|
|
| 565 |
|
| 566 |
# Put the generated tokens
|
| 567 |
# since there is <im_end> and <eos> tokens, we remove last 2 tokens
|
|
@@ -654,7 +663,6 @@ def launch_thread_safe_queue(
|
|
| 654 |
)
|
| 655 |
@click.option("--num-samples", type=int, default=1)
|
| 656 |
@click.option("--max-new-tokens", type=int, default=0)
|
| 657 |
-
@click.option("--top-k", type=int, default=None)
|
| 658 |
@click.option("--top-p", type=float, default=0.7)
|
| 659 |
@click.option("--repetition-penalty", type=float, default=1.5)
|
| 660 |
@click.option("--temperature", type=float, default=0.7)
|
|
@@ -678,7 +686,6 @@ def main(
|
|
| 678 |
prompt_tokens: Optional[Path],
|
| 679 |
num_samples: int,
|
| 680 |
max_new_tokens: int,
|
| 681 |
-
top_k: int,
|
| 682 |
top_p: int,
|
| 683 |
repetition_penalty: float,
|
| 684 |
temperature: float,
|
|
@@ -702,7 +709,10 @@ def main(
|
|
| 702 |
model, decode_one_token = load_model(
|
| 703 |
config_name, checkpoint_path, device, precision, max_length, compile=compile
|
| 704 |
)
|
| 705 |
-
|
|
|
|
|
|
|
|
|
|
| 706 |
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
| 707 |
|
| 708 |
prompt_tokens = (
|
|
@@ -713,7 +723,9 @@ def main(
|
|
| 713 |
|
| 714 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
| 715 |
torch.manual_seed(seed)
|
| 716 |
-
|
|
|
|
|
|
|
| 717 |
|
| 718 |
generator = generate_long(
|
| 719 |
model=model,
|
|
@@ -722,7 +734,6 @@ def main(
|
|
| 722 |
text=text,
|
| 723 |
num_samples=num_samples,
|
| 724 |
max_new_tokens=max_new_tokens,
|
| 725 |
-
top_k=top_k,
|
| 726 |
top_p=top_p,
|
| 727 |
repetition_penalty=repetition_penalty,
|
| 728 |
temperature=temperature,
|
|
|
|
| 42 |
def logits_to_probs(
|
| 43 |
logits,
|
| 44 |
previous_tokens: Optional[torch.Tensor] = None,
|
| 45 |
+
temperature: torch.Tensor = 1.0,
|
| 46 |
+
top_p: torch.Tensor = 1.0,
|
| 47 |
+
repetition_penalty: torch.Tensor = 1.0,
|
| 48 |
+
) -> torch.Tensor:
|
| 49 |
+
# Apply repetition penalty
|
| 50 |
if previous_tokens is not None:
|
| 51 |
previous_tokens = previous_tokens.long()
|
| 52 |
score = torch.gather(logits, dim=0, index=previous_tokens)
|
|
|
|
| 55 |
)
|
| 56 |
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
| 57 |
|
| 58 |
+
# Apply top-p sampling
|
| 59 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 60 |
+
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
|
|
|
|
| 61 |
sorted_indices_to_remove = cum_probs > top_p
|
| 62 |
sorted_indices_to_remove[0] = False # keep at least one option
|
| 63 |
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
|
|
| 67 |
|
| 68 |
logits = logits / max(temperature, 1e-5)
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 71 |
return probs
|
| 72 |
|
|
|
|
| 442 |
text: str,
|
| 443 |
num_samples: int = 1,
|
| 444 |
max_new_tokens: int = 0,
|
|
|
|
| 445 |
top_p: int = 0.7,
|
| 446 |
repetition_penalty: float = 1.5,
|
| 447 |
temperature: float = 0.7,
|
|
|
|
| 454 |
prompt_tokens: Optional[torch.Tensor] = None,
|
| 455 |
is_streaming: bool = False,
|
| 456 |
):
|
| 457 |
+
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
| 458 |
+
assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
|
| 459 |
+
assert 0 < temperature < 2, "temperature must be in (0, 2)"
|
| 460 |
+
|
| 461 |
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 462 |
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 463 |
|
|
|
|
| 489 |
)
|
| 490 |
logger.info(f"Encoded text: {text}")
|
| 491 |
|
| 492 |
+
# Move temperature, top_p, repetition_penalty to device
|
| 493 |
+
# This is important so that changing params doesn't trigger recompile
|
| 494 |
+
temperature = torch.tensor(temperature, device=device, dtype=torch.float)
|
| 495 |
+
top_p = torch.tensor(top_p, device=device, dtype=torch.float)
|
| 496 |
+
repetition_penalty = torch.tensor(
|
| 497 |
+
repetition_penalty, device=device, dtype=torch.float
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
for sample_idx in range(num_samples):
|
| 501 |
+
if torch.cuda.is_available():
|
| 502 |
+
torch.cuda.synchronize()
|
| 503 |
+
|
| 504 |
global_encoded = []
|
| 505 |
all_codes = []
|
| 506 |
seg_idx = 0
|
|
|
|
| 546 |
im_end_id=im_end_id,
|
| 547 |
decode_one_token=decode_one_token,
|
| 548 |
temperature=temperature,
|
|
|
|
| 549 |
top_p=top_p,
|
| 550 |
repetition_penalty=repetition_penalty,
|
| 551 |
)
|
|
|
|
| 553 |
if sample_idx == 0 and seg_idx == 0 and compile:
|
| 554 |
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
| 555 |
|
| 556 |
+
if torch.cuda.is_available():
|
| 557 |
+
torch.cuda.synchronize()
|
| 558 |
+
|
| 559 |
t = time.perf_counter() - t0
|
| 560 |
|
| 561 |
tokens_generated = y.size(1) - prompt_length
|
|
|
|
| 566 |
logger.info(
|
| 567 |
f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
|
| 568 |
)
|
| 569 |
+
|
| 570 |
+
if torch.cuda.is_available():
|
| 571 |
+
logger.info(
|
| 572 |
+
f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
|
| 573 |
+
)
|
| 574 |
|
| 575 |
# Put the generated tokens
|
| 576 |
# since there is <im_end> and <eos> tokens, we remove last 2 tokens
|
|
|
|
| 663 |
)
|
| 664 |
@click.option("--num-samples", type=int, default=1)
|
| 665 |
@click.option("--max-new-tokens", type=int, default=0)
|
|
|
|
| 666 |
@click.option("--top-p", type=float, default=0.7)
|
| 667 |
@click.option("--repetition-penalty", type=float, default=1.5)
|
| 668 |
@click.option("--temperature", type=float, default=0.7)
|
|
|
|
| 686 |
prompt_tokens: Optional[Path],
|
| 687 |
num_samples: int,
|
| 688 |
max_new_tokens: int,
|
|
|
|
| 689 |
top_p: int,
|
| 690 |
repetition_penalty: float,
|
| 691 |
temperature: float,
|
|
|
|
| 709 |
model, decode_one_token = load_model(
|
| 710 |
config_name, checkpoint_path, device, precision, max_length, compile=compile
|
| 711 |
)
|
| 712 |
+
|
| 713 |
+
if torch.cuda.is_available():
|
| 714 |
+
torch.cuda.synchronize()
|
| 715 |
+
|
| 716 |
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
| 717 |
|
| 718 |
prompt_tokens = (
|
|
|
|
| 723 |
|
| 724 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
| 725 |
torch.manual_seed(seed)
|
| 726 |
+
|
| 727 |
+
if torch.cuda.is_available():
|
| 728 |
+
torch.cuda.manual_seed(seed)
|
| 729 |
|
| 730 |
generator = generate_long(
|
| 731 |
model=model,
|
|
|
|
| 734 |
text=text,
|
| 735 |
num_samples=num_samples,
|
| 736 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 737 |
top_p=top_p,
|
| 738 |
repetition_penalty=repetition_penalty,
|
| 739 |
temperature=temperature,
|