YingxuHe commited on
Commit
1ee1019
·
1 Parent(s): 6d20c84

update vllm serving guide

Browse files
README.md CHANGED
@@ -446,7 +446,7 @@ libri_data = load_dataset("distil-whisper/librispeech_long", "clean", split="val
446
  audio_array = libri_data[0]["audio"]["array"]
447
  inputs = processor(text=chat_prompt, audios=audio_array)
448
 
449
- outputs = model.generate(**inputs, max_new_tokens=128)
450
  generated_ids = outputs[:, inputs['input_ids'].size(1):]
451
  response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
452
  ```
@@ -490,7 +490,7 @@ libri_data = load_dataset("distil-whisper/librispeech_long", "clean", split="val
490
  audio_array = [libri_data[0]["audio"]["array"]]*2
491
  inputs = processor(text=chat_prompt, audios=audio_array)
492
 
493
- outputs = model.generate(**inputs, max_new_tokens=128)
494
  generated_ids = outputs[:, inputs['input_ids'].size(1):]
495
  response = processor.batch_decode(generated_ids, skip_special_tokens=True)
496
  ```
@@ -527,9 +527,7 @@ def run_meralion(question: str):
527
 
528
  llm = LLM(model=model_name,
529
  tokenizer=model_name,
530
- tokenizer_mode="slow",
531
- max_model_len=4096,
532
- max_num_seqs=5,
533
  limit_mm_per_prompt={"audio": 1},
534
  trust_remote_code=True,
535
  dtype=torch.bfloat16
@@ -550,9 +548,15 @@ llm, prompt, stop_token_ids = run_meralion(question)
550
 
551
  # We set temperature to 0.2 so that outputs can be different
552
  # even when all prompts are identical when running batch inference.
553
- sampling_params = SamplingParams(temperature=0.2,
554
- max_tokens=64,
555
- stop_token_ids=stop_token_ids)
 
 
 
 
 
 
556
 
557
  mm_data = {"audio": [audio_asset.audio_and_sample_rate]}
558
  inputs = {"prompt": prompt, "multi_modal_data": mm_data}
@@ -569,7 +573,6 @@ for o in outputs:
569
 
570
  #### OpenAI Compatible Server
571
 
572
-
573
  **server**
574
 
575
  Here is an example to start the server via the `vllm serve` command.
@@ -577,7 +580,7 @@ Here is an example to start the server via the `vllm serve` command.
577
  ```bash
578
  export HF_TOKEN=your-hf-token
579
 
580
- vllm serve MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --tokenizer MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --tokenizer-mode slow --max-num-seqs 8 --trust-remote-code --dtype bfloat16
581
  ```
582
 
583
  **client**
 
446
  audio_array = libri_data[0]["audio"]["array"]
447
  inputs = processor(text=chat_prompt, audios=audio_array)
448
 
449
+ outputs = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.1, top_p=0.9, repetition_penalty=1.1)
450
  generated_ids = outputs[:, inputs['input_ids'].size(1):]
451
  response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
452
  ```
 
490
  audio_array = [libri_data[0]["audio"]["array"]]*2
491
  inputs = processor(text=chat_prompt, audios=audio_array)
492
 
493
+ outputs = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.1, top_p=0.9, repetition_penalty=1.1)
494
  generated_ids = outputs[:, inputs['input_ids'].size(1):]
495
  response = processor.batch_decode(generated_ids, skip_special_tokens=True)
496
  ```
 
527
 
528
  llm = LLM(model=model_name,
529
  tokenizer=model_name,
530
+ max_num_seqs=8,
 
 
531
  limit_mm_per_prompt={"audio": 1},
532
  trust_remote_code=True,
533
  dtype=torch.bfloat16
 
548
 
549
  # We set temperature to 0.2 so that outputs can be different
550
  # even when all prompts are identical when running batch inference.
551
+ sampling_params = SamplingParams(
552
+ temperature=0.1,
553
+ top_p=0.9,
554
+ top_k=50,
555
+ repetition_penalty=1.1,
556
+ seed=42,
557
+ max_tokens=1024,
558
+ stop_token_ids=None
559
+ )
560
 
561
  mm_data = {"audio": [audio_asset.audio_and_sample_rate]}
562
  inputs = {"prompt": prompt, "multi_modal_data": mm_data}
 
573
 
574
  #### OpenAI Compatible Server
575
 
 
576
  **server**
577
 
578
  Here is an example to start the server via the `vllm serve` command.
 
580
  ```bash
581
  export HF_TOKEN=your-hf-token
582
 
583
+ vllm serve MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --tokenizer MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --max-num-seqs 8 --trust-remote-code --dtype bfloat16
584
  ```
585
 
586
  **client**
vllm_plugin_meralion/README.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## MERaLiON vLLM serving
2
+
3
+ ### Set up Environment
4
+
5
+ MERaLiON-AudioLLM requires vLLM version `0.6.4.post1` and transformers `4.46.3`
6
+
7
+ ```bash
8
+ pip install vllm==0.6.4.post1
9
+ pip install transformers==4.46.3
10
+ ```
11
+
12
+ As the [vLLM documentation](https://docs.vllm.ai/en/stable/models/adding_model.html#out-of-tree-model-integration) recommends,
13
+ we provide a way to register our model via [vLLM plugins](https://docs.vllm.ai/en/stable/design/plugin_system.html#plugin-system).
14
+
15
+
16
+ ```bash
17
+ python install .
18
+ ```
19
+
20
+ ### Serving
21
+
22
+ Here is an example to start the server via the `vllm serve` command.
23
+
24
+ ```bash
25
+ export HF_TOKEN=<your-hf-token>
26
+
27
+ vllm serve MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --tokenizer MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --max-num-seqs 8 --trust-remote-code --dtype bfloat16 --port 8000
28
+ ```
29
+
30
+ To call the server, you can use the [official OpenAI client](https://github.com/openai/openai-python):
31
+
32
+ ```python
33
+ import base64
34
+
35
+ from openai import OpenAI
36
+
37
+
38
+ def get_client(api_key="EMPTY", base_url="http://localhost:8000/v1"):
39
+ client = OpenAI(
40
+ api_key=api_key,
41
+ base_url=base_url,
42
+ )
43
+
44
+ models = client.models.list()
45
+ model_name = models.data[0].id
46
+ return client, model_name
47
+
48
+
49
+ def get_response(text_input, base64_audio_input, **params):
50
+ response_obj = client.chat.completions.create(
51
+ messages=[{
52
+ "role":
53
+ "user",
54
+ "content": [
55
+ {
56
+ "type": "text",
57
+ "text": f"Text instruction: {text_input}"
58
+ },
59
+ {
60
+ "type": "audio_url",
61
+ "audio_url": {
62
+ "url": f"data:audio/ogg;base64,{base64_audio_input}"
63
+ },
64
+ },
65
+ ],
66
+ }],
67
+ **params
68
+ )
69
+ return response_obj
70
+
71
+
72
+ #specify input and params
73
+ possible_text_inputs = [
74
+ "Please transcribe this speech.",
75
+ "Please summarise the content of this speech.",
76
+ "Please follow the instruction in this speech."
77
+ ]
78
+
79
+ audio_bytes = open(f"/path/to/wav/or/mp3/file", "rb").read()
80
+ audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
81
+
82
+ # use the port number of vllm service.
83
+ client, model_name = get_client(base_url="http://localhost:8000/v1")
84
+
85
+ generation_parameters = dict(
86
+ model=model_name,
87
+ max_completion_tokens=1024,
88
+ temperature=0.1,
89
+ top_p=0.9,
90
+ extra_body={
91
+ "repetition_penalty": 1.1,
92
+ "top_k": 50,
93
+ "length_penalty": 1.0
94
+ },
95
+ seed=42
96
+ )
97
+
98
+
99
+ response_obj = get_response(possible_text_inputs[0], audio_base64, **generation_parameters)
100
+ print(response_obj.choices[0].message.content)
101
+ ```
102
+
103
+ Alternatively, can try calling the server with curl command.
104
+
105
+ ```bash
106
+ curl http://localhost:8000/v1/chat/completions \
107
+ -H "Content-Type: application/json" \
108
+ -d '{
109
+ "model": "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION",
110
+ "messages": [
111
+ {"role": "system", "content": [{"type": "text", "text": "Text instruction: <your-command>"}, {"type":"audio_url", "audio_url": {"url": "data:audio/ogg;base64,<audio base64>"}}]},
112
+ ]
113
+ }'
114
+
115
+ ```
vllm_plugin_meralion/set_up.py CHANGED
@@ -1,7 +1,7 @@
1
  from setuptools import setup
2
 
3
  setup(name='vllm_plugin_meralion',
4
- version='0.1',
5
  packages=['vllm_plugin_meralion'],
6
  entry_points={
7
  'vllm.general_plugins':
 
1
  from setuptools import setup
2
 
3
  setup(name='vllm_plugin_meralion',
4
+ version='0.2',
5
  packages=['vllm_plugin_meralion'],
6
  entry_points={
7
  'vllm.general_plugins':
vllm_plugin_meralion/vllm_plugin_meralion/modeling_text_decoder.py CHANGED
@@ -1316,4 +1316,4 @@ class MERaLiONTextForTokenClassification(MERaLiONTextPreTrainedModel):
1316
  logits=logits,
1317
  hidden_states=outputs.hidden_states,
1318
  attentions=outputs.attentions,
1319
- )
 
1316
  logits=logits,
1317
  hidden_states=outputs.hidden_states,
1318
  attentions=outputs.attentions,
1319
+ )
vllm_plugin_meralion/vllm_plugin_meralion/vllm_meralion.py CHANGED
@@ -35,6 +35,12 @@ _KEYS_TO_MODIFY_MAPPING = {
35
  "text_decoder.model": "text_decoder",
36
  }
37
 
 
 
 
 
 
 
38
 
39
  # === Audio Inputs === #
40
  class MERaLiONInputs(TypedDict):
@@ -107,9 +113,9 @@ def dummy_data_for_meralion(ctx: InputContext, seq_len: int,
107
  (speech_token_index, max_llm_audio_tokens),
108
  (0, seq_len - max_llm_audio_tokens),
109
  )
110
- dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.)
111
  return DummyData(
112
- dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, {
113
  "audio":
114
  consecutive_placeholder_ranges(num_items=num_audios,
115
  item_size=max_tokens_per_audio)
@@ -157,11 +163,33 @@ def get_processor(
157
  cached_get_processor = lru_cache(get_processor)
158
 
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  def get_max_meralion_audio_tokens(ctx: InputContext) -> int:
161
  """
162
  The max number of tokens after speech audio adapter.
163
  """
164
- return 100
165
 
166
 
167
  def input_processor_for_meralion(
@@ -184,26 +212,24 @@ def input_processor_for_meralion(
184
  target_sr=processor.feature_extractor.sampling_rate)
185
  for audio, sampling_rate in audios
186
  ]
187
-
188
- audio_input_lengths = np.array(
189
- [min(3000, _.shape[0] // 160 + 1) for _ in resampled_audios])
190
-
191
- audio_output_length = get_max_meralion_audio_tokens(ctx)
192
  speech_token_index = ctx.model_config.hf_config.speech_token_index
193
 
194
  input_ids = inputs['prompt_token_ids']
195
 
196
  new_input_ids = []
197
  audio_num = input_ids.count(speech_token_index)
198
- assert len(audio_input_lengths) == audio_num, \
199
  (f'The text input contains {audio_num} audio tokens, '
200
- f'but {len(audio_input_lengths)} audios provided')
201
  start = 0
202
- for _ in range(audio_num):
203
  end = input_ids.index(speech_token_index, start)
204
  new_input_ids.extend(input_ids[start:end]) # text part
205
 
206
- new_input_ids.extend([speech_token_index] * audio_output_length)
 
207
  start = end + 1
208
  new_input_ids.extend(input_ids[start:])
209
 
@@ -240,6 +266,9 @@ def input_mapper_for_meralion(
240
  target_sr=processor.feature_extractor.sampling_rate)
241
  for audio, sampling_rate in multi_modal_data
242
  ]
 
 
 
243
  batch_data = audio_feature_extractor(resampled_audios,
244
  sampling_rate=16000,
245
  return_attention_mask=True,
@@ -291,6 +320,7 @@ class MERaLiONForConditionalGeneration(nn.Module, SupportsMultiModal,
291
  self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
292
  config.text_config.vocab_size,
293
  logit_scale)
 
294
  self.sampler = get_sampler()
295
 
296
  self.make_empty_intermediate_tensors = (
 
35
  "text_decoder.model": "text_decoder",
36
  }
37
 
38
+ # === Constants === #
39
+ DEFAULT_SAMPLE_RATE = 16000
40
+ FEATURE_CHUNK_SIZE = DEFAULT_SAMPLE_RATE * 30
41
+ OUTPUT_CHUNK_SIZE = 100
42
+ MAX_NUMBER_CHUNKS = 8
43
+
44
 
45
  # === Audio Inputs === #
46
  class MERaLiONInputs(TypedDict):
 
113
  (speech_token_index, max_llm_audio_tokens),
114
  (0, seq_len - max_llm_audio_tokens),
115
  )
116
+ dummy_audio = np.full((max_llm_audio_tokens * 15 * 2 * 160, ), 0.)
117
  return DummyData(
118
+ dummy_seqdata, {"audio": [(dummy_audio, DEFAULT_SAMPLE_RATE)] * num_audios}, {
119
  "audio":
120
  consecutive_placeholder_ranges(num_items=num_audios,
121
  item_size=max_tokens_per_audio)
 
163
  cached_get_processor = lru_cache(get_processor)
164
 
165
 
166
+ def _get_number_chunks(audios: List[np.ndarray]):
167
+ audio_lengths = np.array([_.shape[0] for _ in audios])
168
+ number_chunks = (audio_lengths // FEATURE_CHUNK_SIZE) + 1
169
+ return np.clip(number_chunks, a_min=None, a_max=MAX_NUMBER_CHUNKS)
170
+
171
+
172
+ def _get_feat_extract_output_lengths(audios: List[np.ndarray]):
173
+ return _get_number_chunks(audios) * OUTPUT_CHUNK_SIZE
174
+
175
+
176
+ def _get_chunked_audios(audios: List[np.ndarray]):
177
+ audio_number_chunks = _get_number_chunks(audios)
178
+ chunked_resampled_audios = []
179
+
180
+ for audio_idx, audio in enumerate(audios):
181
+ for cid in range(audio_number_chunks[audio_idx]):
182
+ chunked_resampled_audios.append(
183
+ audio[cid * FEATURE_CHUNK_SIZE: (cid + 1) * FEATURE_CHUNK_SIZE].copy()
184
+ )
185
+ return chunked_resampled_audios
186
+
187
+
188
  def get_max_meralion_audio_tokens(ctx: InputContext) -> int:
189
  """
190
  The max number of tokens after speech audio adapter.
191
  """
192
+ return MAX_NUMBER_CHUNKS * OUTPUT_CHUNK_SIZE
193
 
194
 
195
  def input_processor_for_meralion(
 
212
  target_sr=processor.feature_extractor.sampling_rate)
213
  for audio, sampling_rate in audios
214
  ]
215
+
216
+ audio_output_lengths = _get_feat_extract_output_lengths(resampled_audios)
 
 
 
217
  speech_token_index = ctx.model_config.hf_config.speech_token_index
218
 
219
  input_ids = inputs['prompt_token_ids']
220
 
221
  new_input_ids = []
222
  audio_num = input_ids.count(speech_token_index)
223
+ assert len(audio_output_lengths) == audio_num, \
224
  (f'The text input contains {audio_num} audio tokens, '
225
+ f'but {len(audio_output_lengths)} audios provided')
226
  start = 0
227
+ for audio_idx in range(audio_num):
228
  end = input_ids.index(speech_token_index, start)
229
  new_input_ids.extend(input_ids[start:end]) # text part
230
 
231
+ new_input_ids.extend([speech_token_index] *
232
+ audio_output_lengths[audio_idx])
233
  start = end + 1
234
  new_input_ids.extend(input_ids[start:])
235
 
 
266
  target_sr=processor.feature_extractor.sampling_rate)
267
  for audio, sampling_rate in multi_modal_data
268
  ]
269
+
270
+ resampled_audios = _get_chunked_audios(resampled_audios)
271
+
272
  batch_data = audio_feature_extractor(resampled_audios,
273
  sampling_rate=16000,
274
  return_attention_mask=True,
 
320
  self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
321
  config.text_config.vocab_size,
322
  logit_scale)
323
+
324
  self.sampler = get_sampler()
325
 
326
  self.make_empty_intermediate_tensors = (