|
import asyncio |
|
import httpx |
|
from fastapi import APIRouter, Header, HTTPException |
|
from .apiModel import Payload |
|
import transformers |
|
import torch |
|
|
|
router = APIRouter() |
|
|
|
API_URL = "https://api-inference.huggingface.co/models/cardiffnlp/meta-llama/Llama-3.1-8B-Instruct" |
|
|
|
|
|
|
|
async def fetch_model_response(payload: dict, headers: dict): |
|
async with httpx.AsyncClient() as client: |
|
try: |
|
response = await client.post(API_URL, headers=headers, json=payload) |
|
if response.status_code == 503: |
|
|
|
print("模型加载中,等待中...") |
|
await asyncio.sleep(20) |
|
return await fetch_model_response(payload, headers) |
|
response.raise_for_status() |
|
return response.json() |
|
except httpx.RequestError as e: |
|
raise HTTPException(status_code=500, detail=f"请求错误: {e}") |
|
except httpx.HTTPStatusError as e: |
|
raise HTTPException(status_code=response.status_code, detail=f"HTTP 错误: {e}") |
|
|
|
|
|
@router.post("/api-llama/") |
|
async def api_inference( |
|
authorization: str = Header(...), |
|
item: Payload = None): |
|
print("请求:", item) |
|
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
|
|
pipeline = transformers.pipeline( |
|
"text-generation", |
|
model=model_id, |
|
model_kwargs={"torch_dtype": torch.bfloat16}, |
|
device_map="auto", |
|
) |
|
|
|
messages = [ |
|
{"role": "system", "content": "你是一个万能聊天机器人,能准确回答每一个提出的问题"}, |
|
{"role": "user", "content": "你是谁?"}, |
|
] |
|
|
|
outputs = pipeline( |
|
messages, |
|
max_new_tokens=256, |
|
) |
|
|
|
|
|
return outputs |
|
|
|
|
|
@router.post("/api-inference/") |
|
async def api_inference( |
|
authorization: str = Header(...), |
|
item: Payload = None): |
|
print("请求:", item) |
|
|
|
|
|
headers = {"Authorization": authorization} |
|
|
|
|
|
response_data = await fetch_model_response(item.dict(), headers) |
|
|
|
return response_data |
|
|