xulh commited on
Commit
4060252
·
1 Parent(s): 587f672

代码初始化

Browse files
Files changed (2) hide show
  1. inference/inference.py +28 -9
  2. requirements.txt +1 -0
inference/inference.py CHANGED
@@ -1,22 +1,41 @@
1
- import requests
2
- from fastapi import APIRouter, Header
3
  from .apiModel import Payload
 
4
 
5
  router = APIRouter()
6
 
7
  API_URL = "https://api-inference.huggingface.co/models/cardiffnlp/twitter-roberta-base-sentiment-latest"
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  @router.post("/api-inference/")
11
  async def api_inference(
12
  authorization: str = Header(...),
13
  item: Payload = None):
14
  print("请求:", item)
 
 
15
  headers = {"Authorization": authorization}
16
- # Forward request to Hugging Face API
17
- response = requests.post(
18
- API_URL,
19
- headers=headers,
20
- json=item.dict()
21
- )
22
- return response.json()
 
1
+ import httpx
2
+ from fastapi import APIRouter, Header, HTTPException
3
  from .apiModel import Payload
4
+ import time
5
 
6
  router = APIRouter()
7
 
8
  API_URL = "https://api-inference.huggingface.co/models/cardiffnlp/twitter-roberta-base-sentiment-latest"
9
 
10
 
11
+ # 使用httpx异步请求
12
+ async def fetch_model_response(payload: dict, headers: dict):
13
+ async with httpx.AsyncClient() as client:
14
+ try:
15
+ response = await client.post(API_URL, headers=headers, json=payload)
16
+ if response.status_code == 503:
17
+ # 如果模型正在加载,等待并重试
18
+ print("模型加载中,等待中...")
19
+ await asyncio.sleep(20) # 等待20秒
20
+ return await fetch_model_response(payload, headers) # 重试请求
21
+ response.raise_for_status() # 如果返回错误状态码,会抛出异常
22
+ return response.json()
23
+ except httpx.RequestError as e:
24
+ raise HTTPException(status_code=500, detail=f"请求错误: {e}")
25
+ except httpx.HTTPStatusError as e:
26
+ raise HTTPException(status_code=response.status_code, detail=f"HTTP 错误: {e}")
27
+
28
+
29
  @router.post("/api-inference/")
30
  async def api_inference(
31
  authorization: str = Header(...),
32
  item: Payload = None):
33
  print("请求:", item)
34
+
35
+ # 设置请求头
36
  headers = {"Authorization": authorization}
37
+
38
+ # 使用异步请求
39
+ response_data = await fetch_model_response(item.dict(), headers)
40
+
41
+ return response_data
 
 
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  fastapi
 
2
  requests
3
  uvicorn[standard]
 
1
  fastapi
2
+ httpx
3
  requests
4
  uvicorn[standard]