mlinmg commited on
Commit
7eebd5c
·
verified ·
1 Parent(s): ad5c9de

Upload 8 files

Browse files
config.json ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "AstraMindAI/xtts2-gpt",
3
+ "architectures": [
4
+ "XttsGPT"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig",
8
+ "AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT"
9
+ },
10
+ "audio_config": {
11
+ "fmax": 8000,
12
+ "fmin": 0,
13
+ "hop_length": 256,
14
+ "mel_channels": 80,
15
+ "mel_norms_file": null,
16
+ "n_fft": 1024,
17
+ "output_sample_rate": 24000,
18
+ "power": 1.0,
19
+ "sample_rate": 22050,
20
+ "win_length": 1024
21
+ },
22
+ "char_limits": {
23
+ "ar": 166,
24
+ "cs": 186,
25
+ "de": 253,
26
+ "en": 250,
27
+ "es": 239,
28
+ "fr": 273,
29
+ "hu": 224,
30
+ "it": 213,
31
+ "ja": 71,
32
+ "ko": 95,
33
+ "nl": 251,
34
+ "pl": 224,
35
+ "pt": 203,
36
+ "ru": 182,
37
+ "tr": 226,
38
+ "zh": 82
39
+ },
40
+ "duration_const": 102400,
41
+ "enable_redaction": false,
42
+ "gpt_batch_size": 1,
43
+ "gpt_checkpointing": false,
44
+ "gpt_code_stride_len": 1024,
45
+ "gpt_cond_chunk_len": 4,
46
+ "gpt_cond_len": 30,
47
+ "gpt_layers": 30,
48
+ "gpt_max_audio_tokens": 605,
49
+ "gpt_max_prompt_tokens": 70,
50
+ "gpt_max_text_tokens": 402,
51
+ "gpt_n_heads": 16,
52
+ "gpt_n_model_channels": 1024,
53
+ "gpt_num_audio_tokens": 1026,
54
+ "gpt_number_text_tokens": 6681,
55
+ "gpt_start_audio_token": 1024,
56
+ "gpt_start_text_token": null,
57
+ "gpt_stop_audio_token": 1025,
58
+ "gpt_stop_text_token": null,
59
+ "gpt_train_solo_embeddings": false,
60
+ "gpt_use_masking_gt_prompt_approach": true,
61
+ "gpt_use_perceiver_resampler": true,
62
+ "kv_cache": true,
63
+ "label_smoothing": 0.0,
64
+ "languages": [
65
+ "en",
66
+ "es",
67
+ "fr",
68
+ "de",
69
+ "it",
70
+ "pt",
71
+ "pl",
72
+ "tr",
73
+ "ru",
74
+ "nl",
75
+ "cs",
76
+ "ar",
77
+ "zh-cn",
78
+ "hu",
79
+ "ko",
80
+ "ja",
81
+ "hi"
82
+ ],
83
+ "max_ref_len": 30,
84
+ "model_type": "xtts_gpt",
85
+ "num_chars": 255,
86
+ "perceiver_cond_length_compression": 256,
87
+ "repetition_penalty": 5.0,
88
+ "sound_norm_refs": false,
89
+ "temperature": 0.75,
90
+ "top_p": 0.85,
91
+ "transformers_version": "4.45.1",
92
+ "vocab_size": 256,
93
+ "cond_d_vector_in_each_upsampling_layer": true,
94
+ "d_vector_dim": 512,
95
+ "decoder_input_dim": 1024,
96
+ "input_sample_rate": 22050,
97
+ "hifi_model_type": "xtts_hifigan",
98
+ "output_hop_length": 256,
99
+ "output_sample_rate": 24000,
100
+ "resblock_dilation_sizes": [
101
+ [
102
+ 1,
103
+ 3,
104
+ 5
105
+ ],
106
+ [
107
+ 1,
108
+ 3,
109
+ 5
110
+ ],
111
+ [
112
+ 1,
113
+ 3,
114
+ 5
115
+ ]
116
+ ],
117
+ "resblock_kernel_sizes": [
118
+ 3,
119
+ 7,
120
+ 11
121
+ ],
122
+ "speaker_encoder_config": {
123
+ "model_config": null,
124
+ "model_name": "speaker_encoder",
125
+ "preprocess_config": null,
126
+ "speaker_embedding_dim": 512,
127
+ "use_torch_spec": true
128
+ },
129
+ "upsample_initial_channel": 512,
130
+ "upsample_kernel_sizes": [
131
+ 16,
132
+ 16,
133
+ 4,
134
+ 4
135
+ ],
136
+ "upsample_rates": [
137
+ 8,
138
+ 8,
139
+ 2,
140
+ 2
141
+ ]
142
+ }
gpt_config.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Dict, Optional, List
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+
9
+ @dataclass
10
+ class XTTSAudioConfig:
11
+ """Configuration for audio processing parameters"""
12
+ sample_rate: int = 22050
13
+ output_sample_rate: int = 24000
14
+ mel_channels: int = 80
15
+ hop_length: int = 256
16
+ win_length: int = 1024
17
+ n_fft: int = 1024
18
+ fmin: int = 0
19
+ fmax: int = 8000
20
+ power: float = 1.0
21
+ mel_norms_file: Optional[str] = None
22
+
23
+
24
+ class XTTSGPTConfig(PretrainedConfig):
25
+ """Configuration class for the GPT component of XTTS"""
26
+ model_type = "xtts_gpt"
27
+
28
+ def __init__(
29
+ self,
30
+ # Model architecture
31
+ vocab_size: int = 256,
32
+ num_chars: int = 255,
33
+
34
+ # GPT parameters
35
+ gpt_batch_size: int = 1,
36
+ gpt_max_audio_tokens: int = 605,
37
+ gpt_max_text_tokens: int = 402,
38
+ gpt_max_prompt_tokens: int = 70,
39
+ gpt_layers: int = 30,
40
+ gpt_n_model_channels: int = 1024,
41
+ gpt_n_heads: int = 16,
42
+ gpt_number_text_tokens: int = 6681,
43
+ gpt_start_text_token: Optional[int] = None,
44
+ gpt_stop_text_token: Optional[int] = None,
45
+ gpt_num_audio_tokens: int = 1026,
46
+ gpt_start_audio_token: int = 1024,
47
+ gpt_stop_audio_token: int = 1025,
48
+ gpt_code_stride_len: int = 1024,
49
+ gpt_use_masking_gt_prompt_approach: bool = True,
50
+ gpt_use_perceiver_resampler: bool = True,
51
+ gpt_checkpointing: bool = False,
52
+ gpt_train_solo_embeddings: bool = False,
53
+
54
+ # Training parameters
55
+ enable_redaction: bool = False,
56
+ kv_cache: bool = True,
57
+ perceiver_cond_length_compression: int = 256,
58
+ label_smoothing: float = 0.0,
59
+
60
+ # Generation parameters
61
+ temperature: float = 0.75,
62
+ length_penalty: float = 1.0,
63
+ repetition_penalty: float = 5.0,
64
+ top_k: int = 50,
65
+ top_p: float = 0.85,
66
+ gpt_cond_len: int = 30,
67
+ gpt_cond_chunk_len: int = 4,
68
+ max_ref_len: int = 30,
69
+ sound_norm_refs: bool = False,
70
+
71
+ # Audio processing
72
+ audio_config: Optional[XTTSAudioConfig] = None,
73
+
74
+ # Constants and limits
75
+ duration_const: int = 102400,
76
+ char_limits: Optional[Dict[str, int]] = None,
77
+ languages: Optional[List[str]] = None,
78
+ pad_token_id: Optional[int] = None,
79
+ bos_token_id: Optional[int] = None,
80
+ eos_token_id: Optional[int] = None,
81
+ **kwargs,
82
+ ):
83
+ if char_limits is None:
84
+ char_limits = {
85
+ "en": 250, "de": 253, "fr": 273, "es": 239,
86
+ "it": 213, "pt": 203, "pl": 224, "zh": 82,
87
+ "ar": 166, "cs": 186, "ru": 182, "nl": 251,
88
+ "tr": 226, "ja": 71, "hu": 224, "ko": 95,
89
+ }
90
+
91
+ if languages is None:
92
+ languages = [
93
+ "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl",
94
+ "cs", "ar", "zh-cn", "hu", "ko", "ja", "hi"
95
+ ]
96
+
97
+ if audio_config is None:
98
+ audio_config = XTTSAudioConfig()
99
+
100
+ super().__init__(
101
+ pad_token_id=pad_token_id,
102
+ bos_token_id=bos_token_id,
103
+ eos_token_id=eos_token_id,
104
+ **kwargs
105
+ )
106
+
107
+ self.vocab_size = vocab_size
108
+ self.num_chars = num_chars
109
+
110
+ # GPT parameters
111
+ self.gpt_batch_size = gpt_batch_size
112
+ self.gpt_max_audio_tokens = gpt_max_audio_tokens
113
+ self.gpt_max_text_tokens = gpt_max_text_tokens
114
+ self.gpt_max_prompt_tokens = gpt_max_prompt_tokens
115
+ self.gpt_layers = gpt_layers
116
+ self.gpt_n_model_channels = gpt_n_model_channels
117
+ self.gpt_n_heads = gpt_n_heads
118
+ self.gpt_number_text_tokens = gpt_number_text_tokens
119
+ self.gpt_start_text_token = gpt_start_text_token
120
+ self.gpt_stop_text_token = gpt_stop_text_token
121
+ self.gpt_num_audio_tokens = gpt_num_audio_tokens
122
+ self.gpt_start_audio_token = gpt_start_audio_token
123
+ self.gpt_stop_audio_token = gpt_stop_audio_token
124
+ self.gpt_code_stride_len = gpt_code_stride_len
125
+ self.gpt_use_masking_gt_prompt_approach = gpt_use_masking_gt_prompt_approach
126
+ self.gpt_use_perceiver_resampler = gpt_use_perceiver_resampler
127
+ self.gpt_checkpointing = gpt_checkpointing
128
+ self.gpt_train_solo_embeddings = gpt_train_solo_embeddings
129
+
130
+ # Training parameters
131
+ self.enable_redaction = enable_redaction
132
+ self.kv_cache = kv_cache
133
+ self.perceiver_cond_length_compression = perceiver_cond_length_compression
134
+ self.label_smoothing = label_smoothing
135
+
136
+ # Generation parameters
137
+ self.temperature = temperature
138
+ self.length_penalty = length_penalty
139
+ self.repetition_penalty = repetition_penalty
140
+ self.top_k = top_k
141
+ self.top_p = top_p
142
+ self.gpt_cond_len = gpt_cond_len
143
+ self.gpt_cond_chunk_len = gpt_cond_chunk_len
144
+ self.max_ref_len = max_ref_len
145
+ self.sound_norm_refs = sound_norm_refs
146
+
147
+ # Audio processing
148
+ self.audio_config = audio_config
149
+
150
+ # Constants and limits
151
+ self.duration_const = duration_const
152
+ self.char_limits = char_limits
153
+ self.languages = languages
154
+
155
+ def to_dict(self):
156
+ """Convert config to dictionary"""
157
+ config_dict = super().to_dict()
158
+ config_dict["audio_config"] = asdict(self.audio_config)
159
+ return config_dict
160
+
161
+ @classmethod
162
+ def from_dict(cls, config_dict):
163
+ """Create config from dictionary"""
164
+ audio_config = XTTSAudioConfig(**config_dict.pop("audio_config", {}))
165
+ return cls(audio_config=audio_config, **config_dict)
166
+
167
+ def update_with_tokenizer(self, tokenizer=None):
168
+ """Update configuration values based on tokenizer"""
169
+ if tokenizer is not None:
170
+ self.gpt_number_text_tokens = tokenizer.get_vocab_size()
171
+ self.gpt_start_text_token = tokenizer.bos_token_id
172
+ self.gpt_stop_text_token = tokenizer.eos_token_id
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[START]",
3
+ "eos_token": "[STOP]",
4
+ "pad_token": "[PAD]",
5
+ "unk_token": "[UNK]"
6
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Dict, Tuple, Any
2
+ import os
3
+ from functools import cached_property
4
+
5
+ from transformers import PreTrainedTokenizerFast
6
+ from transformers.tokenization_utils_base import TruncationStrategy, PaddingStrategy
7
+ from tokenizers import Tokenizer, processors
8
+ from tokenizers.pre_tokenizers import WhitespaceSplit
9
+ from tokenizers.processors import TemplateProcessing
10
+ import torch
11
+ from hangul_romanize import Transliter
12
+ from hangul_romanize.rule import academic
13
+ import cutlet
14
+
15
+ from TTS.tts.layers.xtts.tokenizer import (multilingual_cleaners, basic_cleaners,
16
+ chinese_transliterate, korean_transliterate,
17
+ japanese_cleaners)
18
+
19
+ class XTTSTokenizerFast(PreTrainedTokenizerFast):
20
+ """
21
+ Fast Tokenizer implementation for XTTS model using HuggingFace's PreTrainedTokenizerFast
22
+ """
23
+ def __init__(
24
+ self,
25
+ vocab_file: str = None,
26
+ tokenizer_object: Optional[Tokenizer] = None,
27
+ unk_token: str = "[UNK]",
28
+ pad_token: str = "[PAD]",
29
+ bos_token: str = "[START]",
30
+ eos_token: str = "[STOP]",
31
+ clean_up_tokenization_spaces: bool = True,
32
+ **kwargs
33
+ ):
34
+ if tokenizer_object is None and vocab_file is not None:
35
+ tokenizer_object = Tokenizer.from_file(vocab_file)
36
+
37
+ if tokenizer_object is not None:
38
+ # Configure the tokenizer
39
+ tokenizer_object.pre_tokenizer = WhitespaceSplit()
40
+ tokenizer_object.enable_padding(
41
+ direction='right',
42
+ pad_id=tokenizer_object.token_to_id(pad_token) or 0,
43
+ pad_token=pad_token
44
+ )
45
+ tokenizer_object.post_processor = TemplateProcessing(
46
+ single=f"{bos_token} $A {eos_token}",
47
+ special_tokens=[
48
+ (bos_token, tokenizer_object.token_to_id(bos_token)),
49
+ (eos_token, tokenizer_object.token_to_id(eos_token)),
50
+ ],
51
+ )
52
+
53
+ super().__init__(
54
+ tokenizer_object=tokenizer_object,
55
+ unk_token=unk_token,
56
+ pad_token=pad_token,
57
+ bos_token=bos_token,
58
+ eos_token=eos_token,
59
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
60
+ **kwargs
61
+ )
62
+
63
+ # Character limits per language
64
+ self.char_limits = {
65
+ "en": 250, "de": 253, "fr": 273, "es": 239,
66
+ "it": 213, "pt": 203, "pl": 224, "zh": 82,
67
+ "ar": 166, "cs": 186, "ru": 182, "nl": 251,
68
+ "tr": 226, "ja": 71, "hu": 224, "ko": 95,
69
+ }
70
+
71
+ # Initialize language tools
72
+ self._katsu = None
73
+ self._korean_transliter = Transliter(academic)
74
+
75
+ @cached_property
76
+ def katsu(self):
77
+ if self._katsu is None:
78
+ self._katsu = cutlet.Cutlet()
79
+ return self._katsu
80
+
81
+ def check_input_length(self, text: str, lang: str):
82
+ """Check if input text length is within limits for language"""
83
+ lang = lang.split("-")[0] # remove region
84
+ limit = self.char_limits.get(lang, 250)
85
+ if len(text) > limit:
86
+ print(f"Warning: Text length exceeds {limit} char limit for '{lang}', may cause truncation.")
87
+
88
+ def preprocess_text(self, text: str, lang: str) -> str:
89
+ """Apply text preprocessing for language"""
90
+ if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it",
91
+ "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
92
+ text = multilingual_cleaners(text, lang)
93
+ if lang == "zh":
94
+ text = chinese_transliterate(text)
95
+ if lang == "ko":
96
+ text = korean_transliterate(text)
97
+ elif lang == "ja":
98
+ text = japanese_cleaners(text, self.katsu)
99
+ else:
100
+ text = basic_cleaners(text)
101
+ return text
102
+
103
+ def _batch_encode_plus(
104
+ self,
105
+ batch_text_or_text_pairs,
106
+ add_special_tokens: bool = True,
107
+ padding_strategy = PaddingStrategy.DO_NOT_PAD,
108
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE,
109
+ max_length: Optional[int] = 402,
110
+ stride: int = 0,
111
+ is_split_into_words: bool = False,
112
+ pad_to_multiple_of: Optional[int] = None,
113
+ return_tensors: Optional[str] = None,
114
+ return_token_type_ids: Optional[bool] = None,
115
+ return_attention_mask: Optional[bool] = None,
116
+ return_overflowing_tokens: bool = False,
117
+ return_special_tokens_mask: bool = False,
118
+ return_offsets_mapping: bool = False,
119
+ return_length: bool = False,
120
+ verbose: bool = True,
121
+ **kwargs
122
+ ) -> Dict[str, Any]:
123
+ """
124
+ Override batch encoding to handle language-specific preprocessing
125
+ """
126
+ lang = kwargs.pop("lang", ["en"] * len(batch_text_or_text_pairs))
127
+ if isinstance(lang, str):
128
+ lang = [lang] * len(batch_text_or_text_pairs)
129
+
130
+ # Preprocess each text in the batch with its corresponding language
131
+ processed_texts = []
132
+ for text, text_lang in zip(batch_text_or_text_pairs, lang):
133
+ if isinstance(text, str):
134
+ # Check length and preprocess
135
+ self.check_input_length(text, text_lang)
136
+ processed_text = self.preprocess_text(text, text_lang)
137
+
138
+ # Format text with language tag and spaces
139
+ lang_code = "zh-cn" if text_lang == "zh" else text_lang
140
+ processed_text = f"[{lang_code}]{processed_text}"
141
+ processed_text = processed_text.replace(" ", "[SPACE]")
142
+
143
+ processed_texts.append(processed_text)
144
+ else:
145
+ processed_texts.append(text)
146
+
147
+ # Call the parent class's encoding method with processed texts
148
+ return super()._batch_encode_plus(
149
+ processed_texts,
150
+ add_special_tokens=add_special_tokens,
151
+ padding_strategy=padding_strategy,
152
+ truncation_strategy=truncation_strategy,
153
+ max_length=max_length,
154
+ stride=stride,
155
+ is_split_into_words=is_split_into_words,
156
+ pad_to_multiple_of=pad_to_multiple_of,
157
+ return_tensors=return_tensors,
158
+ return_token_type_ids=return_token_type_ids,
159
+ return_attention_mask=return_attention_mask,
160
+ return_overflowing_tokens=return_overflowing_tokens,
161
+ return_special_tokens_mask=return_special_tokens_mask,
162
+ return_offsets_mapping=return_offsets_mapping,
163
+ return_length=return_length,
164
+ verbose=verbose,
165
+ **kwargs
166
+ )
167
+
168
+ def __call__(
169
+ self,
170
+ text: Union[str, List[str]],
171
+ lang: Union[str, List[str]] = "en",
172
+ add_special_tokens: bool = True,
173
+ padding: Union[bool, str, PaddingStrategy] = True, # Changed default to True
174
+ truncation: Union[bool, str, TruncationStrategy] = True, # Changed default to True
175
+ max_length: Optional[int] = 402,
176
+ stride: int = 0,
177
+ return_tensors: Optional[str] = None,
178
+ return_token_type_ids: Optional[bool] = None,
179
+ return_attention_mask: Optional[bool] = True, # Changed default to True
180
+ **kwargs
181
+ ):
182
+ """
183
+ Main tokenization method
184
+ Args:
185
+ text: Text or list of texts to tokenize
186
+ lang: Language code or list of language codes corresponding to each text
187
+ add_special_tokens: Whether to add special tokens
188
+ padding: Padding strategy (default True)
189
+ truncation: Truncation strategy (default True)
190
+ max_length: Maximum length
191
+ stride: Stride for truncation
192
+ return_tensors: Format of output tensors ("pt" for PyTorch)
193
+ return_token_type_ids: Whether to return token type IDs
194
+ return_attention_mask: Whether to return attention mask (default True)
195
+ """
196
+ # Convert single string to list for batch processing
197
+ if isinstance(text, str):
198
+ text = [text]
199
+ if isinstance(lang, str):
200
+ lang = [lang]
201
+
202
+ # Ensure text and lang lists have same length
203
+ if len(text) != len(lang):
204
+ raise ValueError(f"Number of texts ({len(text)}) must match number of language codes ({len(lang)})")
205
+
206
+ # Convert padding strategy
207
+ if isinstance(padding, bool):
208
+ padding_strategy = PaddingStrategy.MAX_LENGTH if padding else PaddingStrategy.DO_NOT_PAD
209
+ else:
210
+ padding_strategy = PaddingStrategy(padding)
211
+
212
+ # Convert truncation strategy
213
+ if isinstance(truncation, bool):
214
+ truncation_strategy = TruncationStrategy.LONGEST_FIRST if truncation else TruncationStrategy.DO_NOT_TRUNCATE
215
+ else:
216
+ truncation_strategy = TruncationStrategy(truncation)
217
+
218
+ # Use the batch encoding method
219
+ encoded = self._batch_encode_plus(
220
+ text,
221
+ add_special_tokens=add_special_tokens,
222
+ padding_strategy=padding_strategy,
223
+ truncation_strategy=truncation_strategy,
224
+ max_length=max_length,
225
+ stride=stride,
226
+ return_tensors=return_tensors,
227
+ return_token_type_ids=return_token_type_ids,
228
+ return_attention_mask=return_attention_mask,
229
+ lang=lang,
230
+ **kwargs
231
+ )
232
+
233
+ return encoded
tokenizer_config.json ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[STOP]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[SPACE]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "259": {
28
+ "content": "[en]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "260": {
36
+ "content": "[de]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "261": {
44
+ "content": "[START]",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "262": {
52
+ "content": "[fr]",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "267": {
60
+ "content": "[ru]",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "284": {
68
+ "content": "[es]",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "285": {
76
+ "content": "[it]",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": true
82
+ },
83
+ "286": {
84
+ "content": "[pt]",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": true
90
+ },
91
+ "293": {
92
+ "content": "[cs]",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": true
98
+ },
99
+ "294": {
100
+ "content": "[pl]",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": true
106
+ },
107
+ "295": {
108
+ "content": "[tr]",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": true
114
+ },
115
+ "297": {
116
+ "content": "[nl]",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "5022": {
124
+ "content": "[ar]",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ },
131
+ "5023": {
132
+ "content": "[zh-cn]",
133
+ "lstrip": false,
134
+ "normalized": false,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": true
138
+ },
139
+ "5412": {
140
+ "content": "[ja]",
141
+ "lstrip": false,
142
+ "normalized": false,
143
+ "rstrip": false,
144
+ "single_word": false,
145
+ "special": true
146
+ },
147
+ "5753": {
148
+ "content": "[hu]",
149
+ "lstrip": false,
150
+ "normalized": false,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": true
154
+ },
155
+ "6152": {
156
+ "content": "[ko]",
157
+ "lstrip": false,
158
+ "normalized": false,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": true
162
+ },
163
+ "6680": {
164
+ "content": "[hi]",
165
+ "lstrip": false,
166
+ "normalized": false,
167
+ "rstrip": false,
168
+ "single_word": false,
169
+ "special": true
170
+ },
171
+ "6681": {
172
+ "content": "[PAD]",
173
+ "lstrip": false,
174
+ "normalized": false,
175
+ "rstrip": false,
176
+ "single_word": false,
177
+ "special": true
178
+ }
179
+ },
180
+ "bos_token": "[START]",
181
+ "clean_up_tokenization_spaces": true,
182
+ "eos_token": "[STOP]",
183
+ "max_length": null,
184
+ "model_max_length": 1000000000000000019884624838656,
185
+ "pad_to_multiple_of": null,
186
+ "pad_token": "[PAD]",
187
+ "pad_token_type_id": 0,
188
+ "padding_side": "right",
189
+ "tokenizer_class": "XTTSTokenizer",
190
+ "unk_token": "[UNK]"
191
+ }
xtts2_gpt_modeling.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import math
3
+ from array import array
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+ from typing import List, Optional, Union, Iterable, Tuple, Mapping
9
+
10
+ from transformers import PretrainedConfig
11
+ from vllm.attention import AttentionMetadata
12
+ from vllm.config import CacheConfig
13
+ from vllm.distributed import get_pp_group
14
+ from vllm.inputs import InputContext, INPUT_REGISTRY
15
+ from vllm.model_executor.layers.linear import ColumnParallelLinear
16
+ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
17
+ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
18
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
19
+ from vllm.model_executor.models.gpt2 import GPT2Block
20
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
21
+ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
22
+ from vllm.sequence import IntermediateTensors, SequenceData, VLLM_TOKEN_ID_ARRAY_TYPE
23
+ from vllm.model_executor.models.interfaces import SupportsMultiModal
24
+
25
+ from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder # noqa
26
+ from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler # noqa
27
+
28
+ from TTS.TTS.tts.layers.xtts.gpt import LearnedPositionEmbeddings
29
+
30
+ # Constants for token calculation
31
+ _AUDIO_PLACEHOLDER_TOKEN = 8192 # Using XTTS start_audio_token as placeholder
32
+ _AUDIO_TOKENS_PER_SECOND = 6.25
33
+ _CODE_STRIDE_LEN = 1024
34
+
35
+
36
+ def get_xtts_max_audio_tokens(ctx: InputContext) -> int:
37
+ """Calculate maximum audio tokens based on text context and audio duration."""
38
+ # Based on GPT config and common XTTS settings
39
+ text_context = ctx.model_config.max_seq_len - 100 # Reserve space for text
40
+ # Allow for ~30 seconds of audio (similar to whisper chunks)
41
+ max_audio_duration = 30.0
42
+ audio_tokens = math.ceil(max_audio_duration * _AUDIO_TOKENS_PER_SECOND)
43
+ total_tokens = text_context + audio_tokens + 4 # +4 for special tokens
44
+
45
+ return min(total_tokens, 1000) # Cap at 1000 tokens as specified
46
+
47
+
48
+ def dummy_seq_data_for_xtts(
49
+ ctx: InputContext,
50
+ seq_len: int,
51
+ audio_count: int,
52
+ ) -> SequenceData:
53
+ """Create dummy sequence data for XTTS profiling."""
54
+ # Calculate audio token space needed
55
+ audio_len_tokens = math.ceil(_AUDIO_TOKENS_PER_SECOND * 5) # Assume 5s per chunk
56
+ audio_placeholder = array(
57
+ VLLM_TOKEN_ID_ARRAY_TYPE,
58
+ [_AUDIO_PLACEHOLDER_TOKEN]
59
+ ) * audio_len_tokens
60
+
61
+ # Add separator between chunks
62
+ audio_token_ids = (audio_placeholder + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count
63
+
64
+ # Fill remaining sequence with padding
65
+ other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - len(audio_token_ids))
66
+
67
+ return SequenceData(audio_token_ids + other_token_ids)
68
+
69
+
70
+ def dummy_conditioning_for_xtts(
71
+ ctx: InputContext,
72
+ audio_count: int,
73
+ ) -> dict:
74
+ """Create dummy conditioning data for XTTS."""
75
+ return {
76
+ "cond_latents": [(torch.zeros(80, 1024), 22050) for _ in range(audio_count)]
77
+ }
78
+
79
+
80
+ def dummy_data_for_xtts(
81
+ ctx: InputContext,
82
+ seq_len: int,
83
+ mm_counts: Mapping[str, int],
84
+ ) -> Tuple[SequenceData, dict]:
85
+ """Create complete dummy data for XTTS profiling."""
86
+ audio_count = mm_counts["audio"]
87
+ seq_data = dummy_seq_data_for_xtts(ctx, seq_len, audio_count)
88
+ cond_data = dummy_conditioning_for_xtts(ctx, audio_count)
89
+ return (seq_data, cond_data)
90
+
91
+
92
+ def input_mapper_for_xtts(ctx: InputContext, data: object) -> MultiModalInputs:
93
+ """Map input data to XTTS format."""
94
+ if not isinstance(data, list):
95
+ data = [data]
96
+
97
+ # Each item should be a tuple of (mel_spec, sample_rate)
98
+ for audio_input in data:
99
+ if not isinstance(audio_input, tuple):
100
+ raise NotImplementedError(f"Unsupported data type: {type(audio_input)}")
101
+
102
+ return MultiModalInputs({"cond_latents": data})
103
+
104
+
105
+
106
+ @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_xtts)
107
+ @MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", get_xtts_max_audio_tokens)
108
+ @INPUT_REGISTRY.register_dummy_data(dummy_data_for_xtts)
109
+ class XttsGPT(nn.Module, SupportsMultiModal):
110
+ def __init__(
111
+ self,
112
+ config: PretrainedConfig,
113
+ cache_config: Optional[CacheConfig] = None,
114
+ quant_config: Optional["QuantizationConfig"] = None,
115
+ ):
116
+ super().__init__()
117
+ self.config = config
118
+ self.quant_config = quant_config
119
+
120
+ # XTTS specific components
121
+ self.conditioning_encoder = ConditioningEncoder(
122
+ 80, config.n_embd, num_attn_heads=config.n_head
123
+ )
124
+
125
+ if config.use_perceiver_resampler:
126
+ self.conditioning_perceiver = PerceiverResampler(
127
+ dim=config.n_embd,
128
+ depth=2,
129
+ dim_context=config.n_embd,
130
+ num_latents=32,
131
+ dim_head=64,
132
+ heads=8,
133
+ ff_mult=4,
134
+ use_flash_attn=False,
135
+ )
136
+
137
+ # Core GPT components following VLLM pattern
138
+ self.gpt = XttsGPT2Model(
139
+ config,
140
+ cache_config,
141
+ quant_config,
142
+ prefix="gpt"
143
+ )
144
+
145
+ # Prediction heads
146
+ self.text_head = ColumnParallelLinear(
147
+ config.n_embd,
148
+ config.vocab_size,
149
+ bias=False,
150
+ quant_config=quant_config,
151
+ prefix="text_head"
152
+ )
153
+
154
+ self.mel_head = ColumnParallelLinear(
155
+ config.n_embd,
156
+ config.num_audio_tokens,
157
+ bias=False,
158
+ quant_config=quant_config,
159
+ prefix="mel_head"
160
+ )
161
+
162
+ self.sampler = Sampler()
163
+
164
+ def get_style_emb(self, cond_input: torch.Tensor, return_latent: bool = False) -> torch.Tensor:
165
+ """Get conditioning embeddings from mel spectrograms."""
166
+ if not return_latent:
167
+ if cond_input.ndim == 4:
168
+ cond_input = cond_input.squeeze(1)
169
+ conds = self.conditioning_encoder(cond_input)
170
+
171
+ if hasattr(self, 'conditioning_perceiver'):
172
+ conds = self.conditioning_perceiver(
173
+ conds.permute(0, 2, 1)
174
+ ).transpose(1, 2)
175
+ else:
176
+ conds = cond_input.unsqueeze(1)
177
+ return conds
178
+
179
+ def forward(
180
+ self,
181
+ input_ids: torch.Tensor,
182
+ positions: torch.Tensor,
183
+ kv_caches: List[torch.Tensor],
184
+ attn_metadata: AttentionMetadata,
185
+ intermediate_tensors: Optional[IntermediateTensors] = None,
186
+ cond_latents: Optional[torch.Tensor] = None,
187
+ ) -> torch.Tensor:
188
+ """Forward pass following VLLM pattern."""
189
+ if cond_latents is not None:
190
+ # Combine conditioning with input embeddings
191
+ input_embeds = self.gpt.get_input_embeddings()(input_ids)
192
+ combined_embeds = torch.cat([cond_latents, input_embeds], dim=1)
193
+ hidden_states = self.gpt(
194
+ inputs_embeds=combined_embeds,
195
+ positions=positions,
196
+ kv_caches=kv_caches,
197
+ attn_metadata=attn_metadata,
198
+ intermediate_tensors=intermediate_tensors,
199
+ )
200
+ else:
201
+ hidden_states = self.gpt(
202
+ input_ids=input_ids,
203
+ positions=positions,
204
+ kv_caches=kv_caches,
205
+ attn_metadata=attn_metadata,
206
+ intermediate_tensors=intermediate_tensors,
207
+ )
208
+ return hidden_states
209
+
210
+ def compute_logits( # useless but kept for compatibility
211
+ self,
212
+ hidden_states: torch.Tensor,
213
+ sampling_metadata: SamplingMetadata,
214
+ ) -> torch.Tensor:
215
+ """Compute output logits."""
216
+ text_logits = self.text_head(hidden_states[sampling_metadata.selected_token_indices])
217
+ mel_logits = self.mel_head(hidden_states[sampling_metadata.selected_token_indices])
218
+ return torch.cat([text_logits, mel_logits], dim=1)
219
+
220
+
221
+ def sample(
222
+ self,
223
+ logits: torch.Tensor,
224
+ sampling_metadata: SamplingMetadata,
225
+ ) -> Optional[SamplerOutput]:
226
+ """Sample next tokens using VLLM sampler."""
227
+ return self.sampler(logits, sampling_metadata)
228
+
229
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
230
+ """Load weights following VLLM pattern."""
231
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
232
+
233
+ for name, loaded_weight in weights:
234
+ if name not in params_dict:
235
+ continue
236
+
237
+ param = params_dict[name]
238
+ if "c_attn" in name or "c_proj" in name or "c_fc" in name:
239
+ if name.endswith(".weight"):
240
+ loaded_weight = loaded_weight.t()
241
+
242
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
243
+ weight_loader(param, loaded_weight)
244
+
245
+
246
+ class XttsGPT2Model(nn.Module):
247
+ """VLLM-style implementation of GPT2 core architecture."""
248
+
249
+ def __init__(
250
+ self,
251
+ config: PretrainedConfig,
252
+ cache_config: Optional[CacheConfig] = None,
253
+ quant_config: Optional["QuantizationConfig"] = None,
254
+ prefix: str = "",
255
+ ):
256
+ super().__init__()
257
+ self.config = config
258
+ self.text_embedding = VocabParallelEmbedding(config.number_text_tokens, config.n_embd)
259
+ self.mel_embedding = VocabParallelEmbedding(config.num_audio_tokens, config.n_embd)
260
+
261
+ self.text_pos_embedding = (
262
+ LearnedPositionEmbeddings(config.max_text_seq_len, config.n_embd)
263
+ if config.max_mel_seq_len != -1
264
+ else functools.partial(config.null_position_embeddings, dim=config.n_embd)
265
+ )
266
+ self.mel_pos_embedding = (
267
+ LearnedPositionEmbeddings(config.max_mel_seq_len, config.n_embd)
268
+ if config.max_mel_seq_len != -1
269
+ else functools.partial(config.null_position_embeddings, dim=config.n_embd)
270
+ )
271
+ # Build gpt blocks
272
+ self.h = nn.ModuleList([
273
+ GPT2Block(
274
+ config,
275
+ cache_config,
276
+ quant_config,
277
+ prefix=f"{prefix}.h.{i}"
278
+ ) for i in range(config.num_hidden_layers)
279
+ ])
280
+
281
+ self.final_norm = nn.LayerNorm(
282
+ config.n_embd,
283
+ eps=config.layer_norm_epsilon
284
+ )
285
+
286
+ def forward( # TODO: this is not correct, allieeate it with the correct implementation
287
+ self,
288
+ input_ids: torch.Tensor,
289
+ position_ids: torch.Tensor,
290
+ kv_caches: List[torch.Tensor],
291
+ attn_metadata: AttentionMetadata,
292
+ intermediate_tensors: Optional[IntermediateTensors],
293
+ ) -> Union[torch.Tensor, IntermediateTensors]:
294
+ if get_pp_group().is_first_rank:
295
+ inputs_embeds = self.wte(input_ids)
296
+ position_embeds = self.wpe(position_ids)
297
+ hidden_states = inputs_embeds + position_embeds
298
+ else:
299
+ assert intermediate_tensors is not None
300
+ hidden_states = intermediate_tensors["hidden_states"]
301
+
302
+ for i in range(self.start_layer, self.end_layer):
303
+ layer = self.h[i]
304
+ hidden_states = layer(hidden_states,
305
+ kv_caches[i - self.start_layer],
306
+ attn_metadata)
307
+
308
+ if not get_pp_group().is_last_rank:
309
+ return IntermediateTensors({"hidden_states": hidden_states})
310
+
311
+ hidden_states = self.ln_f(hidden_states)
312
+ return hidden_states
xttsv2-gpt.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93fa43aaad29e232fa6c85f3d6c3285285c1fe4c89f9505d8153e231b12e1a50
3
+ size 1764117740