# coding=utf-8 # Implements API for ChatGLM3-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) # Usage: python openai_api.py # Visit http://localhost:8000/docs for documents. import json import time from contextlib import asynccontextmanager from typing import List, Literal, Optional, Union import torch import uvicorn from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from sse_starlette.sse import EventSourceResponse from transformers import AutoTokenizer, AutoModel from utils import process_response, generate_chatglm3, generate_stream_chatglm3 @asynccontextmanager async def lifespan(app: FastAPI): # collects GPU memory yield if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class ModelCard(BaseModel): id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) owned_by: str = "owner" root: Optional[str] = None parent: Optional[str] = None permission: Optional[list] = None class ModelList(BaseModel): object: str = "list" data: List[ModelCard] = [] class ChatMessage(BaseModel): role: Literal["user", "assistant", "system", "observation"] content: str = None metadata: Optional[str] = None tools: Optional[List[dict]] = None class DeltaMessage(BaseModel): role: Optional[Literal["user", "assistant", "system"]] = None content: Optional[str] = None class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 max_tokens: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stream: Optional[bool] = False # Additional parameters support for stop generation stop_token_ids: Optional[List[int]] = None repetition_penalty: Optional[float] = 1.1 # Additional parameters supported by tools return_function_call: Optional[bool] = False class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage finish_reason: Literal["stop", "length", "function_call"] history: Optional[List[dict]] = None class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage finish_reason: Optional[Literal["stop", "length"]] class UsageInfo(BaseModel): prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens: Optional[int] = 0 class ChatCompletionResponse(BaseModel): model: str object: Literal["chat.completion", "chat.completion.chunk"] choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] created: Optional[int] = Field(default_factory=lambda: int(time.time())) usage: Optional[UsageInfo] = None @app.get("/v1/models", response_model=ModelList) async def list_models(): model_card = ModelCard(id="gpt-3.5-turbo") return ModelList(data=[model_card]) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): global model, tokenizer if request.messages[-1].role == "assistant": raise HTTPException(status_code=400, detail="Invalid request") with_function_call = bool(request.messages[0].role == "system" and request.messages[0].tools is not None) # stop settings request.stop = request.stop or [] if isinstance(request.stop, str): request.stop = [request.stop] request.stop_token_ids = request.stop_token_ids or [] gen_params = dict( messages=request.messages, temperature=request.temperature, top_p=request.top_p, max_tokens=request.max_tokens or 1024, echo=False, stream=request.stream, stop_token_ids=request.stop_token_ids, stop=request.stop, repetition_penalty=request.repetition_penalty, with_function_call=with_function_call, ) if request.stream: generate = predict(request.model, gen_params) return EventSourceResponse(generate, media_type="text/event-stream") response = generate_chatglm3(model, tokenizer, gen_params) usage = UsageInfo() finish_reason, history = "stop", None if with_function_call and request.return_function_call: history = [m.dict(exclude_none=True) for m in request.messages] content, history = process_response(response["text"], history) if isinstance(content, dict): message, finish_reason = ChatMessage( role="assistant", content=json.dumps(content, ensure_ascii=False), ), "function_call" else: message = ChatMessage(role="assistant", content=content) else: message = ChatMessage(role="assistant", content=response["text"]) choice_data = ChatCompletionResponseChoice( index=0, message=message, finish_reason=finish_reason, history=history ) task_usage = UsageInfo.parse_obj(response["usage"]) for usage_key, usage_value in task_usage.dict().items(): setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage) async def predict(model_id: str, params: dict): global model, tokenizer choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role="assistant"), finish_reason=None ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) previous_text = "" for new_response in generate_stream_chatglm3(model, tokenizer, params): decoded_unicode = new_response["text"] delta_text = decoded_unicode[len(previous_text):] previous_text = decoded_unicode if len(delta_text) == 0: delta_text = None choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(content=delta_text), finish_reason=None ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(), finish_reason="stop" ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield '[DONE]' if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained("D:\git\model\chatglm3-6b-32k", trust_remote_code=True) model = AutoModel.from_pretrained("D:\git\model\chatglm3-6b-32k", trust_remote_code=True).cuda() # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 # from utils import load_model_on_gpus # model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2) model = model.eval() uvicorn.run(app, host='0.0.0.0', port=8080, workers=1)