Yingxu He commited on
Commit
128578f
·
verified ·
1 Parent(s): 22e0652

Create vllm_meralion.py

Browse files
Files changed (1) hide show
  1. vllm_meralion.py +435 -0
vllm_meralion.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference-only MERaLiON AudioLLM model compatible with HuggingFace weights."""
2
+ from functools import lru_cache
3
+ from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from vllm.attention import AttentionMetadata
11
+ from vllm.config import VllmConfig
12
+ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
13
+ InputContext, token_inputs)
14
+ from vllm.logger import init_logger
15
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
16
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
17
+ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
18
+ from vllm.model_executor.model_loader.weight_utils import (
19
+ default_weight_loader, maybe_remap_kv_scale_name)
20
+ from vllm.model_executor.models.gemma2 import Gemma2Model
21
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
22
+ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
23
+ from vllm.multimodal.utils import consecutive_placeholder_ranges
24
+ from vllm.sequence import IntermediateTensors, SequenceData
25
+
26
+ from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
27
+ from vllm.model_executor.models.utils import maybe_prefix
28
+
29
+ from .modeling_meralion import MERaLiONSpeechEncoder
30
+
31
+ logger = init_logger(__name__)
32
+
33
+ # gemma2 ties word embedding by default
34
+ _KEYS_TO_MODIFY_MAPPING = {
35
+ "text_decoder.model": "text_decoder",
36
+ }
37
+
38
+
39
+ # === Audio Inputs === #
40
+ class MERaLiONInputs(TypedDict):
41
+ input_features: torch.Tensor
42
+ """Shape:
43
+ `(num_audios, num_mel_bins, 3000)`
44
+ """
45
+
46
+ feature_attention_mask: torch.Tensor
47
+ """Shape: `(num_audios, 3000)`
48
+ """
49
+
50
+
51
+ # === Audio Encoder === #
52
+ class MERaLiONSpeechAudioAdaper(nn.Module):
53
+ def __init__(self, audio_hidden_size: int, text_hidden_size: int):
54
+ super(MERaLiONSpeechAudioAdaper, self).__init__()
55
+ speech_mlp_scale_factor = 15
56
+
57
+ self.speech_mlp_scale_factor = speech_mlp_scale_factor
58
+ self.mlp_adapter = nn.Sequential(
59
+ nn.Linear(
60
+ in_features=audio_hidden_size * speech_mlp_scale_factor,
61
+ out_features=audio_hidden_size
62
+ ),
63
+ nn.SiLU(),
64
+ nn.Dropout(0.1),
65
+ )
66
+
67
+ self.speech_llm_proj = nn.Sequential(
68
+ nn.Linear(
69
+ audio_hidden_size,
70
+ audio_hidden_size * 4
71
+ ),
72
+ nn.SiLU(),
73
+ nn.Dropout(0.1),
74
+
75
+ nn.Linear(
76
+ audio_hidden_size * 4,
77
+ text_hidden_size
78
+ ),
79
+ )
80
+
81
+ def forward(self, speech_embeds, **kwargs):
82
+ B, T, C = speech_embeds.shape
83
+ speech_embeds = self.mlp_adapter(
84
+ speech_embeds.reshape(
85
+ B,
86
+ T // self.speech_mlp_scale_factor,
87
+ C * self.speech_mlp_scale_factor,
88
+ )
89
+ )
90
+ return self.speech_llm_proj(speech_embeds)
91
+
92
+
93
+ def dummy_data_for_meralion(ctx: InputContext, seq_len: int,
94
+ mm_counts: Mapping[str, int]):
95
+ num_audios = mm_counts["audio"]
96
+ max_tokens_per_audio = get_max_meralion_audio_tokens(ctx)
97
+ max_llm_audio_tokens = max_tokens_per_audio * num_audios
98
+ if seq_len - max_llm_audio_tokens - 2 < 0:
99
+ raise RuntimeError(
100
+ f"Qwen2-Audio cannot process {num_audios} audios in a prompt, "
101
+ "please increase max_model_len or reduce audio limit by "
102
+ "--limit-mm-per-prompt.")
103
+
104
+ speech_token_index = ctx.model_config.hf_config.speech_token_index
105
+
106
+ dummy_seqdata = SequenceData.from_prompt_token_counts(
107
+ (speech_token_index, max_llm_audio_tokens),
108
+ (0, seq_len - max_llm_audio_tokens),
109
+ )
110
+ dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.)
111
+ return DummyData(
112
+ dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, {
113
+ "audio":
114
+ consecutive_placeholder_ranges(num_items=num_audios,
115
+ item_size=max_tokens_per_audio)
116
+ })
117
+
118
+
119
+ def get_processor(
120
+ processor_name: str,
121
+ *args,
122
+ trust_remote_code: bool = True,
123
+ **kwargs,
124
+ ):
125
+ """Gets a processor for the given model name via HuggingFace.
126
+
127
+ Derived from `vllm.transformers_utils.image_processor.get_image_processor`.
128
+ """
129
+ # don't put this import at the top level
130
+ # it will call torch.cuda.device_count()
131
+ from transformers import AutoProcessor
132
+
133
+ try:
134
+ processor = AutoProcessor.from_pretrained(
135
+ processor_name,
136
+ *args,
137
+ trust_remote_code=trust_remote_code,
138
+ **kwargs)
139
+ except ValueError as e:
140
+ # If the error pertains to the processor class not existing or not
141
+ # currently being imported, suggest using the --trust-remote-code flag.
142
+ # Unlike AutoTokenizer, AutoProcessor does not separate such errors
143
+ if not trust_remote_code:
144
+ err_msg = (
145
+ "Failed to load the processor. If the processor is "
146
+ "a custom processor not yet available in the HuggingFace "
147
+ "transformers library, consider setting "
148
+ "`trust_remote_code=True` in LLM or using the "
149
+ "`--trust-remote-code` flag in the CLI.")
150
+ raise RuntimeError(err_msg) from e
151
+ else:
152
+ raise e
153
+
154
+ return processor
155
+
156
+
157
+ cached_get_processor = lru_cache(get_processor)
158
+
159
+
160
+ def get_max_meralion_audio_tokens(ctx: InputContext) -> int:
161
+ """
162
+ The max number of tokens after speech audio adapter.
163
+ """
164
+ return 100
165
+
166
+
167
+ def input_processor_for_meralion(
168
+ ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
169
+ multi_modal_data = inputs.get("multi_modal_data")
170
+ if multi_modal_data is None or "audio" not in multi_modal_data:
171
+ return inputs
172
+
173
+ audios = multi_modal_data["audio"]
174
+ if not isinstance(audios, list):
175
+ audios = [audios]
176
+
177
+ if len(audios) == 0:
178
+ return inputs
179
+
180
+ processor = cached_get_processor(ctx.model_config.model)
181
+ resampled_audios = [
182
+ librosa.resample(audio,
183
+ orig_sr=sampling_rate,
184
+ target_sr=processor.feature_extractor.sampling_rate)
185
+ for audio, sampling_rate in audios
186
+ ]
187
+
188
+ audio_input_lengths = np.array(
189
+ [min(3000, _.shape[0] // 160 + 1) for _ in resampled_audios])
190
+
191
+ audio_output_length = get_max_meralion_audio_tokens(ctx)
192
+ speech_token_index = ctx.model_config.hf_config.speech_token_index
193
+
194
+ input_ids = inputs['prompt_token_ids']
195
+
196
+ new_input_ids = []
197
+ audio_num = input_ids.count(speech_token_index)
198
+ assert len(audio_input_lengths) == audio_num, \
199
+ (f'The text input contains {audio_num} audio tokens, '
200
+ f'but {len(audio_input_lengths)} audios provided')
201
+ start = 0
202
+ for _ in range(audio_num):
203
+ end = input_ids.index(speech_token_index, start)
204
+ new_input_ids.extend(input_ids[start:end]) # text part
205
+
206
+ new_input_ids.extend([speech_token_index] * audio_output_length)
207
+ start = end + 1
208
+ new_input_ids.extend(input_ids[start:])
209
+
210
+ return token_inputs(
211
+ prompt_token_ids=new_input_ids,
212
+ prompt=inputs['prompt'],
213
+ multi_modal_data=multi_modal_data,
214
+ )
215
+
216
+
217
+ def input_mapper_for_meralion(
218
+ ctx: InputContext,
219
+ multi_modal_data: Union[np.ndarray, List[np.ndarray]],
220
+ ) -> MultiModalKwargs:
221
+ """Input mapper for Qwen2-Audio."""
222
+ if not isinstance(multi_modal_data, list):
223
+ multi_modal_data = [multi_modal_data]
224
+
225
+ if len(multi_modal_data) == 0:
226
+ return MultiModalKwargs()
227
+
228
+ processor = cached_get_processor(ctx.model_config.model)
229
+ audio_feature_extractor = processor.feature_extractor
230
+ if audio_feature_extractor is None:
231
+ raise RuntimeError(
232
+ "No HuggingFace audio_feature_extractor is available "
233
+ "to process the audio object")
234
+
235
+ try:
236
+ resampled_audios = [
237
+ librosa.resample(
238
+ audio,
239
+ orig_sr=sampling_rate,
240
+ target_sr=processor.feature_extractor.sampling_rate)
241
+ for audio, sampling_rate in multi_modal_data
242
+ ]
243
+ batch_data = audio_feature_extractor(resampled_audios,
244
+ sampling_rate=16000,
245
+ return_attention_mask=True,
246
+ padding="max_length",
247
+ return_tensors="pt").data
248
+ batch_data["feature_attention_mask"] = batch_data.pop("attention_mask")
249
+ except Exception:
250
+ logger.error("Failed to process audio (%s)", multi_modal_data)
251
+ raise
252
+
253
+ return MultiModalKwargs(batch_data)
254
+
255
+
256
+ @INPUT_REGISTRY.register_dummy_data(dummy_data_for_meralion)
257
+ @INPUT_REGISTRY.register_input_processor(input_processor_for_meralion)
258
+ @MULTIMODAL_REGISTRY.register_input_mapper("audio",
259
+ input_mapper_for_meralion)
260
+ @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
261
+ "audio", get_max_meralion_audio_tokens)
262
+ class MERaLiONForConditionalGeneration(nn.Module, SupportsMultiModal,
263
+ SupportsPP):
264
+
265
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
266
+ super().__init__()
267
+ config = vllm_config.model_config.hf_config
268
+ quant_config = vllm_config.quant_config
269
+ multimodal_config = vllm_config.model_config.multimodal_config
270
+ self.config = config
271
+ self.multimodal_config = multimodal_config
272
+
273
+ self.speech_encoder = MERaLiONSpeechEncoder(config.speech_config)
274
+ self.ln_speech = nn.LayerNorm(config.speech_config.d_model)
275
+ self.speech_audio_adapter = MERaLiONSpeechAudioAdaper(
276
+ config.speech_config.d_model, config.text_config.hidden_size)
277
+
278
+ self.quant_config = quant_config
279
+
280
+ self.text_decoder = Gemma2Model(
281
+ vllm_config=vllm_config.with_hf_config(config.text_config),
282
+ prefix=maybe_prefix(prefix, "model"))
283
+ self.unpadded_vocab_size = config.text_config.vocab_size
284
+ if config.text_config.tie_word_embeddings:
285
+ self.lm_head = self.text_decoder.embed_tokens
286
+ else:
287
+ self.lm_head = ParallelLMHead(config.text_config.vocab_size,
288
+ config.text_config.hidden_size,
289
+ quant_config=quant_config)
290
+ logit_scale = getattr(config, "logit_scale", 1.0)
291
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
292
+ config.text_config.vocab_size,
293
+ logit_scale)
294
+ self.sampler = get_sampler()
295
+
296
+ self.make_empty_intermediate_tensors = (
297
+ self.text_decoder.make_empty_intermediate_tensors)
298
+
299
+ def _validate_and_reshape_mm_tensor(self,
300
+ mm_input: Union[torch.Tensor,
301
+ List[torch.Tensor]],
302
+ name: str) -> torch.Tensor:
303
+ if not isinstance(mm_input, (torch.Tensor, list)):
304
+ raise ValueError(f"Incorrect type of {name}. "
305
+ f"Got type: {type(mm_input)}")
306
+ if isinstance(mm_input, torch.Tensor):
307
+ return torch.concat(list(mm_input))
308
+ else:
309
+ return torch.concat(mm_input)
310
+
311
+ def _parse_and_validate_audio_input(
312
+ self, **kwargs: object) -> Optional[MERaLiONInputs]:
313
+ input_features = kwargs.pop('input_features', None)
314
+ feature_attention_mask = kwargs.pop('feature_attention_mask', None)
315
+ if input_features is None:
316
+ return None
317
+ input_features = self._validate_and_reshape_mm_tensor(
318
+ input_features, 'input_features')
319
+ feature_attention_mask = self._validate_and_reshape_mm_tensor(
320
+ feature_attention_mask, 'feature_attention_mask')
321
+ if not isinstance(input_features, (torch.Tensor, list)):
322
+ raise ValueError("Incorrect type of audio input features. "
323
+ f"Got type: {type(input_features)}")
324
+ return MERaLiONInputs(input_features=input_features,
325
+ feature_attention_mask=feature_attention_mask)
326
+
327
+ def _process_audio_input(self,
328
+ audio_input: MERaLiONInputs) -> torch.Tensor:
329
+
330
+ input_features = audio_input["input_features"].to(self.speech_encoder.dtype)
331
+ feature_attention_mask = audio_input["feature_attention_mask"]
332
+
333
+ audio_outputs = self.speech_encoder(input_features,
334
+ attention_mask=feature_attention_mask)
335
+ audio_features = audio_outputs.last_hidden_state
336
+ audio_features = self.ln_speech(audio_features)
337
+ audio_features = self.speech_audio_adapter(audio_features)
338
+ audio_features = audio_features.view(-1, audio_features.size(-1))
339
+
340
+ return audio_features
341
+
342
+ def forward(
343
+ self,
344
+ input_ids: torch.Tensor,
345
+ positions: torch.Tensor,
346
+ kv_caches: List[torch.Tensor],
347
+ attn_metadata: AttentionMetadata,
348
+ intermediate_tensors: Optional[IntermediateTensors] = None,
349
+ **kwargs: object,
350
+ ) -> Union[torch.Tensor, IntermediateTensors]:
351
+ if intermediate_tensors is not None:
352
+ input_ids = None
353
+ inputs_embeds = None
354
+ else:
355
+ audio_input = self._parse_and_validate_audio_input(**kwargs)
356
+
357
+ if audio_input is None:
358
+ inputs_embeds = None
359
+ else:
360
+ inputs_embeds = self.text_decoder.embed_tokens(input_ids)
361
+ processed_audio_features = self._process_audio_input(audio_input)
362
+ # merge llm embeddings and audio features
363
+ mask = (input_ids == self.config.speech_token_index)
364
+ inputs_embeds[mask, :] = processed_audio_features
365
+
366
+ input_ids = None
367
+
368
+ hidden_states = self.text_decoder(
369
+ input_ids=input_ids,
370
+ positions=positions,
371
+ kv_caches=kv_caches,
372
+ attn_metadata=attn_metadata,
373
+ intermediate_tensors=intermediate_tensors,
374
+ inputs_embeds=inputs_embeds,
375
+ )
376
+ return hidden_states
377
+
378
+ def compute_logits(self, hidden_states: torch.Tensor,
379
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
380
+ logits = self.logits_processor(self.lm_head, hidden_states,
381
+ sampling_metadata)
382
+ return logits
383
+
384
+ def sample(
385
+ self,
386
+ logits: torch.Tensor,
387
+ sampling_metadata: SamplingMetadata,
388
+ ) -> Optional[SamplerOutput]:
389
+ next_tokens = self.sampler(logits, sampling_metadata)
390
+ return next_tokens
391
+
392
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
393
+ stacked_params_mapping = [
394
+ # (param_name, shard_name, shard_id)
395
+ ("qkv_proj", "q_proj", "q"),
396
+ ("qkv_proj", "k_proj", "k"),
397
+ ("qkv_proj", "v_proj", "v"),
398
+ ("gate_up_proj", "gate_proj", 0),
399
+ ("gate_up_proj", "up_proj", 1),
400
+ ]
401
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
402
+
403
+ for name, loaded_weight in weights:
404
+ if "rotary_emb.inv_freq" in name:
405
+ continue
406
+ if (self.config.text_config.tie_word_embeddings
407
+ and "lm_head.weight" in name):
408
+ continue
409
+ for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
410
+ if key_to_modify in name:
411
+ name = name.replace(key_to_modify, new_key)
412
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
413
+ if weight_name not in name or 'speech_encoder' in name:
414
+ continue
415
+ name = name.replace(weight_name, param_name)
416
+ # Skip loading extra bias for GPTQ models.
417
+ if name.endswith(".bias") and name not in params_dict:
418
+ continue
419
+ param = params_dict[name]
420
+ weight_loader = param.weight_loader
421
+ weight_loader(param, loaded_weight, shard_id)
422
+ break
423
+ else:
424
+ # Skip loading extra bias for GPTQ models.
425
+ if name.endswith(".bias") and name not in params_dict:
426
+ continue
427
+ # Remapping the name of FP8 kv-scale.
428
+ name = maybe_remap_kv_scale_name(name, params_dict)
429
+ if name is None:
430
+ continue
431
+
432
+ param = params_dict[name]
433
+ weight_loader = getattr(param, "weight_loader",
434
+ default_weight_loader)
435
+ weight_loader(param, loaded_weight)