Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Header, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel | |
import openai | |
from typing import List, Optional | |
import logging | |
from itertools import cycle | |
import asyncio | |
import uvicorn | |
from app import config | |
import requests | |
from datetime import datetime, timezone | |
# 配置日志 | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# 允许跨域 | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# API密钥配置 | |
API_KEYS = config.settings.API_KEYS | |
# 创建一个循环迭代器 | |
key_cycle = cycle(API_KEYS) | |
key_lock = asyncio.Lock() | |
class ChatRequest(BaseModel): | |
messages: List[dict] | |
model: str = "llama-3.2-90b-text-preview" | |
temperature: Optional[float] = 0.7 | |
stream: Optional[bool] = False | |
async def verify_authorization(authorization: str = Header(None)): | |
if not authorization: | |
logger.error("Missing Authorization header") | |
raise HTTPException(status_code=401, detail="Missing Authorization header") | |
if not authorization.startswith("Bearer "): | |
logger.error("Invalid Authorization header format") | |
raise HTTPException( | |
status_code=401, detail="Invalid Authorization header format" | |
) | |
token = authorization.replace("Bearer ", "") | |
if token not in config.settings.ALLOWED_TOKENS: | |
logger.error("Invalid token") | |
raise HTTPException(status_code=401, detail="Invalid token") | |
return token | |
def get_gemini_models(api_key): | |
base_url = "https://generativelanguage.googleapis.com/v1beta" | |
url = f"{base_url}/models?key={api_key}" | |
try: | |
response = requests.get(url) | |
if response.status_code == 200: | |
gemini_models = response.json() | |
return convert_to_openai_format(gemini_models) | |
else: | |
print(f"Error: {response.status_code}") | |
print(response.text) | |
return None | |
except requests.RequestException as e: | |
print(f"Request failed: {e}") | |
return None | |
def convert_to_openai_format(gemini_models): | |
openai_format = { | |
"object": "list", | |
"data": [] | |
} | |
for model in gemini_models.get('models', []): | |
openai_model = { | |
"id": model['name'].split('/')[-1], # 取最后一部分作为ID | |
"object": "model", | |
"created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳 | |
"owned_by": "google", # 假设所有Gemini模型都由Google拥有 | |
"permission": [], # Gemini API可能没有直接对应的权限信息 | |
"root": model['name'], | |
"parent": None, # Gemini API可能没有直接对应的父模型信息 | |
} | |
openai_format["data"].append(openai_model) | |
return openai_format | |
async def list_models(authorization: str = Header(None)): | |
await verify_authorization(authorization) | |
async with key_lock: | |
api_key = next(key_cycle) | |
logger.info(f"Using API key: {api_key[:8]}...") | |
try: | |
response = get_gemini_models(api_key) | |
logger.info("Successfully retrieved models list") | |
return response | |
except Exception as e: | |
logger.error(f"Error listing models: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def chat_completion(request: ChatRequest, authorization: str = Header(None)): | |
await verify_authorization(authorization) | |
async with key_lock: | |
api_key = next(key_cycle) | |
logger.info(f"Using API key: {api_key[:8]}...") | |
try: | |
logger.info(f"Chat completion request - Model: {request.model}") | |
client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL) | |
response = client.chat.completions.create( | |
model=request.model, | |
messages=request.messages, | |
temperature=request.temperature, | |
stream=request.stream if hasattr(request, "stream") else False, | |
) | |
if hasattr(request, "stream") and request.stream: | |
logger.info("Streaming response enabled") | |
async def generate(): | |
for chunk in response: | |
yield f"data: {chunk.model_dump_json()}\n\n" | |
return StreamingResponse(content=generate(), media_type="text/event-stream") | |
logger.info("Chat completion successful") | |
return response | |
except Exception as e: | |
logger.error(f"Error in chat completion: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
logger.info("Health check endpoint called") | |
return {"status": "healthy"} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) |