Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import logging | |
import os | |
from typing import Optional | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI( | |
title="DeepSeek R1 Chat API", | |
description="DeepSeek R1 model hosted on Hugging Face Spaces", | |
version="1.0.0" | |
) | |
# Request/Response models | |
class ChatRequest(BaseModel): | |
message: str | |
max_length: Optional[int] = 512 | |
temperature: Optional[float] = 0.7 | |
top_p: Optional[float] = 0.9 | |
class ChatResponse(BaseModel): | |
response: str | |
status: str | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
async def load_model(): | |
"""Load the DeepSeek model on startup""" | |
global model, tokenizer | |
try: | |
logger.info("Loading DeepSeek R1 model...") | |
# Use a smaller DeepSeek model that fits in Spaces | |
model_name = "deepseek-ai/deepseek-r1-distill-qwen-1.5b" | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
padding_side="left" | |
) | |
# Add pad token if it doesn't exist | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Load model with appropriate settings for Spaces | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" if torch.cuda.is_available() else None, | |
low_cpu_mem_usage=True | |
) | |
logger.info("Model loaded successfully!") | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
raise e | |
async def root(): | |
"""Health check endpoint""" | |
return { | |
"message": "DeepSeek R1 Chat API is running!", | |
"status": "healthy", | |
"model_loaded": model is not None | |
} | |
async def health_check(): | |
"""Detailed health check""" | |
return { | |
"status": "healthy", | |
"model_loaded": model is not None, | |
"tokenizer_loaded": tokenizer is not None, | |
"cuda_available": torch.cuda.is_available(), | |
"device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0 | |
} | |
async def chat(request: ChatRequest): | |
"""Chat endpoint for DeepSeek model""" | |
if model is None or tokenizer is None: | |
raise HTTPException(status_code=503, detail="Model not loaded yet") | |
try: | |
# Prepare the input | |
prompt = f"User: {request.message}\nAssistant:" | |
# Tokenize input | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=1024 | |
) | |
# Move to appropriate device | |
if torch.cuda.is_available(): | |
inputs = {k: v.cuda() for k, v in inputs.items()} | |
# Generate response | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=request.max_length, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
repetition_penalty=1.1 | |
) | |
# Decode response | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the assistant's response | |
if "Assistant:" in full_response: | |
response = full_response.split("Assistant:")[-1].strip() | |
else: | |
response = full_response[len(prompt):].strip() | |
return ChatResponse(response=response, status="success") | |
except Exception as e: | |
logger.error(f"Error during generation: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
async def generate(request: ChatRequest): | |
"""Alternative generation endpoint""" | |
return await chat(request) | |
async def model_info(): | |
"""Get model information""" | |
if model is None: | |
return {"status": "Model not loaded"} | |
return { | |
"model_name": "deepseek-ai/deepseek-r1-distill-qwen-1.5b", | |
"model_type": type(model).__name__, | |
"tokenizer_type": type(tokenizer).__name__, | |
"vocab_size": tokenizer.vocab_size if tokenizer else None, | |
"device": str(next(model.parameters()).device) if model else None | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |