xulh
commited on
Commit
·
69d5d0e
1
Parent(s):
01ff4cd
代码初始化
Browse files- inference/inference.py +17 -26
inference/inference.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
import asyncio
|
| 2 |
import httpx
|
| 3 |
-
from fastapi import APIRouter, Header, HTTPException
|
| 4 |
from .apiModel import Payload
|
| 5 |
-
import
|
| 6 |
-
import torch
|
| 7 |
|
| 8 |
router = APIRouter()
|
| 9 |
|
|
@@ -28,32 +27,24 @@ async def fetch_model_response(payload: dict, headers: dict):
|
|
| 28 |
raise HTTPException(status_code=response.status_code, detail=f"HTTP 错误: {e}")
|
| 29 |
|
| 30 |
|
| 31 |
-
@router.post("/
|
| 32 |
-
async def
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
| 37 |
-
# 设置请求头
|
| 38 |
-
pipeline = transformers.pipeline(
|
| 39 |
-
"text-generation",
|
| 40 |
-
model=model_id,
|
| 41 |
-
model_kwargs={"torch_dtype": torch.bfloat16},
|
| 42 |
-
device_map="auto",
|
| 43 |
-
)
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
max_new_tokens=256,
|
| 53 |
-
)
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
|
| 58 |
|
| 59 |
@router.post("/api-inference/")
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import httpx
|
| 3 |
+
from fastapi import APIRouter, Header, HTTPException, Body
|
| 4 |
from .apiModel import Payload
|
| 5 |
+
from huggingface_hub import InferenceClient
|
|
|
|
| 6 |
|
| 7 |
router = APIRouter()
|
| 8 |
|
|
|
|
| 27 |
raise HTTPException(status_code=response.status_code, detail=f"HTTP 错误: {e}")
|
| 28 |
|
| 29 |
|
| 30 |
+
@router.post("/chat-completion/")
|
| 31 |
+
async def chat_completion(token: str = Body(...), messages: list = Body(...)):
|
| 32 |
+
try:
|
| 33 |
+
# 创建 InferenceClient
|
| 34 |
+
client = InferenceClient(api_key=token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
# 使用 chat API 请求生成模型的回答
|
| 37 |
+
completion = client.chat.completions.create(
|
| 38 |
+
model="meta-llama/Llama-3.1-8B-Instruct",
|
| 39 |
+
messages=messages,
|
| 40 |
+
max_tokens=500
|
| 41 |
+
)
|
| 42 |
|
| 43 |
+
# 返回对话信息
|
| 44 |
+
return {"message": completion.choices[0].message}
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
except Exception as e:
|
| 47 |
+
raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}")
|
| 48 |
|
| 49 |
|
| 50 |
@router.post("/api-inference/")
|