sudipta26889 commited on
Commit
a9cd681
Β·
1 Parent(s): 23672b1

Fix CUDA initialization for HF Spaces Stateless GPU environment

Browse files
Files changed (1) hide show
  1. app.py +53 -18
app.py CHANGED
@@ -23,8 +23,6 @@ except ImportError:
23
  pass
24
 
25
  import gradio as gr
26
- import torch
27
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
28
 
29
  # Try to import MCPClient with fallback
30
  try:
@@ -78,6 +76,7 @@ gpt_oss_tokenizer = None
78
  gpt_oss_model = None
79
  _initialized = False
80
  _init_lock = asyncio.Lock()
 
81
 
82
  def _current_system_prompt(style: str) -> str:
83
  """Get the system prompt with style suffix."""
@@ -100,29 +99,59 @@ def get_mcp_client(model_id: str, provider: str, api_key: Optional[str]) -> MCPC
100
  mcp_client = MCPClient(model=model_id, provider=provider, api_key=api_key)
101
  return mcp_client
102
 
103
- def get_gpt_oss_model_and_tokenizer():
104
- """Get or create GPT-OSS-20B model and tokenizer."""
105
  global gpt_oss_tokenizer, gpt_oss_model
106
- if gpt_oss_tokenizer is None or gpt_oss_model is None:
 
 
 
 
 
 
 
 
 
 
107
  try:
 
 
 
 
108
  print("πŸ”„ Loading GPT-OSS-20B tokenizer...")
109
  gpt_oss_tokenizer = AutoTokenizer.from_pretrained(
110
  "openai/gpt-oss-20b",
111
  trust_remote_code=True,
112
  )
 
113
  print("πŸ”„ Loading GPT-OSS-20B model...")
 
 
 
 
 
114
  gpt_oss_model = AutoModelForCausalLM.from_pretrained(
115
  "openai/gpt-oss-20b",
116
- torch_dtype="auto",
117
- device_map="auto",
118
  trust_remote_code=True,
119
  low_cpu_mem_usage=True,
 
 
120
  )
 
 
 
 
121
  print("βœ… GPT-OSS-20B loaded successfully!")
 
 
122
  except Exception as e:
123
  print(f"❌ Failed to load GPT-OSS-20B: {e}")
 
 
 
124
  raise e
125
- return gpt_oss_tokenizer, gpt_oss_model
126
 
127
  async def ensure_mcp_init(model_id: str, provider: str, api_key: Optional[str]):
128
  """Initialize MCP server connection."""
