"""OpenAI-compatible API endpoints.""" from typing import Dict, List, Optional, Union from pydantic import BaseModel, Field from fastapi import APIRouter, HTTPException, Depends import time import json import asyncio from datetime import datetime class ChatMessage(BaseModel): """OpenAI-compatible chat message.""" role: str = Field(..., description="The role of the message author (system/user/assistant)") content: str = Field(..., description="The content of the message") name: Optional[str] = Field(None, description="The name of the author") class ChatCompletionRequest(BaseModel): """OpenAI-compatible chat completion request.""" model: str = Field(..., description="Model to use") messages: List[ChatMessage] temperature: Optional[float] = Field(0.7, description="Sampling temperature") top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter") n: Optional[int] = Field(1, description="Number of completions") stream: Optional[bool] = Field(False, description="Whether to stream responses") stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences") max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate") presence_penalty: Optional[float] = Field(0.0, description="Presence penalty") frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty") user: Optional[str] = Field(None, description="User identifier") class ChatCompletionResponse(BaseModel): """OpenAI-compatible chat completion response.""" id: str = Field(..., description="Unique identifier for the completion") object: str = Field("chat.completion", description="Object type") created: int = Field(..., description="Unix timestamp of creation") model: str = Field(..., description="Model used") choices: List[Dict] = Field(..., description="Completion choices") usage: Dict[str, int] = Field(..., description="Token usage statistics") class OpenAICompatibleAPI: """OpenAI-compatible API implementation.""" def __init__(self, reasoning_engine): self.reasoning_engine = reasoning_engine self.router = APIRouter() self.setup_routes() def setup_routes(self): """Setup API routes.""" @self.router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest) -> ChatCompletionResponse: try: # Convert chat history to context context = self._prepare_context(request.messages) # Get the last user message user_message = next( (msg.content for msg in reversed(request.messages) if msg.role == "user"), None ) if not user_message: raise HTTPException(status_code=400, detail="No user message found") # Process with reasoning engine result = await self.reasoning_engine.reason( query=user_message, context={ "chat_history": context, "temperature": request.temperature, "top_p": request.top_p, "max_tokens": request.max_tokens, "stream": request.stream } ) # Format response response = { "id": f"chatcmpl-{int(time.time()*1000)}", "object": "chat.completion", "created": int(time.time()), "model": request.model, "choices": [{ "index": 0, "message": { "role": "assistant", "content": result.answer }, "finish_reason": "stop" }], "usage": { "prompt_tokens": self._estimate_tokens(user_message), "completion_tokens": self._estimate_tokens(result.answer), "total_tokens": self._estimate_tokens(user_message) + self._estimate_tokens(result.answer) } } return ChatCompletionResponse(**response) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @self.router.get("/v1/models") async def list_models(): """List available models.""" return { "object": "list", "data": [ { "id": "venture-gpt-1", "object": "model", "created": int(time.time()), "owned_by": "venture-ai", "permission": [], "root": "venture-gpt-1", "parent": None } ] } def _prepare_context(self, messages: List[ChatMessage]) -> List[Dict]: """Convert messages to context format.""" return [ { "role": msg.role, "content": msg.content, "name": msg.name, "timestamp": datetime.now().isoformat() } for msg in messages ] def _estimate_tokens(self, text: str) -> int: """Estimate token count for a text.""" # Simple estimation: ~4 characters per token return len(text) // 4