xulh commited on
Commit
3ae1a20
·
1 Parent(s): 9cad8c3

代码初始化

Browse files
Files changed (1) hide show
  1. inference/inference.py +27 -1
inference/inference.py CHANGED
@@ -27,7 +27,7 @@ async def fetch_model_response(payload: dict, headers: dict):
27
  raise HTTPException(status_code=response.status_code, detail=f"HTTP 错误: {e}")
28
 
29
 
30
- @router.post("/chat-completion/")
31
  async def chat_completion(token: str = Body(...), messages: list = Body(...)):
32
  try:
33
  # 创建 InferenceClient
@@ -53,6 +53,32 @@ async def chat_completion(token: str = Body(...), messages: list = Body(...)):
53
  raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}")
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  @router.post("/api-inference/")
57
  async def api_inference(
58
  authorization: str = Header(...),
 
27
  raise HTTPException(status_code=response.status_code, detail=f"HTTP 错误: {e}")
28
 
29
 
30
+ @router.post("/chat-completion-academic/")
31
  async def chat_completion(token: str = Body(...), messages: list = Body(...)):
32
  try:
33
  # 创建 InferenceClient
 
53
  raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}")
54
 
55
 
56
+ @router.post("/chat-completion/")
57
+ async def chat_completion(token: str = Body(...), messages: list = Body(...)):
58
+ try:
59
+ # 创建 InferenceClient
60
+ client = InferenceClient(api_key=token)
61
+
62
+ messages.append({
63
+ "role": "system",
64
+ "content": "You are a multilingual chatbot capable of understanding questions in various languages and "
65
+ "providing accurate responses in the appropriate language."
66
+ })
67
+
68
+ # 使用 chat API 请求生成模型的回答
69
+ completion = client.chat.completions.create(
70
+ model="google/gemma-2-2b-it",
71
+ messages=messages,
72
+ max_tokens=500
73
+ )
74
+
75
+ # 返回对话信息
76
+ return {"message": completion.choices[0].message}
77
+
78
+ except Exception as e:
79
+ raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}")
80
+
81
+
82
  @router.post("/api-inference/")
83
  async def api_inference(
84
  authorization: str = Header(...),