|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import time |
|
from contextlib import asynccontextmanager |
|
from typing import List, Literal, Optional, Union |
|
|
|
import torch |
|
import uvicorn |
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel, Field |
|
from sse_starlette.sse import EventSourceResponse |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
from utils import process_response, generate_chatglm3, generate_stream_chatglm3 |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
yield |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
torch.cuda.ipc_collect() |
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
class ModelCard(BaseModel): |
|
id: str |
|
object: str = "model" |
|
created: int = Field(default_factory=lambda: int(time.time())) |
|
owned_by: str = "owner" |
|
root: Optional[str] = None |
|
parent: Optional[str] = None |
|
permission: Optional[list] = None |
|
|
|
|
|
class ModelList(BaseModel): |
|
object: str = "list" |
|
data: List[ModelCard] = [] |
|
|
|
|
|
class ChatMessage(BaseModel): |
|
role: Literal["user", "assistant", "system", "observation"] |
|
content: str = None |
|
metadata: Optional[str] = None |
|
tools: Optional[List[dict]] = None |
|
|
|
|
|
class DeltaMessage(BaseModel): |
|
role: Optional[Literal["user", "assistant", "system"]] = None |
|
content: Optional[str] = None |
|
|
|
|
|
class ChatCompletionRequest(BaseModel): |
|
model: str |
|
messages: List[ChatMessage] |
|
temperature: Optional[float] = 0.7 |
|
top_p: Optional[float] = 1.0 |
|
max_tokens: Optional[int] = None |
|
stop: Optional[Union[str, List[str]]] = None |
|
stream: Optional[bool] = False |
|
|
|
|
|
stop_token_ids: Optional[List[int]] = None |
|
repetition_penalty: Optional[float] = 1.1 |
|
|
|
|
|
return_function_call: Optional[bool] = False |
|
|
|
|
|
class ChatCompletionResponseChoice(BaseModel): |
|
index: int |
|
message: ChatMessage |
|
finish_reason: Literal["stop", "length", "function_call"] |
|
history: Optional[List[dict]] = None |
|
|
|
|
|
class ChatCompletionResponseStreamChoice(BaseModel): |
|
index: int |
|
delta: DeltaMessage |
|
finish_reason: Optional[Literal["stop", "length"]] |
|
|
|
|
|
class UsageInfo(BaseModel): |
|
prompt_tokens: int = 0 |
|
total_tokens: int = 0 |
|
completion_tokens: Optional[int] = 0 |
|
|
|
|
|
class ChatCompletionResponse(BaseModel): |
|
model: str |
|
object: Literal["chat.completion", "chat.completion.chunk"] |
|
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] |
|
created: Optional[int] = Field(default_factory=lambda: int(time.time())) |
|
usage: Optional[UsageInfo] = None |
|
|
|
|
|
@app.get("/v1/models", response_model=ModelList) |
|
async def list_models(): |
|
model_card = ModelCard(id="gpt-3.5-turbo") |
|
return ModelList(data=[model_card]) |
|
|
|
|
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) |
|
async def create_chat_completion(request: ChatCompletionRequest): |
|
global model, tokenizer |
|
|
|
if request.messages[-1].role == "assistant": |
|
raise HTTPException(status_code=400, detail="Invalid request") |
|
|
|
with_function_call = bool(request.messages[0].role == "system" and request.messages[0].tools is not None) |
|
|
|
|
|
request.stop = request.stop or [] |
|
if isinstance(request.stop, str): |
|
request.stop = [request.stop] |
|
|
|
request.stop_token_ids = request.stop_token_ids or [] |
|
|
|
gen_params = dict( |
|
messages=request.messages, |
|
temperature=request.temperature, |
|
top_p=request.top_p, |
|
max_tokens=request.max_tokens or 1024, |
|
echo=False, |
|
stream=request.stream, |
|
stop_token_ids=request.stop_token_ids, |
|
stop=request.stop, |
|
repetition_penalty=request.repetition_penalty, |
|
with_function_call=with_function_call, |
|
) |
|
|
|
if request.stream: |
|
generate = predict(request.model, gen_params) |
|
return EventSourceResponse(generate, media_type="text/event-stream") |
|
|
|
response = generate_chatglm3(model, tokenizer, gen_params) |
|
usage = UsageInfo() |
|
|
|
finish_reason, history = "stop", None |
|
if with_function_call and request.return_function_call: |
|
history = [m.dict(exclude_none=True) for m in request.messages] |
|
content, history = process_response(response["text"], history) |
|
if isinstance(content, dict): |
|
message, finish_reason = ChatMessage( |
|
role="assistant", |
|
content=json.dumps(content, ensure_ascii=False), |
|
), "function_call" |
|
else: |
|
message = ChatMessage(role="assistant", content=content) |
|
else: |
|
message = ChatMessage(role="assistant", content=response["text"]) |
|
|
|
choice_data = ChatCompletionResponseChoice( |
|
index=0, |
|
message=message, |
|
finish_reason=finish_reason, |
|
history=history |
|
) |
|
|
|
task_usage = UsageInfo.parse_obj(response["usage"]) |
|
for usage_key, usage_value in task_usage.dict().items(): |
|
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) |
|
|
|
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage) |
|
|
|
|
|
async def predict(model_id: str, params: dict): |
|
global model, tokenizer |
|
|
|
choice_data = ChatCompletionResponseStreamChoice( |
|
index=0, |
|
delta=DeltaMessage(role="assistant"), |
|
finish_reason=None |
|
) |
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") |
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) |
|
|
|
previous_text = "" |
|
for new_response in generate_stream_chatglm3(model, tokenizer, params): |
|
decoded_unicode = new_response["text"] |
|
delta_text = decoded_unicode[len(previous_text):] |
|
previous_text = decoded_unicode |
|
|
|
if len(delta_text) == 0: |
|
delta_text = None |
|
|
|
choice_data = ChatCompletionResponseStreamChoice( |
|
index=0, |
|
delta=DeltaMessage(content=delta_text), |
|
finish_reason=None |
|
) |
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") |
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) |
|
|
|
choice_data = ChatCompletionResponseStreamChoice( |
|
index=0, |
|
delta=DeltaMessage(), |
|
finish_reason="stop" |
|
) |
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") |
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) |
|
yield '[DONE]' |
|
|
|
|
|
if __name__ == "__main__": |
|
tokenizer = AutoTokenizer.from_pretrained("D:\git\model\chatglm3-6b-32k", trust_remote_code=True) |
|
model = AutoModel.from_pretrained("D:\git\model\chatglm3-6b-32k", trust_remote_code=True).cuda() |
|
|
|
|
|
|
|
model = model.eval() |
|
|
|
uvicorn.run(app, host='0.0.0.0', port=8080, workers=1) |
|
|