Spaces:
Runtime error
Runtime error
Update to fix Collab launch
Browse files
audiocraft/models/musicgen.py
CHANGED
|
@@ -412,6 +412,38 @@ class MusicGen:
|
|
| 412 |
gen_audio = self.compression_model.decode(gen_tokens, None)
|
| 413 |
return gen_audio
|
| 414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
def to(self, device: str):
|
| 416 |
self.compression_model.to(device)
|
| 417 |
self.lm.to(device)
|
|
|
|
| 412 |
gen_audio = self.compression_model.decode(gen_tokens, None)
|
| 413 |
return gen_audio
|
| 414 |
|
| 415 |
+
#def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
| 416 |
+
# prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
|
| 417 |
+
# """Generate discrete audio tokens given audio prompt and/or conditions.
|
| 418 |
+
|
| 419 |
+
# Args:
|
| 420 |
+
# attributes (tp.List[ConditioningAttributes]): Conditions used for generation (text/melody).
|
| 421 |
+
# prompt_tokens (tp.Optional[torch.Tensor]): Audio prompt used for continuation.
|
| 422 |
+
# progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 423 |
+
# Returns:
|
| 424 |
+
# torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
| 425 |
+
# """
|
| 426 |
+
# def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
| 427 |
+
# print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
|
| 428 |
+
|
| 429 |
+
# if prompt_tokens is not None:
|
| 430 |
+
# assert self.generation_params['max_gen_len'] > prompt_tokens.shape[-1], \
|
| 431 |
+
# "Prompt is longer than audio to generate"
|
| 432 |
+
|
| 433 |
+
# callback = None
|
| 434 |
+
# if progress:
|
| 435 |
+
# callback = _progress_callback
|
| 436 |
+
|
| 437 |
+
# # generate by sampling from LM
|
| 438 |
+
# with self.autocast:
|
| 439 |
+
# gen_tokens = self.lm.generate(prompt_tokens, attributes, callback=callback, **self.generation_params)
|
| 440 |
+
|
| 441 |
+
# # generate audio
|
| 442 |
+
# assert gen_tokens.dim() == 3
|
| 443 |
+
# with torch.no_grad():
|
| 444 |
+
# gen_audio = self.compression_model.decode(gen_tokens, None)
|
| 445 |
+
# return gen_audio
|
| 446 |
+
|
| 447 |
def to(self, device: str):
|
| 448 |
self.compression_model.to(device)
|
| 449 |
self.lm.to(device)
|