ymx / inference /inference.py
xulh
代码初始化
bb1d3e2
raw
history blame
2.46 kB
import asyncio
import httpx
from fastapi import APIRouter, Header, HTTPException, Body
from .apiModel import Payload
from huggingface_hub import InferenceClient
router = APIRouter()
API_URL = "https://api-inference.huggingface.co/models/cardiffnlp/meta-llama/Llama-3.1-8B-Instruct"
# 使用httpx异步请求
async def fetch_model_response(payload: dict, headers: dict):
async with httpx.AsyncClient() as client:
try:
response = await client.post(API_URL, headers=headers, json=payload)
if response.status_code == 503:
# 如果模型正在加载,等待并重试
print("模型加载中,等待中...")
await asyncio.sleep(20) # 等待20秒
return await fetch_model_response(payload, headers) # 重试请求
response.raise_for_status() # 如果返回错误状态码,会抛出异常
return response.json()
except httpx.RequestError as e:
raise HTTPException(status_code=500, detail=f"请求错误: {e}")
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=response.status_code, detail=f"HTTP 错误: {e}")
@router.post("/chat-completion/")
async def chat_completion(token: str = Body(...), messages: list = Body(...)):
try:
# 创建 InferenceClient
client = InferenceClient(api_key=token)
messages.append({
"role": "system",
"content": "You are a multilingual chatbot capable of understanding questions in various languages and "
"providing accurate responses in the appropriate language."
})
# 使用 chat API 请求生成模型的回答
completion = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=messages,
max_tokens=500
)
# 返回对话信息
return {"message": completion.choices[0].message}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}")
@router.post("/api-inference/")
async def api_inference(
authorization: str = Header(...),
item: Payload = None):
print("请求:", item)
# 设置请求头
headers = {"Authorization": authorization}
# 使用异步请求
response_data = await fetch_model_response(item.dict(), headers)
return response_data