Spaces:
Build error
Build error
| import os | |
| import time | |
| from asyncio.log import logger | |
| import uvicorn | |
| import gc | |
| import json | |
| import torch | |
| from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine | |
| from fastapi import FastAPI, HTTPException, Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from contextlib import asynccontextmanager | |
| from typing import List, Literal, Optional, Union | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoTokenizer, LogitsProcessor | |
| from sse_starlette.sse import EventSourceResponse | |
| EventSourceResponse.DEFAULT_PING_INTERVAL = 1000 | |
| MODEL_PATH = "../llama-factory/merged_models/internlm2_5-7b-chat-1m_sft_bf16_p2_full" | |
| MAX_MODEL_LENGTH = 8192 | |
| async def lifespan(app: FastAPI): | |
| yield | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class ModelCard(BaseModel): | |
| id: str | |
| object: str = "model" | |
| created: int = Field(default_factory=lambda: int(time.time())) | |
| owned_by: str = "owner" | |
| root: Optional[str] = None | |
| parent: Optional[str] = None | |
| permission: Optional[list] = None | |
| class ModelList(BaseModel): | |
| object: str = "list" | |
| data: List[ModelCard] = [] | |
| class FunctionCallResponse(BaseModel): | |
| name: Optional[str] = None | |
| arguments: Optional[str] = None | |
| class ChatMessage(BaseModel): | |
| role: Literal["user", "assistant", "system", "tool"] | |
| content: str = None | |
| name: Optional[str] = None | |
| function_call: Optional[FunctionCallResponse] = None | |
| class DeltaMessage(BaseModel): | |
| role: Optional[Literal["user", "assistant", "system"]] = None | |
| content: Optional[str] = None | |
| function_call: Optional[FunctionCallResponse] = None | |
| class EmbeddingRequest(BaseModel): | |
| input: Union[List[str], str] | |
| model: str | |
| class CompletionUsage(BaseModel): | |
| prompt_tokens: int | |
| completion_tokens: int | |
| total_tokens: int | |
| class EmbeddingResponse(BaseModel): | |
| data: list | |
| model: str | |
| object: str | |
| usage: CompletionUsage | |
| class UsageInfo(BaseModel): | |
| prompt_tokens: int = 0 | |
| total_tokens: int = 0 | |
| completion_tokens: Optional[int] = 0 | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: List[ChatMessage] | |
| temperature: Optional[float] = 0.8 | |
| top_p: Optional[float] = 0.8 | |
| max_tokens: Optional[int] = None | |
| stream: Optional[bool] = False | |
| tools: Optional[Union[dict, List[dict]]] = None | |
| tool_choice: Optional[Union[str, dict]] = "None" | |
| repetition_penalty: Optional[float] = 1.1 | |
| class ChatCompletionResponseChoice(BaseModel): | |
| index: int | |
| message: ChatMessage | |
| finish_reason: Literal["stop", "length", "function_call"] | |
| class ChatCompletionResponseStreamChoice(BaseModel): | |
| delta: DeltaMessage | |
| finish_reason: Optional[Literal["stop", "length", "function_call"]] | |
| index: int | |
| class ChatCompletionResponse(BaseModel): | |
| model: str | |
| id: str | |
| object: Literal["chat.completion", "chat.completion.chunk"] | |
| choices: List[ | |
| Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice] | |
| ] | |
| created: Optional[int] = Field(default_factory=lambda: int(time.time())) | |
| usage: Optional[UsageInfo] = None | |
| class InvalidScoreLogitsProcessor(LogitsProcessor): | |
| def __call__( | |
| self, input_ids: torch.LongTensor, scores: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| if torch.isnan(scores).any() or torch.isinf(scores).any(): | |
| scores.zero_() | |
| scores[..., 5] = 5e4 | |
| return scores | |
| def process_response(output: str, use_tool: bool = False) -> Union[str, dict]: | |
| content = "" | |
| for response in output.split("<|assistant|>"): | |
| if "\n" in response: | |
| metadata, content = response.split("\n", maxsplit=1) | |
| else: | |
| metadata, content = "", response | |
| if not metadata.strip(): | |
| content = content.strip() | |
| else: | |
| if use_tool: | |
| parameters = eval(content.strip()) | |
| content = { | |
| "name": metadata.strip(), | |
| "arguments": json.dumps(parameters, ensure_ascii=False), | |
| } | |
| else: | |
| content = {"name": metadata.strip(), "content": content} | |
| return content | |
| async def generate_stream_glm4(params): | |
| messages = params["messages"] | |
| tools = params["tools"] | |
| tool_choice = params["tool_choice"] | |
| temperature = float(params.get("temperature", 1.0)) | |
| repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
| top_p = float(params.get("top_p", 1.0)) | |
| max_new_tokens = int(params.get("max_tokens", 8192)) | |
| messages = process_messages(messages, tools=tools, tool_choice=tool_choice) | |
| inputs = tokenizer.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=False | |
| ) | |
| params_dict = { | |
| "n": 1, | |
| "best_of": 1, | |
| "presence_penalty": 1.0, | |
| "frequency_penalty": 0.0, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": -1, | |
| "repetition_penalty": repetition_penalty, | |
| "use_beam_search": False, | |
| "length_penalty": 1, | |
| "early_stopping": False, | |
| "stop_token_ids": [151329, 151336, 151338], | |
| "ignore_eos": False, | |
| "max_tokens": max_new_tokens, | |
| "logprobs": None, | |
| "prompt_logprobs": None, | |
| "skip_special_tokens": True, | |
| } | |
| sampling_params = SamplingParams(**params_dict) | |
| async for output in engine.generate( | |
| inputs=inputs, sampling_params=sampling_params, request_id="glm-4-9b" | |
| ): | |
| output_len = len(output.outputs[0].token_ids) | |
| input_len = len(output.prompt_token_ids) | |
| ret = { | |
| "text": output.outputs[0].text, | |
| "usage": { | |
| "prompt_tokens": input_len, | |
| "completion_tokens": output_len, | |
| "total_tokens": output_len + input_len, | |
| }, | |
| "finish_reason": output.outputs[0].finish_reason, | |
| } | |
| yield ret | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def process_messages(messages, tools=None, tool_choice="none"): | |
| _messages = messages | |
| messages = [] | |
| msg_has_sys = False | |
| def filter_tools(tool_choice, tools): | |
| function_name = tool_choice.get("function", {}).get("name", None) | |
| if not function_name: | |
| return [] | |
| filtered_tools = [ | |
| tool | |
| for tool in tools | |
| if tool.get("function", {}).get("name") == function_name | |
| ] | |
| return filtered_tools | |
| if tool_choice != "none": | |
| if isinstance(tool_choice, dict): | |
| tools = filter_tools(tool_choice, tools) | |
| if tools: | |
| messages.append({"role": "system", "content": None, "tools": tools}) | |
| msg_has_sys = True | |
| # add to metadata | |
| if isinstance(tool_choice, dict) and tools: | |
| messages.append( | |
| { | |
| "role": "assistant", | |
| "metadata": tool_choice["function"]["name"], | |
| "content": "", | |
| } | |
| ) | |
| for m in _messages: | |
| role, content, func_call = m.role, m.content, m.function_call | |
| if role == "function": | |
| messages.append({"role": "observation", "content": content}) | |
| elif role == "assistant" and func_call is not None: | |
| for response in content.split("<|assistant|>"): | |
| if "\n" in response: | |
| metadata, sub_content = response.split("\n", maxsplit=1) | |
| else: | |
| metadata, sub_content = "", response | |
| messages.append( | |
| {"role": role, "metadata": metadata, "content": sub_content.strip()} | |
| ) | |
| else: | |
| if role == "system" and msg_has_sys: | |
| msg_has_sys = False | |
| continue | |
| messages.append({"role": role, "content": content}) | |
| return messages | |
| async def health() -> Response: | |
| """Health check.""" | |
| return Response(status_code=200) | |
| async def list_models(): | |
| model_card = ModelCard(id="glm-4") | |
| return ModelList(data=[model_card]) | |
| async def create_chat_completion(request: ChatCompletionRequest): | |
| if len(request.messages) < 1 or request.messages[-1].role == "assistant": | |
| raise HTTPException(status_code=400, detail="Invalid request") | |
| gen_params = dict( | |
| messages=request.messages, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| max_tokens=request.max_tokens or 1024, | |
| echo=False, | |
| stream=request.stream, | |
| repetition_penalty=request.repetition_penalty, | |
| tools=request.tools, | |
| tool_choice=request.tool_choice, | |
| ) | |
| logger.debug(f"==== request ====\n{gen_params}") | |
| if request.stream: | |
| predict_stream_generator = predict_stream(request.model, gen_params) | |
| output = await anext(predict_stream_generator) | |
| if output: | |
| return EventSourceResponse( | |
| predict_stream_generator, media_type="text/event-stream" | |
| ) | |
| logger.debug(f"First result output:\n{output}") | |
| function_call = None | |
| if output and request.tools: | |
| try: | |
| function_call = process_response(output, use_tool=True) | |
| except: | |
| logger.warning("Failed to parse tool call") | |
| # CallFunction | |
| if isinstance(function_call, dict): | |
| function_call = FunctionCallResponse(**function_call) | |
| tool_response = "" | |
| if not gen_params.get("messages"): | |
| gen_params["messages"] = [] | |
| gen_params["messages"].append(ChatMessage(role="assistant", content=output)) | |
| gen_params["messages"].append( | |
| ChatMessage(role="tool", name=function_call.name, content=tool_response) | |
| ) | |
| generate = predict(request.model, gen_params) | |
| return EventSourceResponse(generate, media_type="text/event-stream") | |
| else: | |
| generate = parse_output_text(request.model, output) | |
| return EventSourceResponse(generate, media_type="text/event-stream") | |
| response = "" | |
| async for response in generate_stream_glm4(gen_params): | |
| pass | |
| if response["text"].startswith("\n"): | |
| response["text"] = response["text"][1:] | |
| response["text"] = response["text"].strip() | |
| usage = UsageInfo() | |
| function_call, finish_reason = None, "stop" | |
| if request.tools: | |
| try: | |
| function_call = process_response(response["text"], use_tool=True) | |
| except: | |
| logger.warning( | |
| "Failed to parse tool call, maybe the response is not a function call(such as cogview drawing) or have been answered." | |
| ) | |
| if isinstance(function_call, dict): | |
| finish_reason = "function_call" | |
| function_call = FunctionCallResponse(**function_call) | |
| message = ChatMessage( | |
| role="assistant", | |
| content=response["text"], | |
| function_call=( | |
| function_call if isinstance(function_call, FunctionCallResponse) else None | |
| ), | |
| ) | |
| logger.debug(f"==== message ====\n{message}") | |
| choice_data = ChatCompletionResponseChoice( | |
| index=0, | |
| message=message, | |
| finish_reason=finish_reason, | |
| ) | |
| task_usage = UsageInfo.model_validate(response["usage"]) | |
| for usage_key, usage_value in task_usage.model_dump().items(): | |
| setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) | |
| return ChatCompletionResponse( | |
| model=request.model, | |
| id="", # for open_source model, id is empty | |
| choices=[choice_data], | |
| object="chat.completion", | |
| usage=usage, | |
| ) | |
| async def predict(model_id: str, params: dict): | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=0, delta=DeltaMessage(role="assistant"), finish_reason=None | |
| ) | |
| chunk = ChatCompletionResponse( | |
| model=model_id, id="", choices=[choice_data], object="chat.completion.chunk" | |
| ) | |
| yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
| previous_text = "" | |
| async for new_response in generate_stream_glm4(params): | |
| decoded_unicode = new_response["text"] | |
| delta_text = decoded_unicode[len(previous_text) :] | |
| previous_text = decoded_unicode | |
| finish_reason = new_response["finish_reason"] | |
| if len(delta_text) == 0 and finish_reason != "function_call": | |
| continue | |
| function_call = None | |
| if finish_reason == "function_call": | |
| try: | |
| function_call = process_response(decoded_unicode, use_tool=True) | |
| except: | |
| logger.warning( | |
| "Failed to parse tool call, maybe the response is not a tool call or have been answered." | |
| ) | |
| if isinstance(function_call, dict): | |
| function_call = FunctionCallResponse(**function_call) | |
| delta = DeltaMessage( | |
| content=delta_text, | |
| role="assistant", | |
| function_call=( | |
| function_call | |
| if isinstance(function_call, FunctionCallResponse) | |
| else None | |
| ), | |
| ) | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=0, delta=delta, finish_reason=finish_reason | |
| ) | |
| chunk = ChatCompletionResponse( | |
| model=model_id, id="", choices=[choice_data], object="chat.completion.chunk" | |
| ) | |
| yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=0, delta=DeltaMessage(), finish_reason="stop" | |
| ) | |
| chunk = ChatCompletionResponse( | |
| model=model_id, id="", choices=[choice_data], object="chat.completion.chunk" | |
| ) | |
| yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
| yield "[DONE]" | |
| async def predict_stream(model_id, gen_params): | |
| output = "" | |
| is_function_call = False | |
| has_send_first_chunk = False | |
| async for new_response in generate_stream_glm4(gen_params): | |
| decoded_unicode = new_response["text"] | |
| delta_text = decoded_unicode[len(output) :] | |
| output = decoded_unicode | |
| if not is_function_call and len(output) > 7: | |
| is_function_call = output and "get_" in output | |
| if is_function_call: | |
| continue | |
| finish_reason = new_response["finish_reason"] | |
| if not has_send_first_chunk: | |
| message = DeltaMessage( | |
| content="", | |
| role="assistant", | |
| function_call=None, | |
| ) | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=0, delta=message, finish_reason=finish_reason | |
| ) | |
| chunk = ChatCompletionResponse( | |
| model=model_id, | |
| id="", | |
| choices=[choice_data], | |
| created=int(time.time()), | |
| object="chat.completion.chunk", | |
| ) | |
| yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
| send_msg = delta_text if has_send_first_chunk else output | |
| has_send_first_chunk = True | |
| message = DeltaMessage( | |
| content=send_msg, | |
| role="assistant", | |
| function_call=None, | |
| ) | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=0, delta=message, finish_reason=finish_reason | |
| ) | |
| chunk = ChatCompletionResponse( | |
| model=model_id, | |
| id="", | |
| choices=[choice_data], | |
| created=int(time.time()), | |
| object="chat.completion.chunk", | |
| ) | |
| yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
| if is_function_call: | |
| yield output | |
| else: | |
| yield "[DONE]" | |
| async def parse_output_text(model_id: str, value: str): | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=0, delta=DeltaMessage(role="assistant", content=value), finish_reason=None | |
| ) | |
| chunk = ChatCompletionResponse( | |
| model=model_id, id="", choices=[choice_data], object="chat.completion.chunk" | |
| ) | |
| yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=0, delta=DeltaMessage(), finish_reason="stop" | |
| ) | |
| chunk = ChatCompletionResponse( | |
| model=model_id, id="", choices=[choice_data], object="chat.completion.chunk" | |
| ) | |
| yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
| yield "[DONE]" | |
| if __name__ == "__main__": | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) | |
| engine_args = AsyncEngineArgs( | |
| model=MODEL_PATH, | |
| tokenizer=MODEL_PATH, | |
| tensor_parallel_size=1, | |
| dtype="bfloat16", | |
| trust_remote_code=True, | |
| gpu_memory_utilization=0.9, | |
| enforce_eager=True, | |
| worker_use_ray=True, | |
| engine_use_ray=False, | |
| disable_log_requests=True, | |
| max_model_len=MAX_MODEL_LENGTH, | |
| ) | |
| engine = AsyncLLMEngine.from_engine_args(engine_args) | |
| uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) | |