sudipta26889 commited on
Commit
8fd3d30
Β·
1 Parent(s): 079d1c0

Switch to Qwen3-30B model with fixed duplicate torch_dtype parameter and improved CPU loading

Browse files
Files changed (1) hide show
  1. app.py +28 -36
app.py CHANGED
@@ -60,8 +60,8 @@ GRADIO_DOCS_MCP_SSE = os.environ.get(
60
  "https://gradio-docs-mcp.hf.space/gradio_api/mcp/sse",
61
  )
62
 
63
- # Model configuration - local GPT-OSS-20B loading
64
- MODEL_ID = "openai/gpt-oss-20b"
65
  PROVIDER = os.environ.get("CHAT_PROVIDER", "auto")
66
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
67
 
@@ -130,22 +130,18 @@ async def get_gpt_oss_model_and_tokenizer():
130
  torch.cuda.is_available = lambda: False
131
  torch.cuda.device_count = lambda: 0
132
 
133
- print("πŸ”„ Loading GPT-OSS-20B tokenizer...")
134
- gpt_oss_tokenizer = AutoTokenizer.from_pretrained(
135
- MODEL_ID,
136
- trust_remote_code=True,
137
- )
138
 
139
- print("πŸ”„ Loading GPT-OSS-20B model (CPU-only)...")
140
- # Strict CPU-only loading configuration
141
  gpt_oss_model = AutoModelForCausalLM.from_pretrained(
142
  MODEL_ID,
143
  torch_dtype=torch.float32, # Use float32 for CPU compatibility
144
  device_map=None, # Don't use device mapping
145
  trust_remote_code=True,
146
  low_cpu_mem_usage=True,
147
- # Force CPU placement
148
- **{"torch_dtype": torch.float32, "device": "cpu"}
149
  )
150
 
151
  # Explicitly move to CPU
@@ -154,11 +150,11 @@ async def get_gpt_oss_model_and_tokenizer():
154
  # Set model to evaluation mode
155
  gpt_oss_model.eval()
156
 
157
- print("βœ… GPT-OSS-20B loaded successfully on CPU!")
158
  return gpt_oss_tokenizer, gpt_oss_model
159
 
160
  except Exception as e:
161
- print(f"❌ Failed to load GPT-OSS-20B: {e}")
162
  # Reset globals on error
163
  gpt_oss_tokenizer = None
164
  gpt_oss_model = None
@@ -170,36 +166,34 @@ async def generate_with_gpt_oss(messages: List[Dict[str, Any]]) -> str:
170
  # Lazy load model only when needed
171
  tokenizer, model = await get_gpt_oss_model_and_tokenizer()
172
 
173
- # Convert messages to GPT-OSS format with reasoning
174
- gpt_oss_messages = []
175
  for msg in messages:
176
  if msg["role"] == "system":
177
- gpt_oss_messages.append({
 
178
  "role": "user",
179
- "content": f"Reasoning: high\n\n{msg['content']}"
180
  })
181
  else:
182
- gpt_oss_messages.append(msg)
183
 
184
  # Apply chat template and generate
185
- inputs = tokenizer.apply_chat_template(
186
- gpt_oss_messages,
 
187
  add_generation_prompt=True,
188
- tokenize=True,
189
- return_dict=True,
190
- return_tensors="pt",
191
  )
192
 
193
- # Ensure inputs are on CPU
194
- inputs = {k: v.to("cpu") if hasattr(v, "to") else v for k, v in inputs.items()}
195
 
196
  # Generate with timeout protection
197
  try:
198
  import torch
199
  with torch.no_grad(): # Disable gradients for inference
