update vllm serving guide
Browse files
README.md
CHANGED
@@ -446,7 +446,7 @@ libri_data = load_dataset("distil-whisper/librispeech_long", "clean", split="val
|
|
446 |
audio_array = libri_data[0]["audio"]["array"]
|
447 |
inputs = processor(text=chat_prompt, audios=audio_array)
|
448 |
|
449 |
-
outputs = model.generate(**inputs, max_new_tokens=
|
450 |
generated_ids = outputs[:, inputs['input_ids'].size(1):]
|
451 |
response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
452 |
```
|
@@ -490,7 +490,7 @@ libri_data = load_dataset("distil-whisper/librispeech_long", "clean", split="val
|
|
490 |
audio_array = [libri_data[0]["audio"]["array"]]*2
|
491 |
inputs = processor(text=chat_prompt, audios=audio_array)
|
492 |
|
493 |
-
outputs = model.generate(**inputs, max_new_tokens=
|
494 |
generated_ids = outputs[:, inputs['input_ids'].size(1):]
|
495 |
response = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
496 |
```
|
@@ -527,9 +527,7 @@ def run_meralion(question: str):
|
|
527 |
|
528 |
llm = LLM(model=model_name,
|
529 |
tokenizer=model_name,
|
530 |
-
|
531 |
-
max_model_len=4096,
|
532 |
-
max_num_seqs=5,
|
533 |
limit_mm_per_prompt={"audio": 1},
|
534 |
trust_remote_code=True,
|
535 |
dtype=torch.bfloat16
|
@@ -550,9 +548,15 @@ llm, prompt, stop_token_ids = run_meralion(question)
|
|
550 |
|
551 |
# We set temperature to 0.2 so that outputs can be different
|
552 |
# even when all prompts are identical when running batch inference.
|
553 |
-
sampling_params = SamplingParams(
|
554 |
-
|
555 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
556 |
|
557 |
mm_data = {"audio": [audio_asset.audio_and_sample_rate]}
|
558 |
inputs = {"prompt": prompt, "multi_modal_data": mm_data}
|
@@ -569,7 +573,6 @@ for o in outputs:
|
|
569 |
|
570 |
#### OpenAI Compatible Server
|
571 |
|
572 |
-
|
573 |
**server**
|
574 |
|
575 |
Here is an example to start the server via the `vllm serve` command.
|
@@ -577,7 +580,7 @@ Here is an example to start the server via the `vllm serve` command.
|
|
577 |
```bash
|
578 |
export HF_TOKEN=your-hf-token
|
579 |
|
580 |
-
vllm serve MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --tokenizer MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --
|
581 |
```
|
582 |
|
583 |
**client**
|
|
|
446 |
audio_array = libri_data[0]["audio"]["array"]
|
447 |
inputs = processor(text=chat_prompt, audios=audio_array)
|
448 |
|
449 |
+
outputs = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.1, top_p=0.9, repetition_penalty=1.1)
|
450 |
generated_ids = outputs[:, inputs['input_ids'].size(1):]
|
451 |
response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
452 |
```
|
|
|
490 |
audio_array = [libri_data[0]["audio"]["array"]]*2
|
491 |
inputs = processor(text=chat_prompt, audios=audio_array)
|
492 |
|
493 |
+
outputs = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.1, top_p=0.9, repetition_penalty=1.1)
|
494 |
generated_ids = outputs[:, inputs['input_ids'].size(1):]
|
495 |
response = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
496 |
```
|
|
|
527 |
|
528 |
llm = LLM(model=model_name,
|
529 |
tokenizer=model_name,
|
530 |
+
max_num_seqs=8,
|
|
|
|
|
531 |
limit_mm_per_prompt={"audio": 1},
|
532 |
trust_remote_code=True,
|
533 |
dtype=torch.bfloat16
|
|
|
548 |
|
549 |
# We set temperature to 0.2 so that outputs can be different
|
550 |
# even when all prompts are identical when running batch inference.
|
551 |
+
sampling_params = SamplingParams(
|
552 |
+
temperature=0.1,
|
553 |
+
top_p=0.9,
|
554 |
+
top_k=50,
|
555 |
+
repetition_penalty=1.1,
|
556 |
+
seed=42,
|
557 |
+
max_tokens=1024,
|
558 |
+
stop_token_ids=None
|
559 |
+
)
|
560 |
|
561 |
mm_data = {"audio": [audio_asset.audio_and_sample_rate]}
|
562 |
inputs = {"prompt": prompt, "multi_modal_data": mm_data}
|
|
|
573 |
|
574 |
#### OpenAI Compatible Server
|
575 |
|
|
|
576 |
**server**
|
577 |
|
578 |
Here is an example to start the server via the `vllm serve` command.
|
|
|
580 |
```bash
|
581 |
export HF_TOKEN=your-hf-token
|
582 |
|
583 |
+
vllm serve MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --tokenizer MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --max-num-seqs 8 --trust-remote-code --dtype bfloat16
|
584 |
```
|
585 |
|
586 |
**client**
|
vllm_plugin_meralion/README.md
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## MERaLiON vLLM serving
|
2 |
+
|
3 |
+
### Set up Environment
|
4 |
+
|
5 |
+
MERaLiON-AudioLLM requires vLLM version `0.6.4.post1` and transformers `4.46.3`
|
6 |
+
|
7 |
+
```bash
|
8 |
+
pip install vllm==0.6.4.post1
|
9 |
+
pip install transformers==4.46.3
|
10 |
+
```
|
11 |
+
|
12 |
+
As the [vLLM documentation](https://docs.vllm.ai/en/stable/models/adding_model.html#out-of-tree-model-integration) recommends,
|
13 |
+
we provide a way to register our model via [vLLM plugins](https://docs.vllm.ai/en/stable/design/plugin_system.html#plugin-system).
|
14 |
+
|
15 |
+
|
16 |
+
```bash
|
17 |
+
python install .
|
18 |
+
```
|
19 |
+
|
20 |
+
### Serving
|
21 |
+
|
22 |
+
Here is an example to start the server via the `vllm serve` command.
|
23 |
+
|
24 |
+
```bash
|
25 |
+
export HF_TOKEN=<your-hf-token>
|
26 |
+
|
27 |
+
vllm serve MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --tokenizer MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --max-num-seqs 8 --trust-remote-code --dtype bfloat16 --port 8000
|
28 |
+
```
|
29 |
+
|
30 |
+
To call the server, you can use the [official OpenAI client](https://github.com/openai/openai-python):
|
31 |
+
|
32 |
+
```python
|
33 |
+
import base64
|
34 |
+
|
35 |
+
from openai import OpenAI
|
36 |
+
|
37 |
+
|
38 |
+
def get_client(api_key="EMPTY", base_url="http://localhost:8000/v1"):
|
39 |
+
client = OpenAI(
|
40 |
+
api_key=api_key,
|
41 |
+
base_url=base_url,
|
42 |
+
)
|
43 |
+
|
44 |
+
models = client.models.list()
|
45 |
+
model_name = models.data[0].id
|
46 |
+
return client, model_name
|
47 |
+
|
48 |
+
|
49 |
+
def get_response(text_input, base64_audio_input, **params):
|
50 |
+
response_obj = client.chat.completions.create(
|
51 |
+
messages=[{
|
52 |
+
"role":
|
53 |
+
"user",
|
54 |
+
"content": [
|
55 |
+
{
|
56 |
+
"type": "text",
|
57 |
+
"text": f"Text instruction: {text_input}"
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"type": "audio_url",
|
61 |
+
"audio_url": {
|
62 |
+
"url": f"data:audio/ogg;base64,{base64_audio_input}"
|
63 |
+
},
|
64 |
+
},
|
65 |
+
],
|
66 |
+
}],
|
67 |
+
**params
|
68 |
+
)
|
69 |
+
return response_obj
|
70 |
+
|
71 |
+
|
72 |
+
#specify input and params
|
73 |
+
possible_text_inputs = [
|
74 |
+
"Please transcribe this speech.",
|
75 |
+
"Please summarise the content of this speech.",
|
76 |
+
"Please follow the instruction in this speech."
|
77 |
+
]
|
78 |
+
|
79 |
+
audio_bytes = open(f"/path/to/wav/or/mp3/file", "rb").read()
|
80 |
+
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
|
81 |
+
|
82 |
+
# use the port number of vllm service.
|
83 |
+
client, model_name = get_client(base_url="http://localhost:8000/v1")
|
84 |
+
|
85 |
+
generation_parameters = dict(
|
86 |
+
model=model_name,
|
87 |
+
max_completion_tokens=1024,
|
88 |
+
temperature=0.1,
|
89 |
+
top_p=0.9,
|
90 |
+
extra_body={
|
91 |
+
"repetition_penalty": 1.1,
|
92 |
+
"top_k": 50,
|
93 |
+
"length_penalty": 1.0
|
94 |
+
},
|
95 |
+
seed=42
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
response_obj = get_response(possible_text_inputs[0], audio_base64, **generation_parameters)
|
100 |
+
print(response_obj.choices[0].message.content)
|
101 |
+
```
|
102 |
+
|
103 |
+
Alternatively, can try calling the server with curl command.
|
104 |
+
|
105 |
+
```bash
|
106 |
+
curl http://localhost:8000/v1/chat/completions \
|
107 |
+
-H "Content-Type: application/json" \
|
108 |
+
-d '{
|
109 |
+
"model": "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION",
|
110 |
+
"messages": [
|
111 |
+
{"role": "system", "content": [{"type": "text", "text": "Text instruction: <your-command>"}, {"type":"audio_url", "audio_url": {"url": "data:audio/ogg;base64,<audio base64>"}}]},
|
112 |
+
]
|
113 |
+
}'
|
114 |
+
|
115 |
+
```
|
vllm_plugin_meralion/set_up.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from setuptools import setup
|
2 |
|
3 |
setup(name='vllm_plugin_meralion',
|
4 |
-
version='0.
|
5 |
packages=['vllm_plugin_meralion'],
|
6 |
entry_points={
|
7 |
'vllm.general_plugins':
|
|
|
1 |
from setuptools import setup
|
2 |
|
3 |
setup(name='vllm_plugin_meralion',
|
4 |
+
version='0.2',
|
5 |
packages=['vllm_plugin_meralion'],
|
6 |
entry_points={
|
7 |
'vllm.general_plugins':
|
vllm_plugin_meralion/vllm_plugin_meralion/modeling_text_decoder.py
CHANGED
@@ -1316,4 +1316,4 @@ class MERaLiONTextForTokenClassification(MERaLiONTextPreTrainedModel):
|
|
1316 |
logits=logits,
|
1317 |
hidden_states=outputs.hidden_states,
|
1318 |
attentions=outputs.attentions,
|
1319 |
-
)
|
|
|
1316 |
logits=logits,
|
1317 |
hidden_states=outputs.hidden_states,
|
1318 |
attentions=outputs.attentions,
|
1319 |
+
)
|
vllm_plugin_meralion/vllm_plugin_meralion/vllm_meralion.py
CHANGED
@@ -35,6 +35,12 @@ _KEYS_TO_MODIFY_MAPPING = {
|
|
35 |
"text_decoder.model": "text_decoder",
|
36 |
}
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
# === Audio Inputs === #
|
40 |
class MERaLiONInputs(TypedDict):
|
@@ -107,9 +113,9 @@ def dummy_data_for_meralion(ctx: InputContext, seq_len: int,
|
|
107 |
(speech_token_index, max_llm_audio_tokens),
|
108 |
(0, seq_len - max_llm_audio_tokens),
|
109 |
)
|
110 |
-
dummy_audio = np.full((max_llm_audio_tokens *
|
111 |
return DummyData(
|
112 |
-
dummy_seqdata, {"audio": [(dummy_audio,
|
113 |
"audio":
|
114 |
consecutive_placeholder_ranges(num_items=num_audios,
|
115 |
item_size=max_tokens_per_audio)
|
@@ -157,11 +163,33 @@ def get_processor(
|
|
157 |
cached_get_processor = lru_cache(get_processor)
|
158 |
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
def get_max_meralion_audio_tokens(ctx: InputContext) -> int:
|
161 |
"""
|
162 |
The max number of tokens after speech audio adapter.
|
163 |
"""
|
164 |
-
return
|
165 |
|
166 |
|
167 |
def input_processor_for_meralion(
|
@@ -184,26 +212,24 @@ def input_processor_for_meralion(
|
|
184 |
target_sr=processor.feature_extractor.sampling_rate)
|
185 |
for audio, sampling_rate in audios
|
186 |
]
|
187 |
-
|
188 |
-
|
189 |
-
[min(3000, _.shape[0] // 160 + 1) for _ in resampled_audios])
|
190 |
-
|
191 |
-
audio_output_length = get_max_meralion_audio_tokens(ctx)
|
192 |
speech_token_index = ctx.model_config.hf_config.speech_token_index
|
193 |
|
194 |
input_ids = inputs['prompt_token_ids']
|
195 |
|
196 |
new_input_ids = []
|
197 |
audio_num = input_ids.count(speech_token_index)
|
198 |
-
assert len(
|
199 |
(f'The text input contains {audio_num} audio tokens, '
|
200 |
-
f'but {len(
|
201 |
start = 0
|
202 |
-
for
|
203 |
end = input_ids.index(speech_token_index, start)
|
204 |
new_input_ids.extend(input_ids[start:end]) # text part
|
205 |
|
206 |
-
new_input_ids.extend([speech_token_index] *
|
|
|
207 |
start = end + 1
|
208 |
new_input_ids.extend(input_ids[start:])
|
209 |
|
@@ -240,6 +266,9 @@ def input_mapper_for_meralion(
|
|
240 |
target_sr=processor.feature_extractor.sampling_rate)
|
241 |
for audio, sampling_rate in multi_modal_data
|
242 |
]
|
|
|
|
|
|
|
243 |
batch_data = audio_feature_extractor(resampled_audios,
|
244 |
sampling_rate=16000,
|
245 |
return_attention_mask=True,
|
@@ -291,6 +320,7 @@ class MERaLiONForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|
291 |
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
292 |
config.text_config.vocab_size,
|
293 |
logit_scale)
|
|
|
294 |
self.sampler = get_sampler()
|
295 |
|
296 |
self.make_empty_intermediate_tensors = (
|
|
|
35 |
"text_decoder.model": "text_decoder",
|
36 |
}
|
37 |
|
38 |
+
# === Constants === #
|
39 |
+
DEFAULT_SAMPLE_RATE = 16000
|
40 |
+
FEATURE_CHUNK_SIZE = DEFAULT_SAMPLE_RATE * 30
|
41 |
+
OUTPUT_CHUNK_SIZE = 100
|
42 |
+
MAX_NUMBER_CHUNKS = 8
|
43 |
+
|
44 |
|
45 |
# === Audio Inputs === #
|
46 |
class MERaLiONInputs(TypedDict):
|
|
|
113 |
(speech_token_index, max_llm_audio_tokens),
|
114 |
(0, seq_len - max_llm_audio_tokens),
|
115 |
)
|
116 |
+
dummy_audio = np.full((max_llm_audio_tokens * 15 * 2 * 160, ), 0.)
|
117 |
return DummyData(
|
118 |
+
dummy_seqdata, {"audio": [(dummy_audio, DEFAULT_SAMPLE_RATE)] * num_audios}, {
|
119 |
"audio":
|
120 |
consecutive_placeholder_ranges(num_items=num_audios,
|
121 |
item_size=max_tokens_per_audio)
|
|
|
163 |
cached_get_processor = lru_cache(get_processor)
|
164 |
|
165 |
|
166 |
+
def _get_number_chunks(audios: List[np.ndarray]):
|
167 |
+
audio_lengths = np.array([_.shape[0] for _ in audios])
|
168 |
+
number_chunks = (audio_lengths // FEATURE_CHUNK_SIZE) + 1
|
169 |
+
return np.clip(number_chunks, a_min=None, a_max=MAX_NUMBER_CHUNKS)
|
170 |
+
|
171 |
+
|
172 |
+
def _get_feat_extract_output_lengths(audios: List[np.ndarray]):
|
173 |
+
return _get_number_chunks(audios) * OUTPUT_CHUNK_SIZE
|
174 |
+
|
175 |
+
|
176 |
+
def _get_chunked_audios(audios: List[np.ndarray]):
|
177 |
+
audio_number_chunks = _get_number_chunks(audios)
|
178 |
+
chunked_resampled_audios = []
|
179 |
+
|
180 |
+
for audio_idx, audio in enumerate(audios):
|
181 |
+
for cid in range(audio_number_chunks[audio_idx]):
|
182 |
+
chunked_resampled_audios.append(
|
183 |
+
audio[cid * FEATURE_CHUNK_SIZE: (cid + 1) * FEATURE_CHUNK_SIZE].copy()
|
184 |
+
)
|
185 |
+
return chunked_resampled_audios
|
186 |
+
|
187 |
+
|
188 |
def get_max_meralion_audio_tokens(ctx: InputContext) -> int:
|
189 |
"""
|
190 |
The max number of tokens after speech audio adapter.
|
191 |
"""
|
192 |
+
return MAX_NUMBER_CHUNKS * OUTPUT_CHUNK_SIZE
|
193 |
|
194 |
|
195 |
def input_processor_for_meralion(
|
|
|
212 |
target_sr=processor.feature_extractor.sampling_rate)
|
213 |
for audio, sampling_rate in audios
|
214 |
]
|
215 |
+
|
216 |
+
audio_output_lengths = _get_feat_extract_output_lengths(resampled_audios)
|
|
|
|
|
|
|
217 |
speech_token_index = ctx.model_config.hf_config.speech_token_index
|
218 |
|
219 |
input_ids = inputs['prompt_token_ids']
|
220 |
|
221 |
new_input_ids = []
|
222 |
audio_num = input_ids.count(speech_token_index)
|
223 |
+
assert len(audio_output_lengths) == audio_num, \
|
224 |
(f'The text input contains {audio_num} audio tokens, '
|
225 |
+
f'but {len(audio_output_lengths)} audios provided')
|
226 |
start = 0
|
227 |
+
for audio_idx in range(audio_num):
|
228 |
end = input_ids.index(speech_token_index, start)
|
229 |
new_input_ids.extend(input_ids[start:end]) # text part
|
230 |
|
231 |
+
new_input_ids.extend([speech_token_index] *
|
232 |
+
audio_output_lengths[audio_idx])
|
233 |
start = end + 1
|
234 |
new_input_ids.extend(input_ids[start:])
|
235 |
|
|
|
266 |
target_sr=processor.feature_extractor.sampling_rate)
|
267 |
for audio, sampling_rate in multi_modal_data
|
268 |
]
|
269 |
+
|
270 |
+
resampled_audios = _get_chunked_audios(resampled_audios)
|
271 |
+
|
272 |
batch_data = audio_feature_extractor(resampled_audios,
|
273 |
sampling_rate=16000,
|
274 |
return_attention_mask=True,
|
|
|
320 |
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
321 |
config.text_config.vocab_size,
|
322 |
logit_scale)
|
323 |
+
|
324 |
self.sampler = get_sampler()
|
325 |
|
326 |
self.make_empty_intermediate_tensors = (
|