YingxuHe's picture
update vllm guide
7b7a48e
|
raw
history blame
5.14 kB

MERaLiON-AudioLLM vLLM Serving

MERaLiON-AudioLLM is trained on 30s audios. This vllm integration supports at most 4mins audio input.

Set up Environment

MERaLiON-AudioLLM requires vLLM version 0.6.4.post1 and transformers 4.46.3

pip install vllm==0.6.4.post1
pip install transformers==4.46.3

As the vLLM documentation recommends, we provide a way to register our model via vLLM plugins.

python install .

Offline Inference

Here is an example of offline inference using our custom vLLM class.

import torch
from vllm import ModelRegistry, LLM, SamplingParams
from vllm.assets.audio import AudioAsset

def run_meralion(question: str):
    model_name = "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION"

    llm = LLM(model=model_name,
              tokenizer=model_name,
              max_num_seqs=8,
              limit_mm_per_prompt={"audio": 1},
              trust_remote_code=True,
              dtype=torch.bfloat16
              )

    audio_in_prompt = "Given the following audio context: <SpeechHere>\n\n"

    prompt = ("<start_of_turn>user\n"
              f"{audio_in_prompt}Text instruction: {question}<end_of_turn>\n"
              "<start_of_turn>model\n")
    stop_token_ids = None
    return llm, prompt, stop_token_ids

audio_asset = AudioAsset("mary_had_lamb")
question= "Please trancribe this speech."

llm, prompt, stop_token_ids = run_meralion(question)

# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(
  temperature=0.1,
  top_p=0.9,
  top_k=50,
  repetition_penalty=1.1,
  seed=42,
  max_tokens=1024,
  stop_token_ids=None
)

mm_data = {"audio": [audio_asset.audio_and_sample_rate]}
inputs = {"prompt": prompt, "multi_modal_data": mm_data}

# batch inference
inputs = [inputs] * 2

outputs = llm.generate(inputs, sampling_params=sampling_params)

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

Serving

Here is an example to start the server via the vllm serve command.

export HF_TOKEN=<your-hf-token>

vllm serve MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --tokenizer MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --max-num-seqs 8 --trust-remote-code --dtype bfloat16 --port 8000

To call the server, you can use the official OpenAI client:

import base64

from openai import OpenAI


def get_client(api_key="EMPTY", base_url="http://localhost:8000/v1"):
    client = OpenAI(
        api_key=api_key,
        base_url=base_url,
    )

    models = client.models.list()
    model_name = models.data[0].id
    return client, model_name


def get_response(text_input, base64_audio_input, **params):
    response_obj = client.chat.completions.create(
        messages=[{
            "role":
            "user",
            "content": [
                {
                    "type": "text",
                    "text": f"Text instruction: {text_input}"
                },
                {
                    "type": "audio_url",
                    "audio_url": {
                        "url": f"data:audio/ogg;base64,{base64_audio_input}"
                    },
                },
            ],
        }],
        **params
    )
    return response_obj


#specify input and params
possible_text_inputs = [
    "Please transcribe this speech.",
    "Please summarise the content of this speech.",
    "Please follow the instruction in this speech."
]

audio_bytes = open(f"/path/to/wav/or/mp3/file", "rb").read()
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')

# use the port number of vllm service.
client, model_name = get_client(base_url="http://localhost:8000/v1")

generation_parameters = dict(
    model=model_name,
    max_completion_tokens=1024,
    temperature=0.1,
    top_p=0.9,
    extra_body={
        "repetition_penalty": 1.1,
        "top_k": 50,
        "length_penalty": 1.0
    },
    seed=42
)


response_obj = get_response(possible_text_inputs[0], audio_base64, **generation_parameters)
print(response_obj.choices[0].message.content)

Alternatively, you can try calling the server with curl, see the example below. We recommend using the generation config in the json body to fully reproduce the performance.

curl http://localhost:8000/v1/chat/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION",
        "messages": [
            {"role": "user", 
            "content": [
                {"type": "text", "text": "Text instruction: <your-instruction>"},
                {"type": "audio_url", "audio_url": {"url": "data:audio/ogg;base64,<your-audio-base64-string>"}}
            ]
            }
        ],
        "max_completion_tokens": 1024,
        "temperature": 0.1, 
        "top_p": 0.9, 
        "seed": 42 
    }'