Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['HF_HOME'] = '/tmp/.cache/huggingface' # Use /tmp in Spaces | |
| os.makedirs(os.environ['HF_HOME'], exist_ok=True) # Ensure directory exists | |
| from fastapi import FastAPI | |
| from fastapi.responses import HTMLResponse | |
| import torch | |
| import numpy | |
| from transformers import AutoTokenizer | |
| from huggingface_hub import login | |
| from pydantic import BaseModel | |
| import warnings | |
| from transformers import logging as hf_logging | |
| from qwen_classifier.predict import predict_single # Your existing function | |
| from qwen_classifier.evaluate import evaluate_batch # Your existing function | |
| from qwen_classifier.globals import global_model, global_tokenizer | |
| from qwen_classifier.model import QwenClassifier | |
| from qwen_classifier.config import HF_REPO, DEVICE | |
| print(numpy.__version__) | |
| app = FastAPI(title="Qwen Classifier") | |
| hf_repo = os.getenv("HF_REPO") | |
| if not hf_repo: | |
| hf_repo = HF_REPO | |
| debug = False | |
| if not debug: | |
| warnings.filterwarnings("ignore", message="Some weights of the model checkpoint") | |
| hf_logging.set_verbosity_error() | |
| else: | |
| hf_logging.set_verbosity_info() | |
| warnings.simplefilter("default") | |
| # Add this endpoint | |
| def home(): | |
| return """ | |
| <html> | |
| <head> | |
| <title>Qwen Classifier</title> | |
| </head> | |
| <body> | |
| <h1>Qwen Classifier API</h1> | |
| <p>Available endpoints:</p> | |
| <ul> | |
| <li><strong>POST /predict</strong> - Classify text</li> | |
| <li><strong>POST /evaluate</strong> - Evaluate batch text prediction from zip file</li> | |
| <li><strong>GET /health</strong> - Check API status</li> | |
| </ul> | |
| <p>Try it: <code>curl -X POST https://keivanr-qwen-classifier-demo.hf.space/predict -H "Content-Type: application/json" -d '{"text":"your text"}'</code></p> | |
| </body> | |
| </html> | |
| """ | |
| async def load_model(): | |
| global global_model, global_tokenizer | |
| # Warm up GPU | |
| torch.zeros(1).cuda() | |
| # Read HF_TOKEN from Hugging Face Space secrets | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| raise ValueError("HF_TOKEN not found in environment variables") | |
| # Authenticate | |
| login(token=hf_token) | |
| # Load model (will cache in /home/user/.cache/huggingface) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = QwenClassifier.from_pretrained( | |
| hf_repo, | |
| ).to(DEVICE) | |
| global_tokenizer = AutoTokenizer.from_pretrained(hf_repo) | |
| print("Model loaded successfully!") | |
| class PredictionRequest(BaseModel): | |
| text: str # β Enforces that 'text' must be a non-empty string | |
| class EvaluationRequest(BaseModel): | |
| file_path: str # β Enforces that 'text' must be a non-empty string | |
| async def predict(request: PredictionRequest): # β Validates input automatically | |
| return predict_single(request.text, hf_repo, backend="local") | |
| async def evaluate(request: EvaluationRequest): # β Validates input automatically | |
| return evaluate_batch(request.file_path, hf_repo, backend="local") | |
| def health_check(): | |
| return {"status": "healthy", "model": "loaded"} |