GodfreyOwino commited on
Commit
55f7a10
·
verified ·
1 Parent(s): fe4e62b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+ import logging
6
+ import os
7
+ from typing import Optional
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ app = FastAPI(
14
+ title="DeepSeek R1 Chat API",
15
+ description="DeepSeek R1 model hosted on Hugging Face Spaces",
16
+ version="1.0.0"
17
+ )
18
+
19
+ # Request/Response models
20
+ class ChatRequest(BaseModel):
21
+ message: str
22
+ max_length: Optional[int] = 512
23
+ temperature: Optional[float] = 0.7
24
+ top_p: Optional[float] = 0.9
25
+
26
+ class ChatResponse(BaseModel):
27
+ response: str
28
+ status: str
29
+
30
+ # Global variables for model and tokenizer
31
+ model = None
32
+ tokenizer = None
33
+
34
+ @app.on_event("startup")
35
+ async def load_model():
36
+ """Load the DeepSeek model on startup"""
37
+ global model, tokenizer
38
+
39
+ try:
40
+ logger.info("Loading DeepSeek R1 model...")
41
+
42
+ # Use a smaller DeepSeek model that fits in Spaces
43
+ model_name = "deepseek-ai/deepseek-r1-distill-qwen-1.5b"
44
+
45
+ # Load tokenizer
46
+ tokenizer = AutoTokenizer.from_pretrained(
47
+ model_name,
48
+ trust_remote_code=True,
49
+ padding_side="left"
50
+ )
51
+
52
+ # Add pad token if it doesn't exist
53
+ if tokenizer.pad_token is None:
54
+ tokenizer.pad_token = tokenizer.eos_token
55
+
56
+ # Load model with appropriate settings for Spaces
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ model_name,
59
+ trust_remote_code=True,
60
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
61
+ device_map="auto" if torch.cuda.is_available() else None,
62
+ low_cpu_mem_usage=True
63
+ )
64
+
65
+ logger.info("Model loaded successfully!")
66
+
67
+ except Exception as e:
68
+ logger.error(f"Error loading model: {str(e)}")
69
+ raise e
70
+
71
+ @app.get("/")
72
+ async def root():
73
+ """Health check endpoint"""
74
+ return {
75
+ "message": "DeepSeek R1 Chat API is running!",
76
+ "status": "healthy",
77
+ "model_loaded": model is not None
78
+ }
79
+
80
+ @app.get("/health")
81
+ async def health_check():
82
+ """Detailed health check"""
83
+ return {
84
+ "status": "healthy",
85
+ "model_loaded": model is not None,
86
+ "tokenizer_loaded": tokenizer is not None,
87
+ "cuda_available": torch.cuda.is_available(),
88
+ "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
89
+ }
90
+
91
+ @app.post("/chat", response_model=ChatResponse)
92
+ async def chat(request: ChatRequest):
93
+ """Chat endpoint for DeepSeek model"""
94
+
95
+ if model is None or tokenizer is None:
96
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
97
+
98
+ try:
99
+ # Prepare the input
100
+ prompt = f"User: {request.message}\nAssistant:"
101
+
102
+ # Tokenize input
103
+ inputs = tokenizer(
104
+ prompt,
105
+ return_tensors="pt",
106
+ padding=True,
107
+ truncation=True,
108
+ max_length=1024
109
+ )
110
+
111
+ # Move to appropriate device
112
+ if torch.cuda.is_available():
113
+ inputs = {k: v.cuda() for k, v in inputs.items()}
114
+
115
+ # Generate response
116
+ with torch.no_grad():
117
+ outputs = model.generate(
118
+ **inputs,
119
+ max_new_tokens=request.max_length,
120
+ temperature=request.temperature,
121
+ top_p=request.top_p,
122
+ do_sample=True,
123
+ pad_token_id=tokenizer.eos_token_id,
124
+ eos_token_id=tokenizer.eos_token_id,
125
+ repetition_penalty=1.1
126
+ )
127
+
128
+ # Decode response
129
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
130
+
131
+ # Extract only the assistant's response
132
+ if "Assistant:" in full_response:
133
+ response = full_response.split("Assistant:")[-1].strip()
134
+ else:
135
+ response = full_response[len(prompt):].strip()
136
+
137
+ return ChatResponse(response=response, status="success")
138
+
139
+ except Exception as e:
140
+ logger.error(f"Error during generation: {str(e)}")
141
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
142
+
143
+ @app.post("/generate")
144
+ async def generate(request: ChatRequest):
145
+ """Alternative generation endpoint"""
146
+ return await chat(request)
147
+
148
+ @app.get("/model-info")
149
+ async def model_info():
150
+ """Get model information"""
151
+ if model is None:
152
+ return {"status": "Model not loaded"}
153
+
154
+ return {
155
+ "model_name": "deepseek-ai/deepseek-r1-distill-qwen-1.5b",
156
+ "model_type": type(model).__name__,
157
+ "tokenizer_type": type(tokenizer).__name__,
158
+ "vocab_size": tokenizer.vocab_size if tokenizer else None,
159
+ "device": str(next(model.parameters()).device) if model else None
160
+ }
161
+
162
+ if __name__ == "__main__":
163
+ import uvicorn
164
+ uvicorn.run(app, host="0.0.0.0", port=7860)