xulh commited on
Commit
69d5d0e
·
1 Parent(s): 01ff4cd

代码初始化

Browse files
Files changed (1) hide show
  1. inference/inference.py +17 -26
inference/inference.py CHANGED
@@ -1,9 +1,8 @@
1
  import asyncio
2
  import httpx
3
- from fastapi import APIRouter, Header, HTTPException
4
  from .apiModel import Payload
5
- import transformers
6
- import torch
7
 
8
  router = APIRouter()
9
 
@@ -28,32 +27,24 @@ async def fetch_model_response(payload: dict, headers: dict):
28
  raise HTTPException(status_code=response.status_code, detail=f"HTTP 错误: {e}")
29
 
30
 
31
- @router.post("/api-llama/")
32
- async def api_inference(
33
- authorization: str = Header(...),
34
- item: Payload = None):
35
- print("请求:", item)
36
- model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
37
- # 设置请求头
38
- pipeline = transformers.pipeline(
39
- "text-generation",
40
- model=model_id,
41
- model_kwargs={"torch_dtype": torch.bfloat16},
42
- device_map="auto",
43
- )
44
 
45
- messages = [
46
- {"role": "system", "content": "你是一个万能聊天机器人,能准确回答每一个提出的问题"},
47
- {"role": "user", "content": "你是谁?"},
48
- ]
 
 
49
 
50
- outputs = pipeline(
51
- messages,
52
- max_new_tokens=256,
53
- )
54
 
55
- # 使用异步请求
56
- return outputs
57
 
58
 
59
  @router.post("/api-inference/")
 
1
  import asyncio
2
  import httpx
3
+ from fastapi import APIRouter, Header, HTTPException, Body
4
  from .apiModel import Payload
5
+ from huggingface_hub import InferenceClient
 
6
 
7
  router = APIRouter()
8
 
 
27
  raise HTTPException(status_code=response.status_code, detail=f"HTTP 错误: {e}")
28
 
29
 
30
+ @router.post("/chat-completion/")
31
+ async def chat_completion(token: str = Body(...), messages: list = Body(...)):
32
+ try:
33
+ # 创建 InferenceClient
34
+ client = InferenceClient(api_key=token)
 
 
 
 
 
 
 
 
35
 
36
+ # 使用 chat API 请求生成模型的回答
37
+ completion = client.chat.completions.create(
38
+ model="meta-llama/Llama-3.1-8B-Instruct",
39
+ messages=messages,
40
+ max_tokens=500
41
+ )
42
 
43
+ # 返回对话信息
44
+ return {"message": completion.choices[0].message}
 
 
45
 
46
+ except Exception as e:
47
+ raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}")
48
 
49
 
50
  @router.post("/api-inference/")