File size: 6,495 Bytes
fd1489d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers.tokenization_utils import PreTrainedTokenizer, BatchEncoding
from transformers.utils import logging, TensorType, to_py_obj
try:
from ariautils.midi import MidiDict
from ariautils.tokenizer import AbsTokenizer
from ariautils.tokenizer._base import Token
except ImportError:
raise ImportError(
"ariautils is not installed. Please try `pip install git+https://github.com/EleutherAI/aria-utils.git`."
)
if TYPE_CHECKING:
pass
logger = logging.get_logger(__name__)
class AriaTokenizer(PreTrainedTokenizer):
"""
Aria Tokenizer is NOT a BPE tokenizer. A midi file will be converted to a MidiDict (note: in fact, a MidiDict is not a single dict. It is more about a list of "notes") which represents a sequence of notes, stops, etc. And then, aria tokenizer is simply a dictionary that maps MidiDict to discrete indices according to a hard-coded rule.
For a FIM finetuned model, we also follow a simple FIM format to guide a piece of music to a (possibly very different) suffix according to the prompts:
<GUIDANCE-START> ... <GUIDANCE-END> <S> <PROMPT-START> ... <PROMPT-END>
This way, we expect a continuation that connects PROMPT and GUIDANCE.
"""
vocab_files_names = {}
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
add_bos_token=True,
add_eos_token=False,
clean_up_tokenization_spaces=False,
use_default_system_prompt=False,
**kwargs,
):
self._tokenizer = AbsTokenizer()
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.use_default_system_prompt = use_default_system_prompt
bos_token = self._tokenizer.bos_tok
eos_token = self._tokenizer.eos_tok
pad_token = self._tokenizer.pad_tok
unk_token = self._tokenizer.unk_tok
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
use_default_system_prompt=use_default_system_prompt,
**kwargs,
)
def __getstate__(self):
return {}
def __setstate__(self, d):
raise NotImplementedError()
@property
def vocab_size(self):
"""Returns vocab size"""
return self._tokenizer.vocab_size
def get_vocab(self):
return self._tokenizer.tok_to_id
def tokenize(self, midi_dict: MidiDict, **kwargs) -> List[Token]:
return self._tokenizer(midi_dict)
def _tokenize(self, midi_dict: MidiDict, **kwargs) -> List[Token]:
return self._tokenizer(midi_dict)
def __call__(
self,
midi_dicts: MidiDict | list[MidiDict],
padding: bool = False,
max_length: int | None = None,
pad_to_multiple_of: int | None = None,
return_tensors: str | TensorType | None = None,
return_attention_mask: bool | None = None,
**kwargs,
) -> BatchEncoding:
"""It is impossible to rely on the parent method because the inputs are MidiDict(s) instead of strings. I do not like the idea of going hacky so that two entirely different types of inputs can marry. So here I reimplement __call__ with limited support of certain useful arguments. I do not expect any conflict with other "string-in-ids-out" tokenizers. If you have to mix up the API of string-based tokenizers and our midi-based tokenizer, there must be a problem with your design."""
if isinstance(midi_dicts, MidiDict):
midi_dicts = [midi_dicts]
all_tokens: list[list[int]] = []
all_attn_masks: list[list[int]] = []
max_len_encoded = 0
# TODO: if we decide to optimize batched tokenization on ariautils using some compiled backend, we can change this loop accordingly.
for md in midi_dicts:
tokens = self._tokenizer.encode(self._tokenizer.tokenize(md))
if max_length is not None:
tokens = tokens[:max_length]
max_len_encoded = max(max_len_encoded, len(tokens))
all_tokens.append(tokens)
all_attn_masks.append([True] * len(tokens))
if pad_to_multiple_of is not None:
max_len_encoded = (
(max_len_encoded + pad_to_multiple_of) // pad_to_multiple_of
) * pad_to_multiple_of
if padding:
for tokens, attn_mask in zip(all_tokens, all_attn_masks):
tokens.extend([self.pad_token_id] * (max_len_encoded - len(tokens)))
attn_mask.extend([False] * (max_len_encoded - len(tokens)))
return BatchEncoding(
{
"input_ids": all_tokens,
"attention_masks": all_attn_masks,
},
tensor_type=return_tensors,
)
def decode(self, token_ids: List[Token], **kwargs) -> MidiDict:
token_ids = to_py_obj(token_ids)
return self._tokenizer.detokenize(self._tokenizer.decode(token_ids))
def batch_decode(
self, token_ids_list: List[List[Token]], **kwargs
) -> List[MidiDict]:
results = []
for token_ids in token_ids_list:
# Can we simply yield (without breaking all HF wrappers)?
results.append(self.decode(token_ids))
return results
def encode_from_file(self, filename: str, **kwargs) -> BatchEncoding:
midi_dict = MidiDict.from_midi(filename)
return self(midi_dict, **kwargs)
def encode_from_files(self, filenames: list[str], **kwargs) -> BatchEncoding:
midi_dicts = [MidiDict.from_midi(file) for file in filenames]
return self(midi_dicts, **kwargs)
def _convert_token_to_id(self, token: Token):
"""Converts a token (tuple or str) into an id."""
return self._tokenizer.tok_to_id.get(
token, self._tokenizer.tok_to_id[self.unk_token]
)
def _convert_id_to_token(self, index: int):
"""Converts an index (integer) in a token (tuple or str)."""
return self._tokenizer.id_to_tok.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens: List[Token]) -> MidiDict:
"""Converts a sequence of tokens into a single MidiDict."""
return self._tokenizer.detokenize(tokens)
def save_vocabulary(
self, save_directory, filename_prefix: Optional[str] = None
) -> Tuple[str]:
raise NotImplementedError()
|