sudipta26889 commited on
Commit
ccd721b
Β·
1 Parent(s): f9d9584

Switch to local GPT-OSS-20B loading with CPU-only approach to avoid CUDA issues

Browse files
Files changed (2) hide show
  1. app.py +107 -125
  2. requirements.txt +4 -1
app.py CHANGED
@@ -1,8 +1,8 @@
1
  # app.py
2
  # Hugging Face Space: Gradio Docs Chat with GPT-OSS-20B and MCP Integration
3
  # Features:
4
- # β€’ GPT-OSS-20B with harmony format for excellent reasoning
5
- # β€’ Fallback to reliable smaller models when GPT-OSS-20B is paused
6
  # β€’ MCP tool-calling for Gradio docs access
7
  # β€’ Streaming responses with live tool logs
8
  # β€’ Optional "Concise / Detailed" answer styles
@@ -26,7 +26,6 @@ except ImportError:
26
  pass
27
 
28
  import gradio as gr
29
- import requests
30
 
31
  # Try to import MCPClient with fallback
32
  try:
@@ -56,13 +55,8 @@ GRADIO_DOCS_MCP_SSE = os.environ.get(
56
  "https://gradio-docs-mcp.hf.space/gradio_api/mcp/sse",
57
  )
58
 
59
- # Model configuration - primary and fallback models
60
- PRIMARY_MODEL = "openai/gpt-oss-20b"
61
- FALLBACK_MODELS = [
62
- "microsoft/DialoGPT-medium",
63
- "microsoft/DialoGPT-large",
64
- "microsoft/DialoGPT-small"
65
- ]
66
  PROVIDER = os.environ.get("CHAT_PROVIDER", "auto")
67
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
68
 
@@ -81,9 +75,11 @@ DETAILED_SUFFIX = " Provide a detailed, step-by-step answer with short code wher
81
  # Model Clients (lazy initialization)
82
  # ----------------------------
83
  mcp_client: Optional[MCPClient] = None
 
 
84
  _initialized = False
85
  _init_lock = asyncio.Lock()
86
- _current_model = PRIMARY_MODEL
87
 
88
  def _current_system_prompt(style: str) -> str:
89
  """Get the system prompt with style suffix."""
@@ -91,8 +87,10 @@ def _current_system_prompt(style: str) -> str:
91
 
92
  def _reset_clients():
93
  """Reset all global clients."""
94
- global mcp_client, _initialized
95
  mcp_client = None
 
 
96
  _initialized = False
97
 
98
  def get_mcp_client(model_id: str, provider: str, api_key: Optional[str]) -> MCPClient:
@@ -104,108 +102,105 @@ def get_mcp_client(model_id: str, provider: str, api_key: Optional[str]) -> MCPC
104
  mcp_client = MCPClient(model=model_id, provider=provider, api_key=api_key)
105
  return mcp_client
106
 
107
- async def call_inference_api(messages: List[Dict[str, Any]], model_id: str) -> str:
108
- """Call model via HF Inference API with fallback support."""
109
- if not HF_TOKEN:
110
- raise ValueError("HF_TOKEN or HUGGING_FACE_HUB_TOKEN required for inference API")
111
 
112
- # Convert messages to appropriate format based on model
113
- if "gpt-oss" in model_id.lower():
114
- # GPT-OSS format with reasoning
115
- formatted_messages = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  for msg in messages:
117
  if msg["role"] == "system":
118
- formatted_messages.append({
119
  "role": "user",
120
  "content": f"Reasoning: high\n\n{msg['content']}"
121
  })
122
  else:
123
- formatted_messages.append(msg)
124
- else:
125
- # Standard chat format for other models
126
- formatted_messages = messages
127
-
128
- # Prepare the request payload
129
- payload = {
130
- "inputs": formatted_messages,
131
- "parameters": {
132
- "max_new_tokens": 512,
133
- "temperature": 0.7,
134
- "do_sample": True,
135
- "return_full_text": False,
136
- }
137
- }
138
-
139
- headers = {
140
- "Authorization": f"Bearer {HF_TOKEN}",
141
- "Content-Type": "application/json"
142
- }
143
-
144
- # Make the API call
145
- try:
146
- response = requests.post(
147
- f"https://api-inference.huggingface.co/models/{model_id}",
148
- headers=headers,
149
- json=payload,
150
- timeout=120 # 2 minute timeout
151
  )
