pagezyhf's picture
pagezyhf HF Staff
tesst
bdc22f6
raw
history blame
3.68 kB
from fastapi import FastAPI, Request
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse, FileResponse
from optimum.neuron import utils
import logging
import sys
import os
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
app = FastAPI()
# Get the absolute path to the static directory
static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static")
logger.info(f"Static directory path: {static_dir}")
# Get the absolute path to the templates directory
templates_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
logger.info(f"Templates directory path: {templates_dir}")
# Mount static files and templates
app.mount("/static", StaticFiles(directory=static_dir), name="static")
templates = Jinja2Templates(directory=templates_dir)
@app.get("/health")
async def health_check():
logger.info("Health check endpoint called")
return {"status": "healthy"}
@app.get("/")
async def home(request: Request):
logger.info("Home page requested")
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/api/models")
async def get_model_list():
logger.info("Fetching model list")
try:
# Add debug logging
logger.info(f"HF_TOKEN present: {bool(os.getenv('HF_TOKEN'))}")
model_list = utils.get_hub_cached_models(mode="inference")
logger.info(f"Found {len(model_list)} models")
models = []
seen_models = set()
for model_tuple in model_list:
architecture, org, model_id = model_tuple
full_model_id = f"{org}/{model_id}"
if full_model_id not in seen_models:
models.append({
"id": full_model_id,
"name": full_model_id,
"type": architecture
})
seen_models.add(full_model_id)
logger.info(f"Returning {len(models)} unique models")
return JSONResponse(content=models)
except Exception as e:
# Enhanced error logging
logger.error(f"Error fetching models: {str(e)}")
logger.error("Full error details:", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": str(e), "type": str(type(e).__name__)}
)
@app.get("/api/models/{model_id:path}")
async def get_model_info_endpoint(model_id: str):
logger.info(f"Fetching configurations for model: {model_id}")
try:
configs = utils.get_hub_cached_entries(model_id=model_id, mode="inference")
logger.info(f"Found {len(configs)} configurations for model {model_id}")
# Return empty list if no configurations found
if not configs:
return JSONResponse(content={"configurations": []})
return JSONResponse(content={"configurations": configs})
except Exception as e:
logger.error(f"Error fetching configurations for model {model_id}: {str(e)}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
@app.get("/static/{path:path}")
async def static_files(path: str):
logger.info(f"Static file requested: {path}")
file_path = os.path.join(static_dir, path)
if os.path.exists(file_path):
return FileResponse(file_path)
return JSONResponse(status_code=404, content={"error": "File not found"})