Spaces:
Running
on
L4
Running
on
L4
Enable compile on A10G
Browse files- app.py +2 -2
- tools/llama/generate.py +37 -23
app.py
CHANGED
|
@@ -251,7 +251,7 @@ def build_app():
|
|
| 251 |
# speaker,
|
| 252 |
],
|
| 253 |
[audio, error],
|
| 254 |
-
|
| 255 |
)
|
| 256 |
|
| 257 |
return app
|
|
@@ -287,7 +287,7 @@ if __name__ == "__main__":
|
|
| 287 |
args = parse_args()
|
| 288 |
|
| 289 |
args.precision = torch.half if args.half else torch.bfloat16
|
| 290 |
-
|
| 291 |
|
| 292 |
logger.info("Loading Llama model...")
|
| 293 |
llama_model, decode_one_token = load_llama_model(
|
|
|
|
| 251 |
# speaker,
|
| 252 |
],
|
| 253 |
[audio, error],
|
| 254 |
+
concurrency_limit=1,
|
| 255 |
)
|
| 256 |
|
| 257 |
return app
|
|
|
|
| 287 |
args = parse_args()
|
| 288 |
|
| 289 |
args.precision = torch.half if args.half else torch.bfloat16
|
| 290 |
+
args.compile = True
|
| 291 |
|
| 292 |
logger.info("Loading Llama model...")
|
| 293 |
llama_model, decode_one_token = load_llama_model(
|
tools/llama/generate.py
CHANGED
|
@@ -14,7 +14,7 @@ from loguru import logger
|
|
| 14 |
from tqdm import tqdm
|
| 15 |
from transformers import AutoTokenizer
|
| 16 |
|
| 17 |
-
from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID
|
| 18 |
from fish_speech.text.clean import clean_text
|
| 19 |
|
| 20 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
@@ -291,11 +291,11 @@ def encode_tokens(
|
|
| 291 |
):
|
| 292 |
string = clean_text(string)
|
| 293 |
|
| 294 |
-
if speaker is
|
| 295 |
-
|
| 296 |
|
| 297 |
string = (
|
| 298 |
-
f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>
|
| 299 |
)
|
| 300 |
if bos:
|
| 301 |
string = f"<|begin_of_sequence|>{string}"
|
|
@@ -309,7 +309,10 @@ def encode_tokens(
|
|
| 309 |
tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
|
| 310 |
|
| 311 |
# Codebooks
|
| 312 |
-
zeros =
|
|
|
|
|
|
|
|
|
|
| 313 |
prompt = torch.cat((tokens, zeros), dim=0)
|
| 314 |
|
| 315 |
if prompt_tokens is None:
|
|
@@ -331,13 +334,23 @@ def encode_tokens(
|
|
| 331 |
)
|
| 332 |
data = data[:num_codebooks]
|
| 333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
# Since 1.0, we use <|semantic|>
|
| 335 |
s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
dtype=torch.int,
|
| 339 |
-
device=device,
|
| 340 |
)
|
|
|
|
| 341 |
|
| 342 |
data = torch.cat((main_token_ids, data), dim=0)
|
| 343 |
prompt = torch.cat((prompt, data), dim=1)
|
|
@@ -450,6 +463,20 @@ def generate_long(
|
|
| 450 |
use_prompt = prompt_text is not None and prompt_tokens is not None
|
| 451 |
encoded = []
|
| 452 |
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
for idx, text in enumerate(texts):
|
| 454 |
encoded.append(
|
| 455 |
encode_tokens(
|
|
@@ -457,25 +484,12 @@ def generate_long(
|
|
| 457 |
string=text,
|
| 458 |
bos=idx == 0 and not use_prompt,
|
| 459 |
device=device,
|
| 460 |
-
speaker=
|
| 461 |
num_codebooks=model.config.num_codebooks,
|
| 462 |
)
|
| 463 |
)
|
| 464 |
logger.info(f"Encoded text: {text}")
|
| 465 |
|
| 466 |
-
if use_prompt:
|
| 467 |
-
encoded_prompt = encode_tokens(
|
| 468 |
-
tokenizer,
|
| 469 |
-
prompt_text,
|
| 470 |
-
prompt_tokens=prompt_tokens,
|
| 471 |
-
bos=True,
|
| 472 |
-
device=device,
|
| 473 |
-
speaker=speaker,
|
| 474 |
-
num_codebooks=model.config.num_codebooks,
|
| 475 |
-
)
|
| 476 |
-
|
| 477 |
-
encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
|
| 478 |
-
|
| 479 |
for sample_idx in range(num_samples):
|
| 480 |
torch.cuda.synchronize()
|
| 481 |
global_encoded = []
|
|
|
|
| 14 |
from tqdm import tqdm
|
| 15 |
from transformers import AutoTokenizer
|
| 16 |
|
| 17 |
+
from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
|
| 18 |
from fish_speech.text.clean import clean_text
|
| 19 |
|
| 20 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
| 291 |
):
|
| 292 |
string = clean_text(string)
|
| 293 |
|
| 294 |
+
if speaker is None:
|
| 295 |
+
speaker = "assistant"
|
| 296 |
|
| 297 |
string = (
|
| 298 |
+
f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>"
|
| 299 |
)
|
| 300 |
if bos:
|
| 301 |
string = f"<|begin_of_sequence|>{string}"
|
|
|
|
| 309 |
tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
|
| 310 |
|
| 311 |
# Codebooks
|
| 312 |
+
zeros = (
|
| 313 |
+
torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
|
| 314 |
+
* CODEBOOK_PAD_TOKEN_ID
|
| 315 |
+
)
|
| 316 |
prompt = torch.cat((tokens, zeros), dim=0)
|
| 317 |
|
| 318 |
if prompt_tokens is None:
|
|
|
|
| 334 |
)
|
| 335 |
data = data[:num_codebooks]
|
| 336 |
|
| 337 |
+
# Add eos token for each codebook
|
| 338 |
+
data = torch.cat(
|
| 339 |
+
(
|
| 340 |
+
data,
|
| 341 |
+
torch.ones((data.size(0), 1), dtype=torch.int, device=device)
|
| 342 |
+
* CODEBOOK_EOS_TOKEN_ID,
|
| 343 |
+
),
|
| 344 |
+
dim=1,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
# Since 1.0, we use <|semantic|>
|
| 348 |
s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
|
| 349 |
+
end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 350 |
+
main_token_ids = (
|
| 351 |
+
torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
|
|
|
|
| 352 |
)
|
| 353 |
+
main_token_ids[0, -1] = end_token_id
|
| 354 |
|
| 355 |
data = torch.cat((main_token_ids, data), dim=0)
|
| 356 |
prompt = torch.cat((prompt, data), dim=1)
|
|
|
|
| 463 |
use_prompt = prompt_text is not None and prompt_tokens is not None
|
| 464 |
encoded = []
|
| 465 |
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
| 466 |
+
|
| 467 |
+
if use_prompt:
|
| 468 |
+
encoded.append(
|
| 469 |
+
encode_tokens(
|
| 470 |
+
tokenizer,
|
| 471 |
+
prompt_text,
|
| 472 |
+
prompt_tokens=prompt_tokens,
|
| 473 |
+
bos=True,
|
| 474 |
+
device=device,
|
| 475 |
+
speaker=speaker,
|
| 476 |
+
num_codebooks=model.config.num_codebooks,
|
| 477 |
+
)
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
for idx, text in enumerate(texts):
|
| 481 |
encoded.append(
|
| 482 |
encode_tokens(
|
|
|
|
| 484 |
string=text,
|
| 485 |
bos=idx == 0 and not use_prompt,
|
| 486 |
device=device,
|
| 487 |
+
speaker=speaker,
|
| 488 |
num_codebooks=model.config.num_codebooks,
|
| 489 |
)
|
| 490 |
)
|
| 491 |
logger.info(f"Encoded text: {text}")
|
| 492 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
for sample_idx in range(num_samples):
|
| 494 |
torch.cuda.synchronize()
|
| 495 |
global_encoded = []
|