Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List, Optional, Dict, Any, Tuple | |
| import json | |
| from classifiers.llm import LLMClassifier | |
| from litellm import completion | |
| import asyncio | |
| from client import get_client, initialize_client | |
| import os | |
| from dotenv import load_dotenv | |
| import pandas as pd | |
| from utils import validate_results | |
| from process import improve_classification | |
| # Load environment variables | |
| load_dotenv() | |
| app: FastAPI = FastAPI() | |
| # Configure CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, replace with specific origins | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize client with API key from environment | |
| api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") | |
| if api_key: | |
| success: bool | |
| message: str | |
| success, message = initialize_client(api_key) | |
| if not success: | |
| raise RuntimeError(f"Failed to initialize OpenAI client: {message}") | |
| client = get_client() | |
| if not client: | |
| raise RuntimeError("OpenAI client not initialized. Please set OPENAI_API_KEY environment variable.") | |
| # Initialize the LLM classifier | |
| classifier: LLMClassifier = LLMClassifier(client=client, model="gpt-3.5-turbo") | |
| class TextInput(BaseModel): | |
| text: str | |
| categories: Optional[List[str]] = None | |
| class BatchTextInput(BaseModel): | |
| texts: List[str] | |
| categories: Optional[List[str]] = None | |
| class ClassificationResponse(BaseModel): | |
| category: str | |
| confidence: float | |
| explanation: str | |
| class BatchClassificationResponse(BaseModel): | |
| results: List[ClassificationResponse] | |
| class CategorySuggestionResponse(BaseModel): | |
| categories: List[str] | |
| class ModelInfoResponse(BaseModel): | |
| model_name: str | |
| model_version: str | |
| max_tokens: int | |
| temperature: float | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_ready: bool | |
| api_key_configured: bool | |
| class ValidationSample(BaseModel): | |
| text: str | |
| assigned_category: str | |
| confidence: float | |
| class ValidationRequest(BaseModel): | |
| samples: List[ValidationSample] | |
| current_categories: List[str] | |
| text_columns: List[str] | |
| class ValidationResponse(BaseModel): | |
| validation_report: str | |
| accuracy_score: Optional[float] = None | |
| misclassifications: Optional[List[Dict[str, Any]]] = None | |
| suggested_improvements: Optional[List[str]] = None | |
| class ImprovementRequest(BaseModel): | |
| df: Dict[str, Any] # JSON representation of the DataFrame | |
| validation_report: str | |
| text_columns: List[str] | |
| categories: str | |
| classifier_type: str | |
| show_explanations: bool | |
| file_path: str | |
| class ImprovementResponse(BaseModel): | |
| improved_df: Dict[str, Any] # JSON representation of the improved DataFrame | |
| new_validation_report: str | |
| success: bool | |
| updated_categories: List[str] | |
| async def health_check() -> HealthResponse: | |
| """Check the health status of the API""" | |
| return HealthResponse( | |
| status="healthy", | |
| model_ready=client is not None, | |
| api_key_configured=api_key is not None | |
| ) | |
| async def get_model_info() -> ModelInfoResponse: | |
| """Get information about the current model configuration""" | |
| return ModelInfoResponse( | |
| model_name=classifier.model, | |
| model_version="1.0", | |
| max_tokens=200, | |
| temperature=0 | |
| ) | |
| async def classify_text(text_input: TextInput) -> ClassificationResponse: | |
| try: | |
| # Use async classification | |
| results: List[Dict[str, Any]] = await classifier.classify_async( | |
| [text_input.text], | |
| text_input.categories | |
| ) | |
| result: Dict[str, Any] = results[0] # Get first result since we're classifying one text | |
| return ClassificationResponse( | |
| category=result["category"], | |
| confidence=result["confidence"], | |
| explanation=result["explanation"] | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def classify_batch(batch_input: BatchTextInput) -> BatchClassificationResponse: | |
| """Classify multiple texts in a single request""" | |
| try: | |
| results: List[Dict[str, Any]] = await classifier.classify_async( | |
| batch_input.texts, | |
| batch_input.categories | |
| ) | |
| return BatchClassificationResponse( | |
| results=[ | |
| ClassificationResponse( | |
| category=r["category"], | |
| confidence=r["confidence"], | |
| explanation=r["explanation"] | |
| ) for r in results | |
| ] | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def suggest_categories(texts: List[str]) -> CategorySuggestionResponse: | |
| try: | |
| categories: List[str] = await classifier._suggest_categories_async(texts) | |
| return CategorySuggestionResponse(categories=categories) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def validate_classifications(validation_request: ValidationRequest) -> ValidationResponse: | |
| """Validate classification results and provide improvement suggestions""" | |
| try: | |
| # Convert samples to DataFrame | |
| df = pd.DataFrame([ | |
| { | |
| "text": sample.text, | |
| "Category": sample.assigned_category, | |
| "Confidence": sample.confidence | |
| } | |
| for sample in validation_request.samples | |
| ]) | |
| # Use the validate_results function from utils | |
| validation_report: str = validate_results(df, validation_request.text_columns, client) | |
| # Parse the validation report to extract structured information | |
| accuracy_score: Optional[float] = None | |
| misclassifications: Optional[List[Dict[str, Any]]] = None | |
| suggested_improvements: Optional[List[str]] = None | |
| # Extract accuracy score if present | |
| if "accuracy" in validation_report.lower(): | |
| try: | |
| accuracy_str = validation_report.lower().split("accuracy")[1].split("%")[0].strip() | |
| accuracy_score = float(accuracy_str) / 100 | |
| except: | |
| pass | |
| # Extract misclassifications | |
| misclassifications = [ | |
| {"text": sample.text, "current_category": sample.assigned_category} | |
| for sample in validation_request.samples | |
| if sample.confidence < 70 | |
| ] | |
| # Extract suggested improvements | |
| suggested_improvements = [ | |
| "Review low confidence classifications", | |
| "Consider adding more training examples", | |
| "Refine category definitions" | |
| ] | |
| return ValidationResponse( | |
| validation_report=validation_report, | |
| accuracy_score=accuracy_score, | |
| misclassifications=misclassifications, | |
| suggested_improvements=suggested_improvements | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def improve_classification_endpoint(request: ImprovementRequest) -> ImprovementResponse: | |
| """Improve classification based on validation report""" | |
| try: | |
| # Convert JSON DataFrame back to pandas DataFrame | |
| df = pd.DataFrame.from_dict(request.df) | |
| # Call the improve_classification function | |
| improved_df, new_validation, success, updated_categories = await improve_classification( | |
| df=df, | |
| validation_report=request.validation_report, | |
| text_columns=request.text_columns, | |
| categories=request.categories, | |
| classifier_type=request.classifier_type, | |
| show_explanations=request.show_explanations, | |
| file=request.file_path | |
| ) | |
| # Convert improved DataFrame to JSON | |
| improved_df_json = improved_df.to_dict() if improved_df is not None else None | |
| return ImprovementResponse( | |
| improved_df=improved_df_json, | |
| new_validation_report=new_validation, | |
| success=success, | |
| updated_categories=updated_categories | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True) |