Yingxu He
commited on
Create vllm_meralion.py
Browse files- vllm_meralion.py +435 -0
vllm_meralion.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Inference-only MERaLiON AudioLLM model compatible with HuggingFace weights."""
|
2 |
+
from functools import lru_cache
|
3 |
+
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union
|
4 |
+
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from vllm.attention import AttentionMetadata
|
11 |
+
from vllm.config import VllmConfig
|
12 |
+
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
13 |
+
InputContext, token_inputs)
|
14 |
+
from vllm.logger import init_logger
|
15 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
16 |
+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
17 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
18 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
19 |
+
default_weight_loader, maybe_remap_kv_scale_name)
|
20 |
+
from vllm.model_executor.models.gemma2 import Gemma2Model
|
21 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
22 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
23 |
+
from vllm.multimodal.utils import consecutive_placeholder_ranges
|
24 |
+
from vllm.sequence import IntermediateTensors, SequenceData
|
25 |
+
|
26 |
+
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
|
27 |
+
from vllm.model_executor.models.utils import maybe_prefix
|
28 |
+
|
29 |
+
from .modeling_meralion import MERaLiONSpeechEncoder
|
30 |
+
|
31 |
+
logger = init_logger(__name__)
|
32 |
+
|
33 |
+
# gemma2 ties word embedding by default
|
34 |
+
_KEYS_TO_MODIFY_MAPPING = {
|
35 |
+
"text_decoder.model": "text_decoder",
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
# === Audio Inputs === #
|
40 |
+
class MERaLiONInputs(TypedDict):
|
41 |
+
input_features: torch.Tensor
|
42 |
+
"""Shape:
|
43 |
+
`(num_audios, num_mel_bins, 3000)`
|
44 |
+
"""
|
45 |
+
|
46 |
+
feature_attention_mask: torch.Tensor
|
47 |
+
"""Shape: `(num_audios, 3000)`
|
48 |
+
"""
|
49 |
+
|
50 |
+
|
51 |
+
# === Audio Encoder === #
|
52 |
+
class MERaLiONSpeechAudioAdaper(nn.Module):
|
53 |
+
def __init__(self, audio_hidden_size: int, text_hidden_size: int):
|
54 |
+
super(MERaLiONSpeechAudioAdaper, self).__init__()
|
55 |
+
speech_mlp_scale_factor = 15
|
56 |
+
|
57 |
+
self.speech_mlp_scale_factor = speech_mlp_scale_factor
|
58 |
+
self.mlp_adapter = nn.Sequential(
|
59 |
+
nn.Linear(
|
60 |
+
in_features=audio_hidden_size * speech_mlp_scale_factor,
|
61 |
+
out_features=audio_hidden_size
|
62 |
+
),
|
63 |
+
nn.SiLU(),
|
64 |
+
nn.Dropout(0.1),
|
65 |
+
)
|
66 |
+
|
67 |
+
self.speech_llm_proj = nn.Sequential(
|
68 |
+
nn.Linear(
|
69 |
+
audio_hidden_size,
|
70 |
+
audio_hidden_size * 4
|
71 |
+
),
|
72 |
+
nn.SiLU(),
|
73 |
+
nn.Dropout(0.1),
|
74 |
+
|
75 |
+
nn.Linear(
|
76 |
+
audio_hidden_size * 4,
|
77 |
+
text_hidden_size
|
78 |
+
),
|
79 |
+
)
|
80 |
+
|
81 |
+
def forward(self, speech_embeds, **kwargs):
|
82 |
+
B, T, C = speech_embeds.shape
|
83 |
+
speech_embeds = self.mlp_adapter(
|
84 |
+
speech_embeds.reshape(
|
85 |
+
B,
|
86 |
+
T // self.speech_mlp_scale_factor,
|
87 |
+
C * self.speech_mlp_scale_factor,
|
88 |
+
)
|
89 |
+
)
|
90 |
+
return self.speech_llm_proj(speech_embeds)
|
91 |
+
|
92 |
+
|
93 |
+
def dummy_data_for_meralion(ctx: InputContext, seq_len: int,
|
94 |
+
mm_counts: Mapping[str, int]):
|
95 |
+
num_audios = mm_counts["audio"]
|
96 |
+
max_tokens_per_audio = get_max_meralion_audio_tokens(ctx)
|
97 |
+
max_llm_audio_tokens = max_tokens_per_audio * num_audios
|
98 |
+
if seq_len - max_llm_audio_tokens - 2 < 0:
|
99 |
+
raise RuntimeError(
|
100 |
+
f"Qwen2-Audio cannot process {num_audios} audios in a prompt, "
|
101 |
+
"please increase max_model_len or reduce audio limit by "
|
102 |
+
"--limit-mm-per-prompt.")
|
103 |
+
|
104 |
+
speech_token_index = ctx.model_config.hf_config.speech_token_index
|
105 |
+
|
106 |
+
dummy_seqdata = SequenceData.from_prompt_token_counts(
|
107 |
+
(speech_token_index, max_llm_audio_tokens),
|
108 |
+
(0, seq_len - max_llm_audio_tokens),
|
109 |
+
)
|
110 |
+
dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.)
|
111 |
+
return DummyData(
|
112 |
+
dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, {
|
113 |
+
"audio":
|
114 |
+
consecutive_placeholder_ranges(num_items=num_audios,
|
115 |
+
item_size=max_tokens_per_audio)
|
116 |
+
})
|
117 |
+
|
118 |
+
|
119 |
+
def get_processor(
|
120 |
+
processor_name: str,
|
121 |
+
*args,
|
122 |
+
trust_remote_code: bool = True,
|
123 |
+
**kwargs,
|
124 |
+
):
|
125 |
+
"""Gets a processor for the given model name via HuggingFace.
|
126 |
+
|
127 |
+
Derived from `vllm.transformers_utils.image_processor.get_image_processor`.
|
128 |
+
"""
|
129 |
+
# don't put this import at the top level
|
130 |
+
# it will call torch.cuda.device_count()
|
131 |
+
from transformers import AutoProcessor
|
132 |
+
|
133 |
+
try:
|
134 |
+
processor = AutoProcessor.from_pretrained(
|
135 |
+
processor_name,
|
136 |
+
*args,
|
137 |
+
trust_remote_code=trust_remote_code,
|
138 |
+
**kwargs)
|
139 |
+
except ValueError as e:
|
140 |
+
# If the error pertains to the processor class not existing or not
|
141 |
+
# currently being imported, suggest using the --trust-remote-code flag.
|
142 |
+
# Unlike AutoTokenizer, AutoProcessor does not separate such errors
|
143 |
+
if not trust_remote_code:
|
144 |
+
err_msg = (
|
145 |
+
"Failed to load the processor. If the processor is "
|
146 |
+
"a custom processor not yet available in the HuggingFace "
|
147 |
+
"transformers library, consider setting "
|
148 |
+
"`trust_remote_code=True` in LLM or using the "
|
149 |
+
"`--trust-remote-code` flag in the CLI.")
|
150 |
+
raise RuntimeError(err_msg) from e
|
151 |
+
else:
|
152 |
+
raise e
|
153 |
+
|
154 |
+
return processor
|
155 |
+
|
156 |
+
|
157 |
+
cached_get_processor = lru_cache(get_processor)
|
158 |
+
|
159 |
+
|
160 |
+
def get_max_meralion_audio_tokens(ctx: InputContext) -> int:
|
161 |
+
"""
|
162 |
+
The max number of tokens after speech audio adapter.
|
163 |
+
"""
|
164 |
+
return 100
|
165 |
+
|
166 |
+
|
167 |
+
def input_processor_for_meralion(
|
168 |
+
ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
|
169 |
+
multi_modal_data = inputs.get("multi_modal_data")
|
170 |
+
if multi_modal_data is None or "audio" not in multi_modal_data:
|
171 |
+
return inputs
|
172 |
+
|
173 |
+
audios = multi_modal_data["audio"]
|
174 |
+
if not isinstance(audios, list):
|
175 |
+
audios = [audios]
|
176 |
+
|
177 |
+
if len(audios) == 0:
|
178 |
+
return inputs
|
179 |
+
|
180 |
+
processor = cached_get_processor(ctx.model_config.model)
|
181 |
+
resampled_audios = [
|
182 |
+
librosa.resample(audio,
|
183 |
+
orig_sr=sampling_rate,
|
184 |
+
target_sr=processor.feature_extractor.sampling_rate)
|
185 |
+
for audio, sampling_rate in audios
|
186 |
+
]
|
187 |
+
|
188 |
+
audio_input_lengths = np.array(
|
189 |
+
[min(3000, _.shape[0] // 160 + 1) for _ in resampled_audios])
|
190 |
+
|
191 |
+
audio_output_length = get_max_meralion_audio_tokens(ctx)
|
192 |
+
speech_token_index = ctx.model_config.hf_config.speech_token_index
|
193 |
+
|
194 |
+
input_ids = inputs['prompt_token_ids']
|
195 |
+
|
196 |
+
new_input_ids = []
|
197 |
+
audio_num = input_ids.count(speech_token_index)
|
198 |
+
assert len(audio_input_lengths) == audio_num, \
|
199 |
+
(f'The text input contains {audio_num} audio tokens, '
|
200 |
+
f'but {len(audio_input_lengths)} audios provided')
|
201 |
+
start = 0
|
202 |
+
for _ in range(audio_num):
|
203 |
+
end = input_ids.index(speech_token_index, start)
|
204 |
+
new_input_ids.extend(input_ids[start:end]) # text part
|
205 |
+
|
206 |
+
new_input_ids.extend([speech_token_index] * audio_output_length)
|
207 |
+
start = end + 1
|
208 |
+
new_input_ids.extend(input_ids[start:])
|
209 |
+
|
210 |
+
return token_inputs(
|
211 |
+
prompt_token_ids=new_input_ids,
|
212 |
+
prompt=inputs['prompt'],
|
213 |
+
multi_modal_data=multi_modal_data,
|
214 |
+
)
|
215 |
+
|
216 |
+
|
217 |
+
def input_mapper_for_meralion(
|
218 |
+
ctx: InputContext,
|
219 |
+
multi_modal_data: Union[np.ndarray, List[np.ndarray]],
|
220 |
+
) -> MultiModalKwargs:
|
221 |
+
"""Input mapper for Qwen2-Audio."""
|
222 |
+
if not isinstance(multi_modal_data, list):
|
223 |
+
multi_modal_data = [multi_modal_data]
|
224 |
+
|
225 |
+
if len(multi_modal_data) == 0:
|
226 |
+
return MultiModalKwargs()
|
227 |
+
|
228 |
+
processor = cached_get_processor(ctx.model_config.model)
|
229 |
+
audio_feature_extractor = processor.feature_extractor
|
230 |
+
if audio_feature_extractor is None:
|
231 |
+
raise RuntimeError(
|
232 |
+
"No HuggingFace audio_feature_extractor is available "
|
233 |
+
"to process the audio object")
|
234 |
+
|
235 |
+
try:
|
236 |
+
resampled_audios = [
|
237 |
+
librosa.resample(
|
238 |
+
audio,
|
239 |
+
orig_sr=sampling_rate,
|
240 |
+
target_sr=processor.feature_extractor.sampling_rate)
|
241 |
+
for audio, sampling_rate in multi_modal_data
|
242 |
+
]
|
243 |
+
batch_data = audio_feature_extractor(resampled_audios,
|
244 |
+
sampling_rate=16000,
|
245 |
+
return_attention_mask=True,
|
246 |
+
padding="max_length",
|
247 |
+
return_tensors="pt").data
|
248 |
+
batch_data["feature_attention_mask"] = batch_data.pop("attention_mask")
|
249 |
+
except Exception:
|
250 |
+
logger.error("Failed to process audio (%s)", multi_modal_data)
|
251 |
+
raise
|
252 |
+
|
253 |
+
return MultiModalKwargs(batch_data)
|
254 |
+
|
255 |
+
|
256 |
+
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_meralion)
|
257 |
+
@INPUT_REGISTRY.register_input_processor(input_processor_for_meralion)
|
258 |
+
@MULTIMODAL_REGISTRY.register_input_mapper("audio",
|
259 |
+
input_mapper_for_meralion)
|
260 |
+
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
261 |
+
"audio", get_max_meralion_audio_tokens)
|
262 |
+
class MERaLiONForConditionalGeneration(nn.Module, SupportsMultiModal,
|
263 |
+
SupportsPP):
|
264 |
+
|
265 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
266 |
+
super().__init__()
|
267 |
+
config = vllm_config.model_config.hf_config
|
268 |
+
quant_config = vllm_config.quant_config
|
269 |
+
multimodal_config = vllm_config.model_config.multimodal_config
|
270 |
+
self.config = config
|
271 |
+
self.multimodal_config = multimodal_config
|
272 |
+
|
273 |
+
self.speech_encoder = MERaLiONSpeechEncoder(config.speech_config)
|
274 |
+
self.ln_speech = nn.LayerNorm(config.speech_config.d_model)
|
275 |
+
self.speech_audio_adapter = MERaLiONSpeechAudioAdaper(
|
276 |
+
config.speech_config.d_model, config.text_config.hidden_size)
|
277 |
+
|
278 |
+
self.quant_config = quant_config
|
279 |
+
|
280 |
+
self.text_decoder = Gemma2Model(
|
281 |
+
vllm_config=vllm_config.with_hf_config(config.text_config),
|
282 |
+
prefix=maybe_prefix(prefix, "model"))
|
283 |
+
self.unpadded_vocab_size = config.text_config.vocab_size
|
284 |
+
if config.text_config.tie_word_embeddings:
|
285 |
+
self.lm_head = self.text_decoder.embed_tokens
|
286 |
+
else:
|
287 |
+
self.lm_head = ParallelLMHead(config.text_config.vocab_size,
|
288 |
+
config.text_config.hidden_size,
|
289 |
+
quant_config=quant_config)
|
290 |
+
logit_scale = getattr(config, "logit_scale", 1.0)
|
291 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
292 |
+
config.text_config.vocab_size,
|
293 |
+
logit_scale)
|
294 |
+
self.sampler = get_sampler()
|
295 |
+
|
296 |
+
self.make_empty_intermediate_tensors = (
|
297 |
+
self.text_decoder.make_empty_intermediate_tensors)
|
298 |
+
|
299 |
+
def _validate_and_reshape_mm_tensor(self,
|
300 |
+
mm_input: Union[torch.Tensor,
|
301 |
+
List[torch.Tensor]],
|
302 |
+
name: str) -> torch.Tensor:
|
303 |
+
if not isinstance(mm_input, (torch.Tensor, list)):
|
304 |
+
raise ValueError(f"Incorrect type of {name}. "
|
305 |
+
f"Got type: {type(mm_input)}")
|
306 |
+
if isinstance(mm_input, torch.Tensor):
|
307 |
+
return torch.concat(list(mm_input))
|
308 |
+
else:
|
309 |
+
return torch.concat(mm_input)
|
310 |
+
|
311 |
+
def _parse_and_validate_audio_input(
|
312 |
+
self, **kwargs: object) -> Optional[MERaLiONInputs]:
|
313 |
+
input_features = kwargs.pop('input_features', None)
|
314 |
+
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
|
315 |
+
if input_features is None:
|
316 |
+
return None
|
317 |
+
input_features = self._validate_and_reshape_mm_tensor(
|
318 |
+
input_features, 'input_features')
|
319 |
+
feature_attention_mask = self._validate_and_reshape_mm_tensor(
|
320 |
+
feature_attention_mask, 'feature_attention_mask')
|
321 |
+
if not isinstance(input_features, (torch.Tensor, list)):
|
322 |
+
raise ValueError("Incorrect type of audio input features. "
|
323 |
+
f"Got type: {type(input_features)}")
|
324 |
+
return MERaLiONInputs(input_features=input_features,
|
325 |
+
feature_attention_mask=feature_attention_mask)
|
326 |
+
|
327 |
+
def _process_audio_input(self,
|
328 |
+
audio_input: MERaLiONInputs) -> torch.Tensor:
|
329 |
+
|
330 |
+
input_features = audio_input["input_features"].to(self.speech_encoder.dtype)
|
331 |
+
feature_attention_mask = audio_input["feature_attention_mask"]
|
332 |
+
|
333 |
+
audio_outputs = self.speech_encoder(input_features,
|
334 |
+
attention_mask=feature_attention_mask)
|
335 |
+
audio_features = audio_outputs.last_hidden_state
|
336 |
+
audio_features = self.ln_speech(audio_features)
|
337 |
+
audio_features = self.speech_audio_adapter(audio_features)
|
338 |
+
audio_features = audio_features.view(-1, audio_features.size(-1))
|
339 |
+
|
340 |
+
return audio_features
|
341 |
+
|
342 |
+
def forward(
|
343 |
+
self,
|
344 |
+
input_ids: torch.Tensor,
|
345 |
+
positions: torch.Tensor,
|
346 |
+
kv_caches: List[torch.Tensor],
|
347 |
+
attn_metadata: AttentionMetadata,
|
348 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
349 |
+
**kwargs: object,
|
350 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
351 |
+
if intermediate_tensors is not None:
|
352 |
+
input_ids = None
|
353 |
+
inputs_embeds = None
|
354 |
+
else:
|
355 |
+
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
356 |
+
|
357 |
+
if audio_input is None:
|
358 |
+
inputs_embeds = None
|
359 |
+
else:
|
360 |
+
inputs_embeds = self.text_decoder.embed_tokens(input_ids)
|
361 |
+
processed_audio_features = self._process_audio_input(audio_input)
|
362 |
+
# merge llm embeddings and audio features
|
363 |
+
mask = (input_ids == self.config.speech_token_index)
|
364 |
+
inputs_embeds[mask, :] = processed_audio_features
|
365 |
+
|
366 |
+
input_ids = None
|
367 |
+
|
368 |
+
hidden_states = self.text_decoder(
|
369 |
+
input_ids=input_ids,
|
370 |
+
positions=positions,
|
371 |
+
kv_caches=kv_caches,
|
372 |
+
attn_metadata=attn_metadata,
|
373 |
+
intermediate_tensors=intermediate_tensors,
|
374 |
+
inputs_embeds=inputs_embeds,
|
375 |
+
)
|
376 |
+
return hidden_states
|
377 |
+
|
378 |
+
def compute_logits(self, hidden_states: torch.Tensor,
|
379 |
+
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
380 |
+
logits = self.logits_processor(self.lm_head, hidden_states,
|
381 |
+
sampling_metadata)
|
382 |
+
return logits
|
383 |
+
|
384 |
+
def sample(
|
385 |
+
self,
|
386 |
+
logits: torch.Tensor,
|
387 |
+
sampling_metadata: SamplingMetadata,
|
388 |
+
) -> Optional[SamplerOutput]:
|
389 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
390 |
+
return next_tokens
|
391 |
+
|
392 |
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
393 |
+
stacked_params_mapping = [
|
394 |
+
# (param_name, shard_name, shard_id)
|
395 |
+
("qkv_proj", "q_proj", "q"),
|
396 |
+
("qkv_proj", "k_proj", "k"),
|
397 |
+
("qkv_proj", "v_proj", "v"),
|
398 |
+
("gate_up_proj", "gate_proj", 0),
|
399 |
+
("gate_up_proj", "up_proj", 1),
|
400 |
+
]
|
401 |
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
402 |
+
|
403 |
+
for name, loaded_weight in weights:
|
404 |
+
if "rotary_emb.inv_freq" in name:
|
405 |
+
continue
|
406 |
+
if (self.config.text_config.tie_word_embeddings
|
407 |
+
and "lm_head.weight" in name):
|
408 |
+
continue
|
409 |
+
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
410 |
+
if key_to_modify in name:
|
411 |
+
name = name.replace(key_to_modify, new_key)
|
412 |
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
413 |
+
if weight_name not in name or 'speech_encoder' in name:
|
414 |
+
continue
|
415 |
+
name = name.replace(weight_name, param_name)
|
416 |
+
# Skip loading extra bias for GPTQ models.
|
417 |
+
if name.endswith(".bias") and name not in params_dict:
|
418 |
+
continue
|
419 |
+
param = params_dict[name]
|
420 |
+
weight_loader = param.weight_loader
|
421 |
+
weight_loader(param, loaded_weight, shard_id)
|
422 |
+
break
|
423 |
+
else:
|
424 |
+
# Skip loading extra bias for GPTQ models.
|
425 |
+
if name.endswith(".bias") and name not in params_dict:
|
426 |
+
continue
|
427 |
+
# Remapping the name of FP8 kv-scale.
|
428 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
429 |
+
if name is None:
|
430 |
+
continue
|
431 |
+
|
432 |
+
param = params_dict[name]
|
433 |
+
weight_loader = getattr(param, "weight_loader",
|
434 |
+
default_weight_loader)
|
435 |
+
weight_loader(param, loaded_weight)
|