152
 
153
- if response.status_code == 200:
154
- result = response.json()
155
- if isinstance(result, list) and len(result) > 0:
156
- return result[0].get("generated_text", "")
157
- elif isinstance(result, dict):
158
- return result.get("generated_text", "")
159
- else:
160
- return str(result)
161
- else:
162
- error_msg = f"API Error {response.status_code}: {response.text}"
163
- print(f"❌ {error_msg}")
164
- raise Exception(error_msg)
165
-
166
- except requests.exceptions.Timeout:
167
- raise Exception("Request timed out after 120 seconds")
168
- except requests.exceptions.RequestException as e:
169
- raise Exception(f"Request failed: {str(e)}")
170
-
171
- def reset_to_primary_model():
172
- """Reset to use the primary model on next request."""
173
- global _current_model
174
- _current_model = PRIMARY_MODEL
175
- print(f"πŸ”„ Reset to primary model: {PRIMARY_MODEL}")
176
-
177
- async def call_model_with_fallback(messages: List[Dict[str, Any]]) -> Tuple[str, str]:
178
- """Call model with automatic fallback to smaller models."""
179
- global _current_model
180
-
181
- # Try current model first (could be primary or a previously successful fallback)
182
- try:
183
- print(f"πŸ”„ Trying current model: {_current_model}")
184
- result = await call_inference_api(messages, _current_model)
185
- return result, _current_model
186
- except Exception as e:
187
- error_msg = str(e)
188
- print(f"❌ {_current_model} failed: {error_msg}")
189
 
190
- # If current model fails, try all models in order (primary + fallbacks)
191
- all_models = [PRIMARY_MODEL] + FALLBACK_MODELS
 
 
 
192
 
193
- for model in all_models:
194
- if model == _current_model: # Skip the one we just tried
195
- continue
196
-
197
- try:
198
- print(f"πŸ”„ Trying model: {model}")
199
- result = await call_inference_api(messages, model)
200
- _current_model = model # Update current model
201
- print(f"βœ… Successfully using model: {model}")
202
- return result, model
203
- except Exception as model_error:
204
- print(f"❌ {model} failed: {str(model_error)}")
205
- continue
206
 
207
- # If all models fail, provide a helpful error message
208
- raise Exception(f"All models failed. Primary model ({PRIMARY_MODEL}) and fallback models are unavailable. Please try again later.")
209
 
210
  async def ensure_mcp_init(model_id: str, provider: str, api_key: Optional[str]):
211
  """Initialize MCP server connection."""