@@ -208,7 +237,8 @@ async def stream_answer(
208
  # Handle GPT-OSS-20B
209
  if USE_GPT_OSS:
210
  try:
211
- tokenizer, model = get_gpt_oss_model_and_tokenizer()
 
212
 
213
  # Convert messages to GPT-OSS format with reasoning
214
  gpt_oss_messages = []
@@ -228,18 +258,23 @@ async def stream_answer(
228
  tokenize=True,
229
  return_dict=True,
230
  return_tensors="pt",
231
- ).to(model.device)
 
 
 
 
232
 
233
  # Generate with timeout protection
234
  try:
235
- outputs = model.generate(
236
- **inputs,
237
- max_new_tokens=512,
238
- do_sample=True,
239
- temperature=0.7,
240
- pad_token_id=tokenizer.eos_token_id,
241
- max_time=60.0, # 60 second timeout
242
- )
 
243
  except Exception as gen_error:
244
  yield {
245
  "delta": f"❌ Generation Error: {str(gen_error)}",
 
23
  pass
24
 
25
  import gradio as gr
 
 
26
 
27
  # Try to import MCPClient with fallback
28
  try:
 
76
  gpt_oss_model = None
77
  _initialized = False
78
  _init_lock = asyncio.Lock()
79
+ _model_loading_lock = asyncio.Lock()
80
 
81
  def _current_system_prompt(style: str) -> str:
82
  """Get the system prompt with style suffix."""
 
99
  mcp_client = MCPClient(model=model_id, provider=provider, api_key=api_key)
100
  return mcp_client
101
 
102
+ async def get_gpt_oss_model_and_tokenizer():
103
+ """Get or create GPT-OSS-20B model and tokenizer with proper CUDA handling."""
104
  global gpt_oss_tokenizer, gpt_oss_model
105
+
106
+ # Check if already loaded
107
+ if gpt_oss_tokenizer is not None and gpt_oss_model is not None:
108
+ return gpt_oss_tokenizer, gpt_oss_model
109
+
110
+ # Use lock to prevent multiple simultaneous loads
111
+ async with _model_loading_lock:
112
+ # Double-check after acquiring lock
113
+ if gpt_oss_tokenizer is not None and gpt_oss_model is not None:
114
+ return gpt_oss_tokenizer, gpt_oss_model
115
+
116
  try:
117
+ # Import here to avoid CUDA initialization in main process
118
+ import torch
119
+ from transformers import AutoTokenizer, AutoModelForCausalLM
120
+
121
  print("πŸ”„ Loading GPT-OSS-20B tokenizer...")
122
  gpt_oss_tokenizer = AutoTokenizer.from_pretrained(
123
  "openai/gpt-oss-20b",
124
  trust_remote_code=True,
125
  )
126
+
127
  print("πŸ”„ Loading GPT-OSS-20B model...")
128
+ # For HF Spaces with Stateless GPU, use specific device mapping
129
+ device_map = "auto"
130
+ if os.environ.get("SPACE_ZERO_GPU"):
131
+ device_map = "cpu" # Force CPU for ZeroGPU spaces
132
+
133
  gpt_oss_model = AutoModelForCausalLM.from_pretrained(
134
  "openai/gpt-oss-20b",
135
+ torch_dtype=torch.float16, # Use float16 for memory efficiency
136
+ device_map=device_map,
137
  trust_remote_code=True,
138
  low_cpu_mem_usage=True,
139
+ # Disable gradient computation for inference
140
+ torch_dtype=torch.float16,
141
  )
142
+
143
+ # Set model to evaluation mode
144
+ gpt_oss_model.eval()
145
+
146
  print("βœ… GPT-OSS-20B loaded successfully!")
147
+ return gpt_oss_tokenizer, gpt_oss_model
148
+
149
  except Exception as e:
150
  print(f"❌ Failed to load GPT-OSS-20B: {e}")
151
+ # Reset globals on error
152
+ gpt_oss_tokenizer = None
153
+ gpt_oss_model = None
154
  raise e
 
155
 
156
  async def ensure_mcp_init(model_id: str, provider: str, api_key: Optional[str]):
157
  """Initialize MCP server connection."""
 
237
  # Handle GPT-OSS-20B
238
  if USE_GPT_OSS:
239
  try:
240
+ # Lazy load model only when needed
241
+ tokenizer, model = await get_gpt_oss_model_and_tokenizer()
242
 
243
  # Convert messages to GPT-OSS format with reasoning
244
  gpt_oss_messages = []
 
258
  tokenize=True,
259
  return_dict=True,
260
  return_tensors="pt",
261
+ )
262
+
263
+ # Move inputs to model device
264
+ if hasattr(model, 'device'):
265
+ inputs = {k: v.to(model.device) if hasattr(v, 'to') else v for k, v in inputs.items()}
266
 
267
  # Generate with timeout protection
268
  try:
269
+ with torch.no_grad(): # Disable gradients for inference
270
+ outputs = model.generate(
271
+ **inputs,
272
+ max_new_tokens=512,
273
+ do_sample=True,
274
+ temperature=0.7,
275
+ pad_token_id=tokenizer.eos_token_id,
276
+ max_time=60.0, # 60 second timeout
277
+ )
278
  except Exception as gen_error:
279
  yield {
280
  "delta": f"❌ Generation Error: {str(gen_error)}",