Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
from pydantic import BaseModel | |
from typing import List, Optional | |
import sys | |
import os | |
import logging | |
# Add src directory to Python path | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from generation.medical_generator import MedicalTextGenerator | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI( | |
title="Synthex Medical Text Generator API", | |
description="API for generating synthetic medical records", | |
version="1.0.0" | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allows all origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all methods | |
allow_headers=["*"], # Allows all headers | |
) | |
# Mount static files | |
app.mount("/static", StaticFiles(directory="src/web"), name="static") | |
# Initialize generator | |
generator = MedicalTextGenerator() | |
class GenerationRequest(BaseModel): | |
record_type: str | |
quantity: int = 1 | |
use_gemini: bool = False | |
include_metadata: bool = True | |
class MedicalRecord(BaseModel): | |
id: str | |
type: str | |
text: str | |
timestamp: str | |
source: str | |
class GenerationResponse(BaseModel): | |
records: List[MedicalRecord] | |
total_generated: int | |
async def read_root(): | |
"""Serve the HTML interface""" | |
return FileResponse("src/web/index.html") | |
async def get_record_types(): | |
"""Get available record types""" | |
return {"record_types": list(generator.templates.keys())} | |
async def generate_records(request: GenerationRequest): | |
"""Generate synthetic medical records""" | |
try: | |
if request.record_type not in generator.templates: | |
raise HTTPException(status_code=400, detail=f"Invalid record type. Available types: {list(generator.templates.keys())}") | |
if request.quantity < 1 or request.quantity > 10: | |
raise HTTPException(status_code=400, detail="Quantity must be between 1 and 10") | |
records = generator.batch_generate( | |
record_type=request.record_type, | |
count=request.quantity, | |
use_gemini=request.use_gemini | |
) | |
return { | |
"records": records, | |
"total_generated": len(records) | |
} | |
except Exception as e: | |
logger.error(f"Error generating records: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |