Spaces:
Runtime error
Runtime error
hf
Browse files- app.py +9 -14
- app_onnx.py +12 -19
- midi_model.py +53 -23
- midi_tokenizer.py +29 -0
app.py
CHANGED
|
@@ -365,19 +365,19 @@ if __name__ == "__main__":
|
|
| 365 |
synthesizer = MidiSynthesizer(soundfont_path)
|
| 366 |
models_info = {
|
| 367 |
"generic pretrain model (tv2o-medium) by skytnt": [
|
| 368 |
-
"skytnt/midi-model-tv2o-medium",
|
| 369 |
"jpop": "skytnt/midi-model-tv2om-jpop-lora",
|
| 370 |
"touhou": "skytnt/midi-model-tv2om-touhou-lora"
|
| 371 |
}
|
| 372 |
],
|
| 373 |
"generic pretrain model (tv2o-large) by asigalov61": [
|
| 374 |
-
"asigalov61/Music-Llama",
|
| 375 |
],
|
| 376 |
"generic pretrain model (tv2o-medium) by asigalov61": [
|
| 377 |
-
"asigalov61/Music-Llama-Medium",
|
| 378 |
],
|
| 379 |
"generic pretrain model (tv1-medium) by skytnt": [
|
| 380 |
-
"skytnt/midi-model",
|
| 381 |
]
|
| 382 |
}
|
| 383 |
models = {}
|
|
@@ -388,20 +388,15 @@ if __name__ == "__main__":
|
|
| 388 |
torch.backends.cudnn.allow_tf32 = True
|
| 389 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 390 |
torch.backends.cuda.enable_flash_sdp(True)
|
| 391 |
-
for name, (repo_id,
|
| 392 |
-
|
| 393 |
-
model
|
| 394 |
-
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
| 395 |
-
state_dict = ckpt.get("state_dict", ckpt)
|
| 396 |
-
model.load_state_dict(state_dict, strict=False)
|
| 397 |
-
model.to(device="cpu", dtype=torch.float32).eval()
|
| 398 |
models[name] = model
|
| 399 |
for lora_name, lora_repo in loras.items():
|
| 400 |
-
model = MIDIModel
|
| 401 |
-
model.load_state_dict(state_dict, strict=False)
|
| 402 |
print(f"loading lora {lora_repo} for {name}")
|
| 403 |
model = model.load_merge_lora(lora_repo)
|
| 404 |
-
model.to(device="cpu", dtype=torch.float32)
|
| 405 |
models[f"{name} with {lora_name} lora"] = model
|
| 406 |
|
| 407 |
load_javascript()
|
|
|
|
| 365 |
synthesizer = MidiSynthesizer(soundfont_path)
|
| 366 |
models_info = {
|
| 367 |
"generic pretrain model (tv2o-medium) by skytnt": [
|
| 368 |
+
"skytnt/midi-model-tv2o-medium", {
|
| 369 |
"jpop": "skytnt/midi-model-tv2om-jpop-lora",
|
| 370 |
"touhou": "skytnt/midi-model-tv2om-touhou-lora"
|
| 371 |
}
|
| 372 |
],
|
| 373 |
"generic pretrain model (tv2o-large) by asigalov61": [
|
| 374 |
+
"asigalov61/Music-Llama", {}
|
| 375 |
],
|
| 376 |
"generic pretrain model (tv2o-medium) by asigalov61": [
|
| 377 |
+
"asigalov61/Music-Llama-Medium", {}
|
| 378 |
],
|
| 379 |
"generic pretrain model (tv1-medium) by skytnt": [
|
| 380 |
+
"skytnt/midi-model", {}
|
| 381 |
]
|
| 382 |
}
|
| 383 |
models = {}
|
|
|
|
| 388 |
torch.backends.cudnn.allow_tf32 = True
|
| 389 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 390 |
torch.backends.cuda.enable_flash_sdp(True)
|
| 391 |
+
for name, (repo_id, loras) in models_info.items():
|
| 392 |
+
model = MIDIModel.from_pretrained(repo_id)
|
| 393 |
+
model.to(device="cpu", dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
models[name] = model
|
| 395 |
for lora_name, lora_repo in loras.items():
|
| 396 |
+
model = MIDIModel.from_pretrained(repo_id)
|
|
|
|
| 397 |
print(f"loading lora {lora_repo} for {name}")
|
| 398 |
model = model.load_merge_lora(lora_repo)
|
| 399 |
+
model.to(device="cpu", dtype=torch.float32)
|
| 400 |
models[f"{name} with {lora_name} lora"] = model
|
| 401 |
|
| 402 |
load_javascript()
|
app_onnx.py
CHANGED
|
@@ -432,18 +432,12 @@ def hf_hub_download_retry(repo_id, filename):
|
|
| 432 |
raise err
|
| 433 |
|
| 434 |
|
| 435 |
-
def get_tokenizer(
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
else:
|
| 442 |
-
o = False
|
| 443 |
-
if tv not in ["v1", "v2"]:
|
| 444 |
-
raise ValueError(f"Unknown tokenizer version {tv}")
|
| 445 |
-
tokenizer = MIDITokenizer(tv)
|
| 446 |
-
tokenizer.set_optimise_midi(o)
|
| 447 |
return tokenizer
|
| 448 |
|
| 449 |
|
|
@@ -468,34 +462,33 @@ if __name__ == "__main__":
|
|
| 468 |
synthesizer = MidiSynthesizer(soundfont_path)
|
| 469 |
models_info = {
|
| 470 |
"generic pretrain model (tv2o-medium) by skytnt": [
|
| 471 |
-
"skytnt/midi-model-tv2o-medium", "",
|
| 472 |
"jpop": "skytnt/midi-model-tv2om-jpop-lora",
|
| 473 |
"touhou": "skytnt/midi-model-tv2om-touhou-lora"
|
| 474 |
}
|
| 475 |
],
|
| 476 |
"generic pretrain model (tv2o-large) by asigalov61": [
|
| 477 |
-
"asigalov61/Music-Llama", "",
|
| 478 |
],
|
| 479 |
"generic pretrain model (tv2o-medium) by asigalov61": [
|
| 480 |
-
"asigalov61/Music-Llama-Medium", "",
|
| 481 |
],
|
| 482 |
"generic pretrain model (tv1-medium) by skytnt": [
|
| 483 |
-
"skytnt/midi-model", "",
|
| 484 |
]
|
| 485 |
}
|
| 486 |
models = {}
|
| 487 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 488 |
device = "cuda"
|
| 489 |
|
| 490 |
-
for name, (repo_id, path,
|
| 491 |
model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
|
| 492 |
model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
|
| 493 |
-
tokenizer = get_tokenizer(
|
| 494 |
models[name] = [model_base_path, model_token_path, tokenizer]
|
| 495 |
for lora_name, lora_repo in loras.items():
|
| 496 |
model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
|
| 497 |
model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
|
| 498 |
-
tokenizer = get_tokenizer(config)
|
| 499 |
models[f"{name} with {lora_name} lora"] = [model_base_path, model_token_path, tokenizer]
|
| 500 |
|
| 501 |
load_javascript()
|
|
|
|
| 432 |
raise err
|
| 433 |
|
| 434 |
|
| 435 |
+
def get_tokenizer(repo_id):
|
| 436 |
+
config_path = hf_hub_download_retry(repo_id=repo_id, filename=f"config.json")
|
| 437 |
+
with open(config_path, "r") as f:
|
| 438 |
+
config = json.load(f)
|
| 439 |
+
tokenizer = MIDITokenizer(config["tokenizer"]["version"])
|
| 440 |
+
tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
return tokenizer
|
| 442 |
|
| 443 |
|
|
|
|
| 462 |
synthesizer = MidiSynthesizer(soundfont_path)
|
| 463 |
models_info = {
|
| 464 |
"generic pretrain model (tv2o-medium) by skytnt": [
|
| 465 |
+
"skytnt/midi-model-tv2o-medium", "", {
|
| 466 |
"jpop": "skytnt/midi-model-tv2om-jpop-lora",
|
| 467 |
"touhou": "skytnt/midi-model-tv2om-touhou-lora"
|
| 468 |
}
|
| 469 |
],
|
| 470 |
"generic pretrain model (tv2o-large) by asigalov61": [
|
| 471 |
+
"asigalov61/Music-Llama", "", {}
|
| 472 |
],
|
| 473 |
"generic pretrain model (tv2o-medium) by asigalov61": [
|
| 474 |
+
"asigalov61/Music-Llama-Medium", "", {}
|
| 475 |
],
|
| 476 |
"generic pretrain model (tv1-medium) by skytnt": [
|
| 477 |
+
"skytnt/midi-model", "", {}
|
| 478 |
]
|
| 479 |
}
|
| 480 |
models = {}
|
| 481 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 482 |
device = "cuda"
|
| 483 |
|
| 484 |
+
for name, (repo_id, path, loras) in models_info.items():
|
| 485 |
model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
|
| 486 |
model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
|
| 487 |
+
tokenizer = get_tokenizer(repo_id)
|
| 488 |
models[name] = [model_base_path, model_token_path, tokenizer]
|
| 489 |
for lora_name, lora_repo in loras.items():
|
| 490 |
model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
|
| 491 |
model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
|
|
|
|
| 492 |
models[f"{name} with {lora_name} lora"] = [model_base_path, model_token_path, tokenizer]
|
| 493 |
|
| 494 |
load_javascript()
|
midi_model.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
|
@@ -6,21 +7,57 @@ import torch.nn as nn
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
import tqdm
|
| 8 |
from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
|
| 9 |
-
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
| 10 |
-
from transformers.integrations import PeftAdapterMixin
|
| 11 |
|
| 12 |
from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
|
| 13 |
|
| 14 |
config_name_list = ["tv1-medium", "tv2-medium", "tv2o-medium", "tv2-large", "tv2o-large"]
|
| 15 |
|
| 16 |
|
| 17 |
-
class MIDIModelConfig:
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
@staticmethod
|
| 26 |
def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_embd=1024, n_inner=4096):
|
|
@@ -59,27 +96,20 @@ class MIDIModelConfig:
|
|
| 59 |
raise ValueError(f"Unknown model size {size}")
|
| 60 |
|
| 61 |
|
| 62 |
-
class MIDIModel(
|
|
|
|
|
|
|
| 63 |
def __init__(self, config: MIDIModelConfig, *args, **kwargs):
|
| 64 |
-
super(MIDIModel, self).__init__()
|
| 65 |
self.tokenizer = config.tokenizer
|
| 66 |
self.net = LlamaModel(config.net_config)
|
| 67 |
self.net_token = LlamaModel(config.net_token_config)
|
| 68 |
self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False)
|
| 69 |
-
self.device = "cpu"
|
| 70 |
-
|
| 71 |
-
def to(self, *args, **kwargs):
|
| 72 |
-
if "device" in kwargs:
|
| 73 |
-
self.device = kwargs["device"]
|
| 74 |
-
return super(MIDIModel, self).to(*args, **kwargs)
|
| 75 |
-
|
| 76 |
-
def peft_loaded(self):
|
| 77 |
-
return self._hf_peft_config_loaded
|
| 78 |
|
| 79 |
def load_merge_lora(self, model_id):
|
| 80 |
peft_config = PeftConfig.from_pretrained(model_id)
|
| 81 |
model = LoraModel(self, peft_config, adapter_name="default")
|
| 82 |
-
adapter_state_dict = load_peft_weights(model_id, device=self.device)
|
| 83 |
set_peft_model_state_dict(self, adapter_state_dict, "default")
|
| 84 |
return model.merge_and_unload()
|
| 85 |
|
|
@@ -164,7 +194,7 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
|
|
| 164 |
with bar:
|
| 165 |
while cur_len < max_len:
|
| 166 |
end = [False] * batch_size
|
| 167 |
-
hidden = self.forward(input_tensor[:,past_len:], cache=cache1)[:, -1]
|
| 168 |
next_token_seq = None
|
| 169 |
event_names = [""] * batch_size
|
| 170 |
cache2 = DynamicCache()
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import Union, Dict, Any
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
|
|
|
| 7 |
import torch.nn.functional as F
|
| 8 |
import tqdm
|
| 9 |
from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
|
| 10 |
+
from transformers import LlamaModel, LlamaConfig, DynamicCache, PretrainedConfig, PreTrainedModel
|
|
|
|
| 11 |
|
| 12 |
from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
|
| 13 |
|
| 14 |
config_name_list = ["tv1-medium", "tv2-medium", "tv2o-medium", "tv2-large", "tv2o-large"]
|
| 15 |
|
| 16 |
|
| 17 |
+
class MIDIModelConfig(PretrainedConfig):
|
| 18 |
+
model_type = "midi_model"
|
| 19 |
+
|
| 20 |
+
def __init__(self,
|
| 21 |
+
tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2, Dict]=None,
|
| 22 |
+
net_config: Union[LlamaConfig, Dict]=None,
|
| 23 |
+
net_token_config: Union[LlamaConfig, Dict]=None,
|
| 24 |
+
**kwargs):
|
| 25 |
+
super().__init__(**kwargs)
|
| 26 |
+
if tokenizer:
|
| 27 |
+
if isinstance(tokenizer, dict):
|
| 28 |
+
self.tokenizer = MIDITokenizer(tokenizer["version"])
|
| 29 |
+
self.tokenizer.set_optimise_midi(tokenizer["optimise_midi"])
|
| 30 |
+
else:
|
| 31 |
+
self.tokenizer = tokenizer
|
| 32 |
+
else:
|
| 33 |
+
self.tokenizer = MIDITokenizer()
|
| 34 |
+
if net_config:
|
| 35 |
+
if isinstance(net_config, dict):
|
| 36 |
+
self.net_config = LlamaConfig(**net_config)
|
| 37 |
+
else:
|
| 38 |
+
self.net_config = net_config
|
| 39 |
+
else:
|
| 40 |
+
self.net_config = LlamaConfig()
|
| 41 |
+
if net_token_config:
|
| 42 |
+
if isinstance(net_token_config, dict):
|
| 43 |
+
self.net_token_config = LlamaConfig(**net_token_config)
|
| 44 |
+
else:
|
| 45 |
+
self.net_token_config = net_token_config
|
| 46 |
+
else:
|
| 47 |
+
self.net_token_config = LlamaConfig()
|
| 48 |
+
self.n_embd = self.net_token_config.hidden_size
|
| 49 |
+
|
| 50 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 51 |
+
d = super().to_dict()
|
| 52 |
+
d["tokenizer"] = self.tokenizer.to_dict()
|
| 53 |
+
return d
|
| 54 |
+
|
| 55 |
+
def __str__(self):
|
| 56 |
+
d = {
|
| 57 |
+
"net": self.net_config.to_json_string(use_diff=False),
|
| 58 |
+
"net_token": self.net_token_config.to_json_string(use_diff=False)
|
| 59 |
+
}
|
| 60 |
+
return json.dumps(d, indent=4)
|
| 61 |
|
| 62 |
@staticmethod
|
| 63 |
def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_embd=1024, n_inner=4096):
|
|
|
|
| 96 |
raise ValueError(f"Unknown model size {size}")
|
| 97 |
|
| 98 |
|
| 99 |
+
class MIDIModel(PreTrainedModel):
|
| 100 |
+
config_class = MIDIModelConfig
|
| 101 |
+
|
| 102 |
def __init__(self, config: MIDIModelConfig, *args, **kwargs):
|
| 103 |
+
super(MIDIModel, self).__init__(config, *args, **kwargs)
|
| 104 |
self.tokenizer = config.tokenizer
|
| 105 |
self.net = LlamaModel(config.net_config)
|
| 106 |
self.net_token = LlamaModel(config.net_token_config)
|
| 107 |
self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
def load_merge_lora(self, model_id):
|
| 110 |
peft_config = PeftConfig.from_pretrained(model_id)
|
| 111 |
model = LoraModel(self, peft_config, adapter_name="default")
|
| 112 |
+
adapter_state_dict = load_peft_weights(model_id, device=str(self.device))
|
| 113 |
set_peft_model_state_dict(self, adapter_state_dict, "default")
|
| 114 |
return model.merge_and_unload()
|
| 115 |
|
|
|
|
| 194 |
with bar:
|
| 195 |
while cur_len < max_len:
|
| 196 |
end = [False] * batch_size
|
| 197 |
+
hidden = self.forward(input_tensor[:, past_len:], cache=cache1)[:, -1]
|
| 198 |
next_token_seq = None
|
| 199 |
event_names = [""] * batch_size
|
| 200 |
cache2 = DynamicCache()
|
midi_tokenizer.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import random
|
|
|
|
| 2 |
|
| 3 |
import PIL.Image
|
| 4 |
import numpy as np
|
|
@@ -33,6 +34,20 @@ class MIDITokenizerV1:
|
|
| 33 |
self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
|
| 34 |
self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def set_optimise_midi(self, optimise_midi=True):
|
| 37 |
self.optimise_midi = optimise_midi
|
| 38 |
|
|
@@ -519,6 +534,20 @@ class MIDITokenizerV2:
|
|
| 519 |
self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
|
| 520 |
self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
|
| 521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
def set_optimise_midi(self, optimise_midi=True):
|
| 523 |
self.optimise_midi = optimise_midi
|
| 524 |
|
|
|
|
| 1 |
import random
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
|
| 4 |
import PIL.Image
|
| 5 |
import numpy as np
|
|
|
|
| 34 |
self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
|
| 35 |
self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
|
| 36 |
|
| 37 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 38 |
+
d = {
|
| 39 |
+
"version":self.version,
|
| 40 |
+
"optimise_midi":self.optimise_midi,
|
| 41 |
+
"vocab_size": self.vocab_size,
|
| 42 |
+
"events": self.events,
|
| 43 |
+
"event_parameters": self.event_parameters,
|
| 44 |
+
"max_token_seq": self.max_token_seq,
|
| 45 |
+
"pad_id": self.pad_id,
|
| 46 |
+
"bos_id": self.bos_id,
|
| 47 |
+
"eos_id": self.eos_id,
|
| 48 |
+
}
|
| 49 |
+
return d
|
| 50 |
+
|
| 51 |
def set_optimise_midi(self, optimise_midi=True):
|
| 52 |
self.optimise_midi = optimise_midi
|
| 53 |
|
|
|
|
| 534 |
self.parameter_ids = {p: allocate_ids(s) for p, s in self.event_parameters.items()}
|
| 535 |
self.max_token_seq = max([len(ps) for ps in self.events.values()]) + 1
|
| 536 |
|
| 537 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 538 |
+
d = {
|
| 539 |
+
"version":self.version,
|
| 540 |
+
"optimise_midi":self.optimise_midi,
|
| 541 |
+
"vocab_size": self.vocab_size,
|
| 542 |
+
"events": self.events,
|
| 543 |
+
"event_parameters": self.event_parameters,
|
| 544 |
+
"max_token_seq": self.max_token_seq,
|
| 545 |
+
"pad_id": self.pad_id,
|
| 546 |
+
"bos_id": self.bos_id,
|
| 547 |
+
"eos_id": self.eos_id,
|
| 548 |
+
}
|
| 549 |
+
return d
|
| 550 |
+
|
| 551 |
def set_optimise_midi(self, optimise_midi=True):
|
| 552 |
self.optimise_midi = optimise_midi
|
| 553 |
|