200
- outputs = model.generate(
201
- **inputs,
202
- max_new_tokens=512,
203
  do_sample=True,
204
  temperature=0.7,
205
  pad_token_id=tokenizer.eos_token_id,
@@ -209,15 +203,13 @@ async def generate_with_gpt_oss(messages: List[Dict[str, Any]]) -> str:
209
  raise Exception(f"Generation Error: {str(gen_error)}")
210
 
211
  # Decode the generated text
212
- generated_text = tokenizer.decode(
213
- outputs[0][inputs["input_ids"].shape[-1]:],
214
- skip_special_tokens=True
215
- )
216
 
217
  return generated_text
218
 
219
  except Exception as e:
220
- raise Exception(f"GPT-OSS-20B Error: {str(e)}")
221
 
222
  async def ensure_mcp_init(model_id: str, provider: str, api_key: Optional[str]):
223
  """Initialize MCP server connection."""
@@ -445,7 +437,7 @@ async def stream_answer(
445
  with gr.Blocks(fill_height=True) as demo:
446
  gr.Markdown(
447
  "# πŸ€– Gradio Docs Chat\n"
448
- "Ask anything about **Gradio**. Powered by GPT-OSS-20B with local transformers loading."
449
  )
450
 
451
  with gr.Row():
@@ -475,7 +467,7 @@ with gr.Blocks(fill_height=True) as demo:
475
  model_info = gr.Markdown(
476
  f"**Model:** `{MODEL_ID}` (Local Loading) \n"
477
  f"**Provider:** `{PROVIDER}` \n"
478
- "_(CPU-only loading to avoid CUDA issues)_"
479
  )
480
 
481
  with gr.Accordion("πŸ›  Tool Activity (live)", open=True):
@@ -506,7 +498,7 @@ with gr.Blocks(fill_height=True) as demo:
506
  # ----------------------------
507
  # Launch App
508
  # ----------------------------
509
- print(f"πŸš€ Starting Gradio Docs Chat with GPT-OSS-20B (Local Loading)")
510
  print(f"πŸ“ Model: {MODEL_ID}")
511
  print(f"πŸ”— MCP Server: {GRADIO_DOCS_MCP_SSE}")
512
 
 
60
  "https://gradio-docs-mcp.hf.space/gradio_api/mcp/sse",
61
  )
62
 
63
+ # Model configuration - local Qwen model loading (more efficient than GPT-OSS-20B)
64
+ MODEL_ID = "Qwen/Qwen3-30B-A3B-Instruct-2507"
65
  PROVIDER = os.environ.get("CHAT_PROVIDER", "auto")
66
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
67
 
 
130
  torch.cuda.is_available = lambda: False
131
  torch.cuda.device_count = lambda: 0
132
 
133
+ print("πŸ”„ Loading Qwen tokenizer...")
134
+ gpt_oss_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
135
 
136
+ print("πŸ”„ Loading Qwen model (CPU-only)...")
137
+ # Clean CPU-only loading configuration (no duplicate parameters)
138
  gpt_oss_model = AutoModelForCausalLM.from_pretrained(
139
  MODEL_ID,
140
  torch_dtype=torch.float32, # Use float32 for CPU compatibility
141
  device_map=None, # Don't use device mapping
142
  trust_remote_code=True,
143
  low_cpu_mem_usage=True,
144
+
 
145
  )
146
 
147
  # Explicitly move to CPU
 
150
  # Set model to evaluation mode
151
  gpt_oss_model.eval()
152
 
153
+ print("βœ… Qwen model loaded successfully on CPU!")
154
  return gpt_oss_tokenizer, gpt_oss_model
155
 
156
  except Exception as e:
157
+ print(f"❌ Failed to load Qwen model: {e}")
158
  # Reset globals on error
159
  gpt_oss_tokenizer = None
160
  gpt_oss_model = None
 
166
  # Lazy load model only when needed
167
  tokenizer, model = await get_gpt_oss_model_and_tokenizer()
168
 
169
+ # Convert messages to Qwen format
170
+ qwen_messages = []
171
  for msg in messages:
172
  if msg["role"] == "system":
173
+ # Convert system message to user message for Qwen
174
+ qwen_messages.append({
175
  "role": "user",
176
+ "content": f"System: {msg['content']}"
177
  })
178
  else:
179
+ qwen_messages.append(msg)
180
 
181
  # Apply chat template and generate
182
+ text = tokenizer.apply_chat_template(
183
+ qwen_messages,
184
+ tokenize=False,
185
  add_generation_prompt=True,
 
 
 
186
  )
187
 
188
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
189
 
190
  # Generate with timeout protection
191
  try:
192
  import torch
193
  with torch.no_grad(): # Disable gradients for inference
194
+ generated_ids = model.generate(
195
+ **model_inputs,
196
+ max_new_tokens=512, # Reduced from 16384 for faster response
197
  do_sample=True,
198
  temperature=0.7,
199
  pad_token_id=tokenizer.eos_token_id,
 
203
  raise Exception(f"Generation Error: {str(gen_error)}")
204
 
205
  # Decode the generated text
206
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
207
+ generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)
 
 
208
 
209
  return generated_text
210
 
211
  except Exception as e:
212
+ raise Exception(f"Qwen Model Error: {str(e)}")
213
 
214
  async def ensure_mcp_init(model_id: str, provider: str, api_key: Optional[str]):
215
  """Initialize MCP server connection."""
 
437
  with gr.Blocks(fill_height=True) as demo:
438
  gr.Markdown(
439
  "# πŸ€– Gradio Docs Chat\n"
440
+ "Ask anything about **Gradio**. Powered by Qwen3-30B with local transformers loading."
441
  )
442
 
443
  with gr.Row():
 
467
  model_info = gr.Markdown(
468
  f"**Model:** `{MODEL_ID}` (Local Loading) \n"
469
  f"**Provider:** `{PROVIDER}` \n"
470
+ "_(CPU-only loading for stable inference)_"
471
  )
472
 
473
  with gr.Accordion("πŸ›  Tool Activity (live)", open=True):
 
498
  # ----------------------------
499
  # Launch App
500
  # ----------------------------
501
+ print(f"πŸš€ Starting Gradio Docs Chat with Qwen3-30B (Local Loading)")
502
  print(f"πŸ“ Model: {MODEL_ID}")
503
  print(f"πŸ”— MCP Server: {GRADIO_DOCS_MCP_SSE}")
504