Spaces:
Runtime error
Runtime error
import time | |
from asyncio.log import logger | |
import re | |
import uvicorn | |
import gc | |
import json | |
import torch | |
import random | |
import string | |
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine | |
from fastapi import FastAPI, HTTPException, Response | |
from fastapi.middleware.cors import CORSMiddleware | |
from contextlib import asynccontextmanager | |
from typing import List, Literal, Optional, Union | |
from pydantic import BaseModel, Field | |
from transformers import AutoTokenizer, LogitsProcessor | |
from sse_starlette.sse import EventSourceResponse | |
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000 | |
import os | |
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat') | |
MAX_MODEL_LENGTH = 8192 | |
async def lifespan(app: FastAPI): | |
yield | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
app = FastAPI(lifespan=lifespan) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def generate_id(prefix: str, k=29) -> str: | |
suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=k)) | |
return f"{prefix}{suffix}" | |
class ModelCard(BaseModel): | |
id: str = "" | |
object: str = "model" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
owned_by: str = "owner" | |
root: Optional[str] = None | |
parent: Optional[str] = None | |
permission: Optional[list] = None | |
class ModelList(BaseModel): | |
object: str = "list" | |
data: List[ModelCard] = ["glm-4"] | |
class FunctionCall(BaseModel): | |
name: Optional[str] = None | |
arguments: Optional[str] = None | |
class ChoiceDeltaToolCallFunction(BaseModel): | |
name: Optional[str] = None | |
arguments: Optional[str] = None | |
class UsageInfo(BaseModel): | |
prompt_tokens: int = 0 | |
total_tokens: int = 0 | |
completion_tokens: Optional[int] = 0 | |
class ChatCompletionMessageToolCall(BaseModel): | |
index: Optional[int] = 0 | |
id: Optional[str] = None | |
function: FunctionCall | |
type: Optional[Literal["function"]] = 'function' | |
class ChatMessage(BaseModel): | |
# “function” 字段解释: | |
# 使用较老的OpenAI API版本需要注意在这里添加 function 字段并在 process_messages函数中添加相应角色转换逻辑为 observation | |
role: Literal["user", "assistant", "system", "tool"] | |
content: Optional[str] = None | |
function_call: Optional[ChoiceDeltaToolCallFunction] = None | |
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None | |
class DeltaMessage(BaseModel): | |
role: Optional[Literal["user", "assistant", "system"]] = None | |
content: Optional[str] = None | |
function_call: Optional[ChoiceDeltaToolCallFunction] = None | |
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None | |
class ChatCompletionResponseChoice(BaseModel): | |
index: int | |
message: ChatMessage | |
finish_reason: Literal["stop", "length", "tool_calls"] | |
class ChatCompletionResponseStreamChoice(BaseModel): | |
delta: DeltaMessage | |
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] | |
index: int | |
class ChatCompletionResponse(BaseModel): | |
model: str | |
id: Optional[str] = Field(default_factory=lambda: generate_id('chatcmpl-', 29)) | |
object: Literal["chat.completion", "chat.completion.chunk"] | |
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] | |
created: Optional[int] = Field(default_factory=lambda: int(time.time())) | |
system_fingerprint: Optional[str] = Field(default_factory=lambda: generate_id('fp_', 9)) | |
usage: Optional[UsageInfo] = None | |
class ChatCompletionRequest(BaseModel): | |
model: str | |
messages: List[ChatMessage] | |
temperature: Optional[float] = 0.8 | |
top_p: Optional[float] = 0.8 | |
max_tokens: Optional[int] = None | |
stream: Optional[bool] = False | |
tools: Optional[Union[dict, List[dict]]] = None | |
tool_choice: Optional[Union[str, dict]] = None | |
repetition_penalty: Optional[float] = 1.1 | |
class InvalidScoreLogitsProcessor(LogitsProcessor): | |
def __call__( | |
self, input_ids: torch.LongTensor, scores: torch.FloatTensor | |
) -> torch.FloatTensor: | |
if torch.isnan(scores).any() or torch.isinf(scores).any(): | |
scores.zero_() | |
scores[..., 5] = 5e4 | |
return scores | |
def process_response(output: str, tools: dict | List[dict] = None, use_tool: bool = False) -> Union[str, dict]: | |
lines = output.strip().split("\n") | |
arguments_json = None | |
special_tools = ["cogview", "simple_browser"] | |
tools = {tool['function']['name'] for tool in tools} | |
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。 | |
##TODO 如果你希望做更多判断,可以在这里进行逻辑完善。 | |
if len(lines) >= 2 and lines[1].startswith("{"): | |
function_name = lines[0].strip() | |
arguments = "\n".join(lines[1:]).strip() | |
if function_name in tools or function_name in special_tools: | |
try: | |
arguments_json = json.loads(arguments) | |
is_tool_call = True | |
except json.JSONDecodeError: | |
is_tool_call = function_name in special_tools | |
if is_tool_call and use_tool: | |
content = { | |
"name": function_name, | |
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments, | |
ensure_ascii=False) | |
} | |
if function_name == "simple_browser": | |
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)') | |
match = search_pattern.match(arguments) | |
if match: | |
content["arguments"] = json.dumps({ | |
"query": match.group(1), | |
"recency_days": int(match.group(2)) | |
}, ensure_ascii=False) | |
elif function_name == "cogview": | |
content["arguments"] = json.dumps({ | |
"prompt": arguments | |
}, ensure_ascii=False) | |
return content | |
return output.strip() | |
async def generate_stream_glm4(params): | |
messages = params["messages"] | |
tools = params["tools"] | |
tool_choice = params["tool_choice"] | |
temperature = float(params.get("temperature", 1.0)) | |
repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
top_p = float(params.get("top_p", 1.0)) | |
max_new_tokens = int(params.get("max_tokens", 8192)) | |
messages = process_messages(messages, tools=tools, tool_choice=tool_choice) | |
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
params_dict = { | |
"n": 1, | |
"best_of": 1, | |
"presence_penalty": 1.0, | |
"frequency_penalty": 0.0, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": -1, | |
"repetition_penalty": repetition_penalty, | |
"use_beam_search": False, | |
"length_penalty": 1, | |
"early_stopping": False, | |
"stop_token_ids": [151329, 151336, 151338], | |
"ignore_eos": False, | |
"max_tokens": max_new_tokens, | |
"logprobs": None, | |
"prompt_logprobs": None, | |
"skip_special_tokens": True, | |
} | |
sampling_params = SamplingParams(**params_dict) | |
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"): | |
output_len = len(output.outputs[0].token_ids) | |
input_len = len(output.prompt_token_ids) | |
ret = { | |
"text": output.outputs[0].text, | |
"usage": { | |
"prompt_tokens": input_len, | |
"completion_tokens": output_len, | |
"total_tokens": output_len + input_len | |
}, | |
"finish_reason": output.outputs[0].finish_reason, | |
} | |
yield ret | |
gc.collect() | |
torch.cuda.empty_cache() | |
def process_messages(messages, tools=None, tool_choice="none"): | |
_messages = messages | |
processed_messages = [] | |
msg_has_sys = False | |
def filter_tools(tool_choice, tools): | |
function_name = tool_choice.get('function', {}).get('name', None) | |
if not function_name: | |
return [] | |
filtered_tools = [ | |
tool for tool in tools | |
if tool.get('function', {}).get('name') == function_name | |
] | |
return filtered_tools | |
if tool_choice != "none": | |
if isinstance(tool_choice, dict): | |
tools = filter_tools(tool_choice, tools) | |
if tools: | |
processed_messages.append( | |
{ | |
"role": "system", | |
"content": None, | |
"tools": tools | |
} | |
) | |
msg_has_sys = True | |
if isinstance(tool_choice, dict) and tools: | |
processed_messages.append( | |
{ | |
"role": "assistant", | |
"metadata": tool_choice["function"]["name"], | |
"content": "" | |
} | |
) | |
for m in _messages: | |
role, content, func_call = m.role, m.content, m.function_call | |
tool_calls = getattr(m, 'tool_calls', None) | |
if role == "function": | |
processed_messages.append( | |
{ | |
"role": "observation", | |
"content": content | |
} | |
) | |
elif role == "tool": | |
processed_messages.append( | |
{ | |
"role": "observation", | |
"content": content, | |
"function_call": True | |
} | |
) | |
elif role == "assistant": | |
if tool_calls: | |
for tool_call in tool_calls: | |
processed_messages.append( | |
{ | |
"role": "assistant", | |
"metadata": tool_call.function.name, | |
"content": tool_call.function.arguments | |
} | |
) | |
else: | |
for response in content.split("\n"): | |
if "\n" in response: | |
metadata, sub_content = response.split("\n", maxsplit=1) | |
else: | |
metadata, sub_content = "", response | |
processed_messages.append( | |
{ | |
"role": role, | |
"metadata": metadata, | |
"content": sub_content.strip() | |
} | |
) | |
else: | |
if role == "system" and msg_has_sys: | |
msg_has_sys = False | |
continue | |
processed_messages.append({"role": role, "content": content}) | |
if not tools or tool_choice == "none": | |
for m in _messages: | |
if m.role == 'system': | |
processed_messages.insert(0, {"role": m.role, "content": m.content}) | |
break | |
return processed_messages | |
async def health() -> Response: | |
"""Health check.""" | |
return Response(status_code=200) | |
async def list_models(): | |
model_card = ModelCard(id="glm-4") | |
return ModelList(data=[model_card]) | |
async def create_chat_completion(request: ChatCompletionRequest): | |
if len(request.messages) < 1 or request.messages[-1].role == "assistant": | |
raise HTTPException(status_code=400, detail="Invalid request") | |
gen_params = dict( | |
messages=request.messages, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
max_tokens=request.max_tokens or 1024, | |
echo=False, | |
stream=request.stream, | |
repetition_penalty=request.repetition_penalty, | |
tools=request.tools, | |
tool_choice=request.tool_choice, | |
) | |
logger.debug(f"==== request ====\n{gen_params}") | |
if request.stream: | |
predict_stream_generator = predict_stream(request.model, gen_params) | |
output = await anext(predict_stream_generator) | |
if output: | |
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream") | |
logger.debug(f"First result output:\n{output}") | |
function_call = None | |
if output and request.tools: | |
try: | |
function_call = process_response(output, request.tools, use_tool=True) | |
except: | |
logger.warning("Failed to parse tool call") | |
if isinstance(function_call, dict): | |
function_call = ChoiceDeltaToolCallFunction(**function_call) | |
generate = parse_output_text(request.model, output, function_call=function_call) | |
return EventSourceResponse(generate, media_type="text/event-stream") | |
else: | |
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream") | |
response = "" | |
async for response in generate_stream_glm4(gen_params): | |
pass | |
if response["text"].startswith("\n"): | |
response["text"] = response["text"][1:] | |
response["text"] = response["text"].strip() | |
usage = UsageInfo() | |
function_call, finish_reason = None, "stop" | |
tool_calls = None | |
if request.tools: | |
try: | |
function_call = process_response(response["text"], request.tools, use_tool=True) | |
except Exception as e: | |
logger.warning(f"Failed to parse tool call: {e}") | |
if isinstance(function_call, dict): | |
finish_reason = "tool_calls" | |
function_call_response = ChoiceDeltaToolCallFunction(**function_call) | |
function_call_instance = FunctionCall( | |
name=function_call_response.name, | |
arguments=function_call_response.arguments | |
) | |
tool_calls = [ | |
ChatCompletionMessageToolCall( | |
id=generate_id('call_', 24), | |
function=function_call_instance, | |
type="function")] | |
message = ChatMessage( | |
role="assistant", | |
content=None if tool_calls else response["text"], | |
function_call=None, | |
tool_calls=tool_calls, | |
) | |
logger.debug(f"==== message ====\n{message}") | |
choice_data = ChatCompletionResponseChoice( | |
index=0, | |
message=message, | |
finish_reason=finish_reason, | |
) | |
task_usage = UsageInfo.model_validate(response["usage"]) | |
for usage_key, usage_value in task_usage.model_dump().items(): | |
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) | |
return ChatCompletionResponse( | |
model=request.model, | |
choices=[choice_data], | |
object="chat.completion", | |
usage=usage | |
) | |
async def predict_stream(model_id, gen_params): | |
output = "" | |
is_function_call = False | |
has_send_first_chunk = False | |
created_time = int(time.time()) | |
function_name = None | |
response_id = generate_id('chatcmpl-', 29) | |
system_fingerprint = generate_id('fp_', 9) | |
tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else None | |
async for new_response in generate_stream_glm4(gen_params): | |
decoded_unicode = new_response["text"] | |
delta_text = decoded_unicode[len(output):] | |
output = decoded_unicode | |
lines = output.strip().split("\n") | |
# 检查是否为工具 | |
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。 | |
##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。 | |
if not is_function_call and len(lines) >= 2: | |
first_line = lines[0].strip() | |
if first_line in tools: | |
is_function_call = True | |
function_name = first_line | |
# 工具调用返回 | |
if is_function_call: | |
if not has_send_first_chunk: | |
function_call = {"name": function_name, "arguments": ""} | |
tool_call = ChatCompletionMessageToolCall( | |
index=0, | |
id=generate_id('call_', 24), | |
function=FunctionCall(**function_call), | |
type="function" | |
) | |
message = DeltaMessage( | |
content=None, | |
role="assistant", | |
function_call=None, | |
tool_calls=[tool_call] | |
) | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=message, | |
finish_reason=None | |
) | |
chunk = ChatCompletionResponse( | |
model=model_id, | |
id=response_id, | |
choices=[choice_data], | |
created=created_time, | |
system_fingerprint=system_fingerprint, | |
object="chat.completion.chunk" | |
) | |
yield "" | |
yield chunk.model_dump_json(exclude_unset=True) | |
has_send_first_chunk = True | |
function_call = {"name": None, "arguments": delta_text} | |
tool_call = ChatCompletionMessageToolCall( | |
index=0, | |
id=None, | |
function=FunctionCall(**function_call), | |
type="function" | |
) | |
message = DeltaMessage( | |
content=None, | |
role=None, | |
function_call=None, | |
tool_calls=[tool_call] | |
) | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=message, | |
finish_reason=None | |
) | |
chunk = ChatCompletionResponse( | |
model=model_id, | |
id=response_id, | |
choices=[choice_data], | |
created=created_time, | |
system_fingerprint=system_fingerprint, | |
object="chat.completion.chunk" | |
) | |
yield chunk.model_dump_json(exclude_unset=True) | |
# 用户请求了 Function Call 但是框架还没确定是否为Function Call | |
elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call: | |
continue | |
# 常规返回 | |
else: | |
finish_reason = new_response.get("finish_reason", None) | |
if not has_send_first_chunk: | |
message = DeltaMessage( | |
content="", | |
role="assistant", | |
function_call=None, | |
) | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=message, | |
finish_reason=finish_reason | |
) | |
chunk = ChatCompletionResponse( | |
model=model_id, | |
id=response_id, | |
choices=[choice_data], | |
created=created_time, | |
system_fingerprint=system_fingerprint, | |
object="chat.completion.chunk" | |
) | |
yield chunk.model_dump_json(exclude_unset=True) | |
has_send_first_chunk = True | |
message = DeltaMessage( | |
content=delta_text, | |
role="assistant", | |
function_call=None, | |
) | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=message, | |
finish_reason=finish_reason | |
) | |
chunk = ChatCompletionResponse( | |
model=model_id, | |
id=response_id, | |
choices=[choice_data], | |
created=created_time, | |
system_fingerprint=system_fingerprint, | |
object="chat.completion.chunk" | |
) | |
yield chunk.model_dump_json(exclude_unset=True) | |
# 工具调用需要额外返回一个字段以对齐 OpenAI 接口 | |
if is_function_call: | |
yield ChatCompletionResponse( | |
model=model_id, | |
id=response_id, | |
system_fingerprint=system_fingerprint, | |
choices=[ | |
ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=DeltaMessage( | |
content=None, | |
role=None, | |
function_call=None, | |
), | |
finish_reason="tool_calls" | |
)], | |
created=created_time, | |
object="chat.completion.chunk", | |
usage=None | |
).model_dump_json(exclude_unset=True) | |
yield '[DONE]' | |
async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None): | |
delta = DeltaMessage(role="assistant", content=value) | |
if function_call is not None: | |
delta.function_call = function_call | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=delta, | |
finish_reason=None | |
) | |
chunk = ChatCompletionResponse( | |
model=model_id, | |
choices=[choice_data], | |
object="chat.completion.chunk" | |
) | |
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
yield '[DONE]' | |
if __name__ == "__main__": | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) | |
engine_args = AsyncEngineArgs( | |
model=MODEL_PATH, | |
tokenizer=MODEL_PATH, | |
# 如果你有多张显卡,可以在这里设置成你的显卡数量 | |
tensor_parallel_size=1, | |
dtype="bfloat16", | |
trust_remote_code=True, | |
# 占用显存的比例,请根据你的显卡显存大小设置合适的值,例如,如果你的显卡有80G,您只想使用24G,请按照24/80=0.3设置 | |
gpu_memory_utilization=0.9, | |
enforce_eager=True, | |
worker_use_ray=False, | |
engine_use_ray=False, | |
disable_log_requests=True, | |
max_model_len=MAX_MODEL_LENGTH, | |
) | |
engine = AsyncLLMEngine.from_engine_args(engine_args) | |
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) | |