ymx / inference /inference.py
xulh
代码初始化
5b6514f
raw
history blame
2.25 kB
import asyncio
import httpx
from fastapi import APIRouter, Header, HTTPException
from .apiModel import Payload
import transformers
import torch
router = APIRouter()
API_URL = "https://api-inference.huggingface.co/models/cardiffnlp/meta-llama/Llama-3.1-8B-Instruct"
# 使用httpx异步请求
async def fetch_model_response(payload: dict, headers: dict):
async with httpx.AsyncClient() as client:
try:
response = await client.post(API_URL, headers=headers, json=payload)
if response.status_code == 503:
# 如果模型正在加载,等待并重试
print("模型加载中,等待中...")
await asyncio.sleep(20) # 等待20秒
return await fetch_model_response(payload, headers) # 重试请求
response.raise_for_status() # 如果返回错误状态码,会抛出异常
return response.json()
except httpx.RequestError as e:
raise HTTPException(status_code=500, detail=f"请求错误: {e}")
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=response.status_code, detail=f"HTTP 错误: {e}")
@router.post("/api-llama/")
async def api_inference(
authorization: str = Header(...),
item: Payload = None):
print("请求:", item)
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
# 设置请求头
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
)
messages = [
{"role": "system", "content": "你是一个万能聊天机器人,能准确回答每一个提出的问题"},
{"role": "user", "content": "你是谁?"},
]
outputs = pipeline(
messages,
max_new_tokens=256,
)
# 使用异步请求
return outputs
@router.post("/api-inference/")
async def api_inference(
authorization: str = Header(...),
item: Payload = None):
print("请求:", item)
# 设置请求头
headers = {"Authorization": authorization}
# 使用异步请求
response_data = await fetch_model_response(item.dict(), headers)
return response_data