@@ -288,15 +283,11 @@ async def stream_answer(
288
  tool_log: List[str] = []
289
  citations: List[Tuple[str, Optional[str]]] = []
290
 
291
- # Handle GPT-OSS-20B via Inference API with fallback
292
  if USE_GPT_OSS:
293
  try:
294
- # Call the inference API with fallback
295
- generated_text, used_model = await call_model_with_fallback(messages_for_llm)
296
-
297
- # Add model info to tool log
298
- if used_model != PRIMARY_MODEL:
299
- _append_log(tool_log, f"⚠️ Using fallback model: {used_model}")
300
 
301
  # Stream character by character
302
  for char in generated_text:
@@ -308,7 +299,7 @@ async def stream_answer(
308
 
309
  except Exception as e:
310
  yield {
311
- "delta": f"❌ Model Error: {str(e)}",
312
  "tool_log": _format_tool_log(tool_log),
313
  "citations": _format_citations(citations),
314
  }
@@ -437,7 +428,7 @@ async def stream_answer(
437
  with gr.Blocks(fill_height=True) as demo:
438
  gr.Markdown(
439
  "# πŸ€– Gradio Docs Chat\n"
440
- "Ask anything about **Gradio**. Powered by GPT-OSS-20B with automatic fallback to reliable models."
441
  )
442
 
443
  with gr.Row():
@@ -465,12 +456,10 @@ with gr.Blocks(fill_height=True) as demo:
465
  value="Detailed",
466
  )
467
  model_info = gr.Markdown(
468
- f"**Primary Model:** `{PRIMARY_MODEL}` \n"
469
- f"**Current Model:** `{_current_model}` \n"
470
  f"**Provider:** `{PROVIDER}` \n"
471
- "_(Auto-fallback to smaller models if primary is paused)_"
472
  )
473
- reset_model_btn = gr.Button("πŸ”„ Reset to Primary Model", variant="secondary", size="sm")
474
 
475
  with gr.Accordion("πŸ›  Tool Activity (live)", open=True):
476
  tool_log_md = gr.Markdown("_No tool activity yet._")
@@ -487,28 +476,21 @@ with gr.Blocks(fill_height=True) as demo:
487
 
488
  messages_for_llm = to_llm_messages(history_msgs[:-1], user_msg, style_choice)
489
 
490
- async for chunk in stream_answer(messages_for_llm, PRIMARY_MODEL, PROVIDER, HF_TOKEN):
491
  delta = chunk.get("delta", "")
492
  if delta:
493
  history_msgs[-1]["content"] += delta
494
  yield history_msgs, gr.update(value=chunk.get("tool_log", "")), gr.update(value=chunk.get("citations", ""))
495
 
496
- def on_reset_model():
497
- """Reset to primary model and update UI."""
498
- reset_to_primary_model()
499
- return gr.update(value=f"**Primary Model:** `{PRIMARY_MODEL}` \n**Current Model:** `{_current_model}` \n**Provider:** `{PROVIDER}` \n_(Auto-fallback to smaller models if primary is paused)_")
500
-
501
  # Wire up event handlers
502
  msg.submit(on_submit, inputs=[msg, chat, style], outputs=[chat, tool_log_md, citations_md], queue=True)
503
  send_btn.click(on_submit, inputs=[msg, chat, style], outputs=[chat, tool_log_md, citations_md], queue=True)
504
- reset_model_btn.click(on_reset_model, outputs=[model_info])
505
 
506
  # ----------------------------
507
  # Launch App
508
  # ----------------------------
509
- print(f"πŸš€ Starting Gradio Docs Chat with GPT-OSS-20B + Fallback Models")
510
- print(f"πŸ“ Primary Model: {PRIMARY_MODEL}")
511
- print(f"πŸ“ Fallback Models: {', '.join(FALLBACK_MODELS)}")
512
  print(f"πŸ”— MCP Server: {GRADIO_DOCS_MCP_SSE}")
513
 
514
  demo = demo.queue(max_size=32)
 
1
  # app.py
2
  # Hugging Face Space: Gradio Docs Chat with GPT-OSS-20B and MCP Integration
3
  # Features:
4
+ # β€’ GPT-OSS-20B with local transformers loading for fast inference
5
+ # β€’ CPU-only loading to avoid CUDA initialization issues
6
  # β€’ MCP tool-calling for Gradio docs access
7
  # β€’ Streaming responses with live tool logs
8
  # β€’ Optional "Concise / Detailed" answer styles
 
26
  pass
27
 
28
  import gradio as gr
 
29
 
30
  # Try to import MCPClient with fallback
31
  try:
 
55
  "https://gradio-docs-mcp.hf.space/gradio_api/mcp/sse",
56
  )
57
 
58
+ # Model configuration - local GPT-OSS-20B loading
59
+ MODEL_ID = "openai/gpt-oss-20b"
 
 
 
 
 
60
  PROVIDER = os.environ.get("CHAT_PROVIDER", "auto")
61
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
62
 
 
75
  # Model Clients (lazy initialization)
76
  # ----------------------------
77
  mcp_client: Optional[MCPClient] = None
78
+ gpt_oss_tokenizer = None
79
+ gpt_oss_model = None
80
  _initialized = False
81
  _init_lock = asyncio.Lock()
82
+ _model_loading_lock = asyncio.Lock()
83
 
84
  def _current_system_prompt(style: str) -> str:
85
  """Get the system prompt with style suffix."""
 
87
 
88
  def _reset_clients():
89
  """Reset all global clients."""
90
+ global mcp_client, gpt_oss_tokenizer, gpt_oss_model, _initialized
91
  mcp_client = None
92
+ gpt_oss_tokenizer = None
93
+ gpt_oss_model = None
94
  _initialized = False
95
 
96
  def get_mcp_client(model_id: str, provider: str, api_key: Optional[str]) -> MCPClient:
 
102
  mcp_client = MCPClient(model=model_id, provider=provider, api_key=api_key)
103
  return mcp_client
104
 
105
+ async def get_gpt_oss_model_and_tokenizer():
106
+ """Get or create GPT-OSS-20B model and tokenizer with CPU-only loading."""
107
+ global gpt_oss_tokenizer, gpt_oss_model
 
108
 
109
+ # Check if already loaded
110
+ if gpt_oss_tokenizer is not None and gpt_oss_model is not None:
111
+ return gpt_oss_tokenizer, gpt_oss_model
112
+
113
+ # Use lock to prevent multiple simultaneous loads
114
+ async with _model_loading_lock:
115
+ # Double-check after acquiring lock
116
+ if gpt_oss_tokenizer is not None and gpt_oss_model is not None:
117
+ return gpt_oss_tokenizer, gpt_oss_model
118
+
119
+ try:
120
+ # Import here to avoid CUDA initialization in main process
121
+ import torch
122
+ from transformers import AutoTokenizer, AutoModelForCausalLM
123
+
124
+ print("πŸ”„ Loading GPT-OSS-20B tokenizer...")
125
+ gpt_oss_tokenizer = AutoTokenizer.from_pretrained(
126
+ MODEL_ID,
127
+ trust_remote_code=True,
128
+ )
129
+
130
+ print("πŸ”„ Loading GPT-OSS-20B model (CPU-only)...")
131
+ # Force CPU-only loading to avoid CUDA initialization issues
132
+ gpt_oss_model = AutoModelForCausalLM.from_pretrained(
133
+ MODEL_ID,
134
+ torch_dtype=torch.float32, # Use float32 for CPU compatibility
135
+ device_map="cpu", # Force CPU loading
136
+ trust_remote_code=True,
137
+ low_cpu_mem_usage=True,
138
+ )
139
+
140
+ # Set model to evaluation mode
141
+ gpt_oss_model.eval()
142
+
143
+ print("βœ… GPT-OSS-20B loaded successfully on CPU!")
144
+ return gpt_oss_tokenizer, gpt_oss_model
145
+
146
+ except Exception as e:
147
+ print(f"❌ Failed to load GPT-OSS-20B: {e}")
148
+ # Reset globals on error
149
+ gpt_oss_tokenizer = None
150
+ gpt_oss_model = None
151
+ raise e
152
+
153
+ async def generate_with_gpt_oss(messages: List[Dict[str, Any]]) -> str:
154
+ """Generate response using local GPT-OSS-20B model."""
155
+ try:
156
+ # Lazy load model only when needed
157
+ tokenizer, model = await get_gpt_oss_model_and_tokenizer()
158
+
159
+ # Convert messages to GPT-OSS format with reasoning
160
+ gpt_oss_messages = []
161
  for msg in messages:
162
  if msg["role"] == "system":
163
+ gpt_oss_messages.append({
164
  "role": "user",
165
  "content": f"Reasoning: high\n\n{msg['content']}"
166
  })
167
  else:
168
+ gpt_oss_messages.append(msg)
169
+
170
+ # Apply chat template and generate
171
+ inputs = tokenizer.apply_chat_template(
172
+ gpt_oss_messages,
173
+ add_generation_prompt=True,
174
+ tokenize=True,
175
+ return_dict=True,
176
+ return_tensors="pt",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  )
178
 
179
+ # Generate with timeout protection
180
+ try:
181
+ import torch
182
+ with torch.no_grad(): # Disable gradients for inference
183
+ outputs = model.generate(
184
+ **inputs,
185
+ max_new_tokens=512,
186
+ do_sample=True,
187
+ temperature=0.7,
188
+ pad_token_id=tokenizer.eos_token_id,
189
+ max_time=60.0, # 60 second timeout
190
+ )
191
+ except Exception as gen_error:
192
+ raise Exception(f"Generation Error: {str(gen_error)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
+ # Decode the generated text
195
+ generated_text = tokenizer.decode(
196
+ outputs[0][inputs["input_ids"].shape[-1]:],
197
+ skip_special_tokens=True
198
+ )
199
 
200
+ return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ except Exception as e:
203
+ raise Exception(f"GPT-OSS-20B Error: {str(e)}")
204
 
205
  async def ensure_mcp_init(model_id: str, provider: str, api_key: Optional[str]):
206
  """Initialize MCP server connection."""
 
283
  tool_log: List[str] = []
284
  citations: List[Tuple[str, Optional[str]]] = []
285
 
286
+ # Handle GPT-OSS-20B via local model
287
  if USE_GPT_OSS:
288
  try:
289
+ # Generate response using local model
290
+ generated_text = await generate_with_gpt_oss(messages_for_llm)
 
 
 
 
291
 
292
  # Stream character by character
293
  for char in generated_text:
 
299
 
300
  except Exception as e:
301
  yield {
302
+ "delta": f"❌ {str(e)}",
303
  "tool_log": _format_tool_log(tool_log),
304
  "citations": _format_citations(citations),
305
  }
 
428
  with gr.Blocks(fill_height=True) as demo:
429
  gr.Markdown(
430
  "# πŸ€– Gradio Docs Chat\n"
431
+ "Ask anything about **Gradio**. Powered by GPT-OSS-20B with local transformers loading."
432
  )
433
 
434
  with gr.Row():
 
456
  value="Detailed",
457
  )
458
  model_info = gr.Markdown(
459
+ f"**Model:** `{MODEL_ID}` (Local Loading) \n"
 
460
  f"**Provider:** `{PROVIDER}` \n"
461
+ "_(CPU-only loading to avoid CUDA issues)_"
462
  )
 
463
 
464
  with gr.Accordion("πŸ›  Tool Activity (live)", open=True):
465
  tool_log_md = gr.Markdown("_No tool activity yet._")
 
476
 
477
  messages_for_llm = to_llm_messages(history_msgs[:-1], user_msg, style_choice)
478
 
479
+ async for chunk in stream_answer(messages_for_llm, MODEL_ID, PROVIDER, HF_TOKEN):
480
  delta = chunk.get("delta", "")
481
  if delta:
482
  history_msgs[-1]["content"] += delta
483
  yield history_msgs, gr.update(value=chunk.get("tool_log", "")), gr.update(value=chunk.get("citations", ""))
484
 
 
 
 
 
 
485
  # Wire up event handlers
486
  msg.submit(on_submit, inputs=[msg, chat, style], outputs=[chat, tool_log_md, citations_md], queue=True)
487
  send_btn.click(on_submit, inputs=[msg, chat, style], outputs=[chat, tool_log_md, citations_md], queue=True)
 
488
 
489
  # ----------------------------
490
  # Launch App
491
  # ----------------------------
492
+ print(f"πŸš€ Starting Gradio Docs Chat with GPT-OSS-20B (Local Loading)")
493
+ print(f"πŸ“ Model: {MODEL_ID}")
 
494
  print(f"πŸ”— MCP Server: {GRADIO_DOCS_MCP_SSE}")
495
 
496
  demo = demo.queue(max_size=32)
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  gradio>=5.0.0
2
  huggingface_hub>=0.34.0
3
  python-dotenv>=1.0.1
4
- requests>=2.31.0
 
 
 
 
1
  gradio>=5.0.0
2
  huggingface_hub>=0.34.0
3
  python-dotenv>=1.0.1
4
+ transformers>=4.40.0
5
+ torch>=2.0.0
6
+ accelerate>=0.20.0
7
+ safetensors>=0.4.0