YingxuHe's picture
fix bug vllm integration
98192a8
|
raw
history blame
7.64 kB

MERaLiON-AudioLLM vLLM Serving

MERaLiON-AudioLLM is trained on 30 second audios. This vllm integration supports at most 4mins audio input.

Set up Environment

MERaLiON-AudioLLM requires vLLM version 6.4.post1 and transformers 4.46.3

pip install vllm==6.4.post1
pip install transformers==4.46.3

As the vLLM documentation recommends, we provide a way to register our model via vLLM plugins.

python install .

Offline Inference

Here is an example of offline inference using our custom vLLM class.

import torch
from vllm import ModelRegistry, LLM, SamplingParams
from vllm.assets.audio import AudioAsset

model_name = "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION"

llm = LLM(model=model_name,
          tokenizer=model_name,
          limit_mm_per_prompt={"audio": 1},
          trust_remote_code=True,
          dtype=torch.bfloat16
          )

audio_asset = AudioAsset("mary_had_lamb")

question= "Please trancribe this speech."
audio_in_prompt = "Given the following audio context: <SpeechHere>\n\n"

prompt = ("<start_of_turn>user\n"
          f"{audio_in_prompt}Text instruction: {question}<end_of_turn>\n"
          "<start_of_turn>model\n")

sampling_params = SamplingParams(
  temperature=1,
  top_p=9,
  top_k=50,
  repetition_penalty=1.1,
  seed=42,
  max_tokens=1024,
  stop_token_ids=None
)

mm_data = {"audio": [audio_asset.audio_and_sample_rate]}
inputs = {"prompt": prompt, "multi_modal_data": mm_data}

# batch inference
inputs = [inputs] * 2

outputs = llm.generate(inputs, sampling_params=sampling_params)

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

Serving

Here is an example to start the server via the vllm serve command.

export HF_TOKEN=<your-hf-token>

vllm serve MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --tokenizer MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --trust-remote-code --dtype bfloat16 --port 8000

To call the server, you can use the official OpenAI client:

import base64

from openai import OpenAI


def get_client(api_key="EMPTY", base_url="http://localhost:8000/v1"):
    client = OpenAI(
        api_key=api_key,
        base_url=base_url,
    )

    models = client.models.list()
    model_name = models.data[0].id
    return client, model_name


def get_response(text_input, base64_audio_input, **params):
    response_obj = client.chat.completions.create(
        messages=[{
            "role":
            "user",
            "content": [
                {
                    "type": "text",
                    "text": f"Text instruction: {text_input}"
                },
                {
                    "type": "audio_url",
                    "audio_url": {
                        "url": f"data:audio/ogg;base64,{base64_audio_input}"
                    },
                },
            ],
        }],
        **params
    )
    return response_obj


#specify input and params
possible_text_inputs = [
    "Please transcribe this speech.",
    "Please summarise the content of this speech.",
    "Please follow the instruction in this speech."
]

audio_bytes = open(f"/path/to/wav/or/mp3/file", "rb").read()
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')

# use the port number of vllm service.
client, model_name = get_client(base_url="http://localhost:8000/v1")

generation_parameters = dict(
    model=model_name,
    max_completion_tokens=1024,
    temperature=1,
    top_p=9,
    extra_body={
        "repetition_penalty": 1.1,
        "top_k": 50,
        "length_penalty": 1.0
    },
    seed=42
)


response_obj = get_response(possible_text_inputs[0], audio_base64, **generation_parameters)
print(response_obj.choices[0].message.content)

Alternatively, you can try calling the server with curl, see the example below. We recommend using the generation config in the json body to fully reproduce the performance.

curl http://localhost:8000/v1/chat/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION",
        "messages": [
            {"role": "user", 
            "content": [
                {"type": "text", "text": "Text instruction: <your-instruction>"},
                {"type": "audio_url", "audio_url": {"url": "data:audio/ogg;base64,<your-audio-base64-string>"}}
            ]
            }
        ],
        "max_completion_tokens": 1024,
        "temperature": 1, 
        "top_p": 9, 
        "seed": 42 
    }'

Inference Performance Benchmark

We report average Time To First Token (TTFT, unit: ms) together with Inter-Token Latency (ITL, unit: ms) with vLLM instance running on H100 and A100 GPU respectively.

Input: 120 speech recognition prompts for each input audio length and concurrency combination.
Output: The corresponding output length of these prompts.

Single NVIDIA H100 GPU (80GiB GPU memory)

Input Audio Length 30s 1min 2mins
Concurrent requests TTFT (ms) ITL (ms) TTFT (ms) ITL (ms) TTFT (ms) ITL (ms)
1 85.8 9.9 126.4 9.6 214.5 9.7
4 96.9 11.4 159.6 11.1 258.1 11.2
8 109.6 13.0 206.5 12.7 261.9 13.0
16 149.9 16.3 236.7 16.2 299.0 16.8

Single NVIDIA A100 GPU (40GiB GPU memory)

Input Audio Length 30s 1min 2mins
Concurrent requests TTFT (ms) ITL (ms) TTFT (ms) ITL (ms) TTFT (ms) ITL (ms)
1 162.6 18.0 195.0 18.3 309.9 18.6
4 159.1 21.1 226.9 21.2 329.5 21.6
8 176.5 25.2 305.4 24.8 352.5 25.5
16 196.0 32.0 329.4 31.9 414.7 33.4