File size: 6,495 Bytes
fd1489d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from typing import TYPE_CHECKING, List, Optional, Tuple

from transformers.tokenization_utils import PreTrainedTokenizer, BatchEncoding
from transformers.utils import logging, TensorType, to_py_obj

try:
    from ariautils.midi import MidiDict
    from ariautils.tokenizer import AbsTokenizer
    from ariautils.tokenizer._base import Token
except ImportError:
    raise ImportError(
        "ariautils is not installed. Please try `pip install git+https://github.com/EleutherAI/aria-utils.git`."
    )

if TYPE_CHECKING:
    pass

logger = logging.get_logger(__name__)


class AriaTokenizer(PreTrainedTokenizer):
    """
    Aria Tokenizer is NOT a BPE tokenizer. A midi file will be converted to a MidiDict (note: in fact, a MidiDict is not a single dict. It is more about a list of "notes") which represents a sequence of notes, stops, etc. And then, aria tokenizer is simply a dictionary that maps MidiDict to discrete indices according to a hard-coded rule.

    For a FIM finetuned model, we also follow a simple FIM format to guide a piece of music to a (possibly very different) suffix according to the prompts:
    <GUIDANCE-START> ... <GUIDANCE-END> <S> <PROMPT-START> ... <PROMPT-END>
    This way, we expect a continuation that connects PROMPT and GUIDANCE.
    """

    vocab_files_names = {}
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        add_bos_token=True,
        add_eos_token=False,
        clean_up_tokenization_spaces=False,
        use_default_system_prompt=False,
        **kwargs,
    ):
        self._tokenizer = AbsTokenizer()

        self.add_bos_token = add_bos_token
        self.add_eos_token = add_eos_token
        self.use_default_system_prompt = use_default_system_prompt

        bos_token = self._tokenizer.bos_tok
        eos_token = self._tokenizer.eos_tok
        pad_token = self._tokenizer.pad_tok
        unk_token = self._tokenizer.unk_tok

        super().__init__(
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            pad_token=pad_token,
            use_default_system_prompt=use_default_system_prompt,
            **kwargs,
        )

    def __getstate__(self):
        return {}

    def __setstate__(self, d):
        raise NotImplementedError()

    @property
    def vocab_size(self):
        """Returns vocab size"""
        return self._tokenizer.vocab_size

    def get_vocab(self):
        return self._tokenizer.tok_to_id

    def tokenize(self, midi_dict: MidiDict, **kwargs) -> List[Token]:
        return self._tokenizer(midi_dict)

    def _tokenize(self, midi_dict: MidiDict, **kwargs) -> List[Token]:
        return self._tokenizer(midi_dict)

    def __call__(
        self,
        midi_dicts: MidiDict | list[MidiDict],
        padding: bool = False,
        max_length: int | None = None,
        pad_to_multiple_of: int | None = None,
        return_tensors: str | TensorType | None = None,
        return_attention_mask: bool | None = None,
        **kwargs,
    ) -> BatchEncoding:
        """It is impossible to rely on the parent method because the inputs are MidiDict(s) instead of strings. I do not like the idea of going hacky so that two entirely different types of inputs can marry. So here I reimplement __call__ with limited support of certain useful arguments. I do not expect any conflict with other "string-in-ids-out" tokenizers. If you have to mix up the API of string-based tokenizers and our midi-based tokenizer, there must be a problem with your design."""
        if isinstance(midi_dicts, MidiDict):
            midi_dicts = [midi_dicts]

        all_tokens: list[list[int]] = []
        all_attn_masks: list[list[int]] = []
        max_len_encoded = 0
        # TODO: if we decide to optimize batched tokenization on ariautils using some compiled backend, we can change this loop accordingly.
        for md in midi_dicts:
            tokens = self._tokenizer.encode(self._tokenizer.tokenize(md))
            if max_length is not None:
                tokens = tokens[:max_length]
            max_len_encoded = max(max_len_encoded, len(tokens))
            all_tokens.append(tokens)
            all_attn_masks.append([True] * len(tokens))

        if pad_to_multiple_of is not None:
            max_len_encoded = (
                (max_len_encoded + pad_to_multiple_of) // pad_to_multiple_of
            ) * pad_to_multiple_of
        if padding:
            for tokens, attn_mask in zip(all_tokens, all_attn_masks):
                tokens.extend([self.pad_token_id] * (max_len_encoded - len(tokens)))
                attn_mask.extend([False] * (max_len_encoded - len(tokens)))

        return BatchEncoding(
            {
                "input_ids": all_tokens,
                "attention_masks": all_attn_masks,
            },
            tensor_type=return_tensors,
        )

    def decode(self, token_ids: List[Token], **kwargs) -> MidiDict:
        token_ids = to_py_obj(token_ids)

        return self._tokenizer.detokenize(self._tokenizer.decode(token_ids))

    def batch_decode(
        self, token_ids_list: List[List[Token]], **kwargs
    ) -> List[MidiDict]:
        results = []
        for token_ids in token_ids_list:
            # Can we simply yield (without breaking all HF wrappers)?
            results.append(self.decode(token_ids))
        return results

    def encode_from_file(self, filename: str, **kwargs) -> BatchEncoding:
        midi_dict = MidiDict.from_midi(filename)
        return self(midi_dict, **kwargs)

    def encode_from_files(self, filenames: list[str], **kwargs) -> BatchEncoding:
        midi_dicts = [MidiDict.from_midi(file) for file in filenames]
        return self(midi_dicts, **kwargs)

    def _convert_token_to_id(self, token: Token):
        """Converts a token (tuple or str) into an id."""
        return self._tokenizer.tok_to_id.get(
            token, self._tokenizer.tok_to_id[self.unk_token]
        )

    def _convert_id_to_token(self, index: int):
        """Converts an index (integer) in a token (tuple or str)."""
        return self._tokenizer.id_to_tok.get(index, self.unk_token)

    def convert_tokens_to_string(self, tokens: List[Token]) -> MidiDict:
        """Converts a sequence of tokens into a single MidiDict."""
        return self._tokenizer.detokenize(tokens)

    def save_vocabulary(
        self, save_directory, filename_prefix: Optional[str] = None
    ) -> Tuple[str]:
        raise NotImplementedError()