xulh
commited on
Commit
·
3ae1a20
1
Parent(s):
9cad8c3
代码初始化
Browse files- inference/inference.py +27 -1
inference/inference.py
CHANGED
|
@@ -27,7 +27,7 @@ async def fetch_model_response(payload: dict, headers: dict):
|
|
| 27 |
raise HTTPException(status_code=response.status_code, detail=f"HTTP 错误: {e}")
|
| 28 |
|
| 29 |
|
| 30 |
-
@router.post("/chat-completion/")
|
| 31 |
async def chat_completion(token: str = Body(...), messages: list = Body(...)):
|
| 32 |
try:
|
| 33 |
# 创建 InferenceClient
|
|
@@ -53,6 +53,32 @@ async def chat_completion(token: str = Body(...), messages: list = Body(...)):
|
|
| 53 |
raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}")
|
| 54 |
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
@router.post("/api-inference/")
|
| 57 |
async def api_inference(
|
| 58 |
authorization: str = Header(...),
|
|
|
|
| 27 |
raise HTTPException(status_code=response.status_code, detail=f"HTTP 错误: {e}")
|
| 28 |
|
| 29 |
|
| 30 |
+
@router.post("/chat-completion-academic/")
|
| 31 |
async def chat_completion(token: str = Body(...), messages: list = Body(...)):
|
| 32 |
try:
|
| 33 |
# 创建 InferenceClient
|
|
|
|
| 53 |
raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}")
|
| 54 |
|
| 55 |
|
| 56 |
+
@router.post("/chat-completion/")
|
| 57 |
+
async def chat_completion(token: str = Body(...), messages: list = Body(...)):
|
| 58 |
+
try:
|
| 59 |
+
# 创建 InferenceClient
|
| 60 |
+
client = InferenceClient(api_key=token)
|
| 61 |
+
|
| 62 |
+
messages.append({
|
| 63 |
+
"role": "system",
|
| 64 |
+
"content": "You are a multilingual chatbot capable of understanding questions in various languages and "
|
| 65 |
+
"providing accurate responses in the appropriate language."
|
| 66 |
+
})
|
| 67 |
+
|
| 68 |
+
# 使用 chat API 请求生成模型的回答
|
| 69 |
+
completion = client.chat.completions.create(
|
| 70 |
+
model="google/gemma-2-2b-it",
|
| 71 |
+
messages=messages,
|
| 72 |
+
max_tokens=500
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# 返回对话信息
|
| 76 |
+
return {"message": completion.choices[0].message}
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
@router.post("/api-inference/")
|
| 83 |
async def api_inference(
|
| 84 |
authorization: str = Header(...),
|