xulh
commited on
Commit
·
69d5d0e
1
Parent(s):
01ff4cd
代码初始化
Browse files- inference/inference.py +17 -26
inference/inference.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
import asyncio
|
2 |
import httpx
|
3 |
-
from fastapi import APIRouter, Header, HTTPException
|
4 |
from .apiModel import Payload
|
5 |
-
import
|
6 |
-
import torch
|
7 |
|
8 |
router = APIRouter()
|
9 |
|
@@ -28,32 +27,24 @@ async def fetch_model_response(payload: dict, headers: dict):
|
|
28 |
raise HTTPException(status_code=response.status_code, detail=f"HTTP 错误: {e}")
|
29 |
|
30 |
|
31 |
-
@router.post("/
|
32 |
-
async def
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
37 |
-
# 设置请求头
|
38 |
-
pipeline = transformers.pipeline(
|
39 |
-
"text-generation",
|
40 |
-
model=model_id,
|
41 |
-
model_kwargs={"torch_dtype": torch.bfloat16},
|
42 |
-
device_map="auto",
|
43 |
-
)
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
max_new_tokens=256,
|
53 |
-
)
|
54 |
|
55 |
-
|
56 |
-
|
57 |
|
58 |
|
59 |
@router.post("/api-inference/")
|
|
|
1 |
import asyncio
|
2 |
import httpx
|
3 |
+
from fastapi import APIRouter, Header, HTTPException, Body
|
4 |
from .apiModel import Payload
|
5 |
+
from huggingface_hub import InferenceClient
|
|
|
6 |
|
7 |
router = APIRouter()
|
8 |
|
|
|
27 |
raise HTTPException(status_code=response.status_code, detail=f"HTTP 错误: {e}")
|
28 |
|
29 |
|
30 |
+
@router.post("/chat-completion/")
|
31 |
+
async def chat_completion(token: str = Body(...), messages: list = Body(...)):
|
32 |
+
try:
|
33 |
+
# 创建 InferenceClient
|
34 |
+
client = InferenceClient(api_key=token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
# 使用 chat API 请求生成模型的回答
|
37 |
+
completion = client.chat.completions.create(
|
38 |
+
model="meta-llama/Llama-3.1-8B-Instruct",
|
39 |
+
messages=messages,
|
40 |
+
max_tokens=500
|
41 |
+
)
|
42 |
|
43 |
+
# 返回对话信息
|
44 |
+
return {"message": completion.choices[0].message}
|
|
|
|
|
45 |
|
46 |
+
except Exception as e:
|
47 |
+
raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}")
|
48 |
|
49 |
|
50 |
@router.post("/api-inference/")
|