Spestly commited on
Commit
6263300
·
verified ·
1 Parent(s): a2633ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -259
app.py CHANGED
@@ -4,11 +4,6 @@ import torch
4
  import time
5
  import spaces
6
  import re
7
- import logging
8
-
9
- # Configure logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
 
13
  # Model configurations
14
  MODELS = {
@@ -23,110 +18,72 @@ MODELS = {
23
  "Athena-1 7B": "Spestly/Athena-1-7B"
24
  }
25
 
26
- # Models that need the enable_thinking parameter
27
- THINKING_ENABLED_MODELS = ["Spestly/Athena-R3X-4B"]
28
-
29
- # Cache for loaded models
30
- loaded_models = {}
31
-
32
- @spaces.GPU
33
- def load_model(model_id):
34
- """Load model and tokenizer once and cache them"""
35
- try:
36
- if model_id not in loaded_models:
37
- logger.info(f"🚀 Loading {model_id}...")
38
- start_time = time.time()
39
-
40
- tokenizer = AutoTokenizer.from_pretrained(model_id)
41
- if tokenizer.pad_token is None:
42
- tokenizer.pad_token = tokenizer.eos_token
43
-
44
- model = AutoModelForCausalLM.from_pretrained(
45
- model_id,
46
- torch_dtype=torch.float16,
47
- device_map="auto",
48
- trust_remote_code=True
49
- )
50
-
51
- load_time = time.time() - start_time
52
- logger.info(f"✅ Model loaded in {load_time:.2f}s")
53
- loaded_models[model_id] = (model, tokenizer, load_time)
54
-
55
- return loaded_models[model_id]
56
- except Exception as e:
57
- logger.error(f"Error loading model {model_id}: {str(e)}")
58
- raise gr.Error(f"Failed to load model {model_id}. Please try another model.")
59
-
60
  @spaces.GPU
61
  def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
62
- """Generate response using the specified model"""
63
- try:
64
- model, tokenizer, _ = load_model(model_id)
65
-
66
- # Build messages in proper chat format
67
- messages = []
68
- system_prompt = (
69
- "You are Athena, a helpful, harmless, and honest AI assistant. "
70
- "You provide clear, accurate, and concise responses to user questions. "
71
- "You are knowledgeable across many domains and always aim to be respectful and helpful. "
72
- "You are finetuned by Aayan Mishra"
73
- )
74
- messages.append({"role": "system", "content": system_prompt})
75
-
76
- # Add conversation history
77
- for msg in conversation:
78
- messages.append(msg)
 
 
 
 
 
 
 
79
 
80
- # Add current user message
81
- messages.append({"role": "user", "content": user_message})
 
82
 
83
- # Check if this model needs the enable_thinking parameter
84
- if model_id in THINKING_ENABLED_MODELS:
85
- prompt = tokenizer.apply_chat_template(
86
- messages,
87
- tokenize=False,
88
- add_generation_prompt=True,
89
- enable_thinking=True
90
- )
91
- else:
92
- prompt = tokenizer.apply_chat_template(
93
- messages,
94
- tokenize=False,
95
- add_generation_prompt=True
96
- )
97
 
98
- inputs = tokenizer(prompt, return_tensors="pt")
99
- device = next(model.parameters()).device
100
- inputs = {k: v.to(device) for k, v in inputs.items()}
101
-
102
- generation_start = time.time()
103
- with torch.no_grad():
104
- outputs = model.generate(
105
- **inputs,
106
- max_new_tokens=max_length,
107
- temperature=temperature,
108
- do_sample=True,
109
- top_p=0.9,
110
- pad_token_id=tokenizer.eos_token_id,
111
- eos_token_id=tokenizer.eos_token_id
112
- )
113
-
114
- generation_time = time.time() - generation_start
115
- response = tokenizer.decode(
116
- outputs[0][inputs['input_ids'].shape[-1]:],
117
- skip_special_tokens=True
118
- ).strip()
119
-
120
- logger.info(f"Generation time: {generation_time:.2f}s")
121
- return response, generation_time
122
-
123
- except Exception as e:
124
- logger.error(f"Error in generate_response: {str(e)}")
125
- raise gr.Error(f"Error generating response: {str(e)}")
126
 
127
  def format_response_with_thinking(response):
128
  """Format response to handle <think></think> tags"""
 
129
  if '<think>' in response and '</think>' in response:
 
130
  pattern = r'(.*?)(<think>(.*?)</think>)(.*)'
131
  match = re.search(pattern, response, re.DOTALL)
132
 
@@ -135,66 +92,53 @@ def format_response_with_thinking(response):
135
  thinking_content = match.group(3).strip()
136
  after_thinking = match.group(4).strip()
137
 
 
138
  html = f"{before_thinking}\n"
139
  html += f'<div class="thinking-container">'
140
- html += f'<button class="thinking-toggle"><div class="thinking-icon"></div> Thinking completed <span class="dropdown-arrow">▼</span></button>'
141
  html += f'<div class="thinking-content hidden">{thinking_content}</div>'
142
  html += f'</div>\n'
143
  html += after_thinking
144
 
145
  return html
146
 
 
147
  return response
148
 
149
- def validate_input(message):
150
- """Validate user input"""
151
- if not message or not message.strip():
152
- raise gr.Error("Message cannot be empty")
153
- if len(message) > 2000:
154
- raise gr.Error("Message too long (max 2000 characters)")
155
- return message
156
-
157
  def chat_submit(message, history, conversation_state, model_name, max_length, temperature):
158
  """Process a new message and update the chat history"""
 
 
 
 
159
  try:
160
- # Validate input
161
- message = validate_input(message)
 
162
 
163
- # Get model ID
164
- model_id = MODELS.get(model_name, MODELS["Athena-R3X 4B"])
165
-
166
- # Show generating message
167
- yield "", history + [(message, "Generating response...")], conversation_state, gr.update(visible=True)
168
-
169
- # Generate response
170
- response, generation_time = generate_response(
171
  model_id, conversation_state, message, max_length, temperature
172
  )
173
 
174
- # Update conversation state
175
  conversation_state.append({"role": "user", "content": message})
176
  conversation_state.append({"role": "assistant", "content": response})
177
 
178
- # Limit conversation history to last 10 exchanges
179
- if len(conversation_state) > 20: # 10 user + 10 assistant messages
180
- conversation_state = conversation_state[-20:]
181
-
182
  # Format the response for display
183
  formatted_response = format_response_with_thinking(response)
184
 
185
  # Update the visible chat history
186
- updated_history = history[:-1] + [(message, formatted_response)]
187
-
188
- yield "", updated_history, conversation_state, gr.update(visible=False)
189
 
 
190
  except Exception as e:
191
- logger.error(f"Error in chat_submit: {str(e)}")
 
 
192
  error_message = f"Error: {str(e)}"
193
- yield error_message, history, conversation_state, gr.update(visible=False)
194
-
195
- def clear_conversation():
196
- """Clear the conversation history"""
197
- return [], [], gr.update(visible=False)
198
 
199
  css = """
200
  .message {
@@ -206,140 +150,47 @@ css = """
206
  margin: 10px 0;
207
  }
208
  .thinking-toggle {
209
- background-color: rgba(30, 30, 40, 0.8);
210
- border: none;
211
- border-radius: 25px;
212
- padding: 8px 15px;
213
  cursor: pointer;
214
- font-size: 0.95em;
215
- margin-bottom: 8px;
216
- color: white;
217
- display: flex;
218
- align-items: center;
219
- gap: 8px;
220
- box-shadow: 0 2px 5px rgba(0,0,0,0.2);
221
- transition: background-color 0.2s;
222
- width: auto;
223
- max-width: 280px;
224
- }
225
- .thinking-toggle:hover {
226
- background-color: rgba(40, 40, 50, 0.9);
227
- }
228
- .thinking-icon {
229
- width: 16px;
230
- height: 16px;
231
- border-radius: 50%;
232
- background-color: #6366f1;
233
- position: relative;
234
- overflow: hidden;
235
- }
236
- .thinking-icon::after {
237
- content: "";
238
- position: absolute;
239
- top: 50%;
240
- left: 50%;
241
- width: 60%;
242
- height: 60%;
243
- background-color: #a5b4fc;
244
- transform: translate(-50%, -50%);
245
- border-radius: 50%;
246
- }
247
- .dropdown-arrow {
248
- font-size: 0.7em;
249
- margin-left: auto;
250
- transition: transform 0.3s;
251
  }
252
  .thinking-content {
253
- background-color: rgba(30, 30, 40, 0.8);
254
- border-left: 2px solid #6366f1;
255
- padding: 15px;
256
  margin-top: 5px;
257
- margin-bottom: 15px;
258
  font-size: 0.95em;
259
- color: #e2e8f0;
260
  font-family: monospace;
261
  white-space: pre-wrap;
262
  overflow-x: auto;
263
- border-radius: 5px;
264
- line-height: 1.5;
265
  }
266
  .hidden {
267
  display: none;
268
  }
269
- .progress-container {
270
- text-align: center;
271
- margin: 10px 0;
272
- color: #6366f1;
273
- }
274
  """
275
 
276
- js = """
277
- function setupThinkingToggle() {
278
- document.querySelectorAll('.thinking-toggle').forEach(button => {
279
- if (!button.dataset.listenerAdded) {
280
- button.addEventListener('click', function() {
281
- const content = this.nextElementSibling;
282
- content.classList.toggle('hidden');
283
- const arrow = this.querySelector('.dropdown-arrow');
284
- arrow.textContent = content.classList.contains('hidden') ? '▼' : '▲';
285
- });
286
- button.dataset.listenerAdded = 'true';
287
- }
288
- });
289
- }
290
-
291
- document.addEventListener('DOMContentLoaded', () => {
292
- setupThinkingToggle();
293
-
294
- const observer = new MutationObserver((mutations) => {
295
- setupThinkingToggle();
296
- });
297
-
298
- observer.observe(document.body, {
299
- childList: true,
300
- subtree: true
301
- });
302
- });
303
- """
304
 
305
- # Create Gradio interface
306
- with gr.Blocks(title="Athena Playground Chat", css=css, js=js) as demo:
307
  gr.Markdown("# 🚀 Athena Playground Chat")
308
  gr.Markdown("*Powered by HuggingFace ZeroGPU*")
309
 
310
  # State to keep track of the conversation for the model
311
  conversation_state = gr.State([])
312
 
313
- # Hidden progress indicator
314
- progress = gr.HTML(
315
- """<div class="progress-container">Generating response...</div>""",
316
- visible=False
317
- )
318
-
319
- # Chatbot component
320
- chatbot = gr.Chatbot(
321
- height=500,
322
- label="Athena",
323
- render_markdown=True,
324
- elem_classes=["chatbot"]
325
- )
326
 
327
- # Input and send button row
328
  with gr.Row():
329
- user_input = gr.Textbox(
330
- label="Your message",
331
- scale=8,
332
- autofocus=True,
333
- placeholder="Type your message here...",
334
- lines=2
335
- )
336
- send_btn = gr.Button(
337
- value="Send",
338
- scale=1,
339
- variant="primary"
340
- )
341
 
342
- # Clear button
343
  clear_btn = gr.Button("Clear Conversation")
344
 
345
  # Configuration controls
@@ -362,25 +213,28 @@ with gr.Blocks(title="Athena Playground Chat", css=css, js=js) as demo:
362
  info="Higher values = more creative responses"
363
  )
364
 
365
- # Connect the interface components
366
- submit_event = user_input.submit(
367
- fn=chat_submit,
 
 
 
 
368
  inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
369
- outputs=[user_input, chatbot, conversation_state, progress]
370
  )
371
 
372
- send_click = send_btn.click(
373
- fn=chat_submit,
 
374
  inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
375
- outputs=[user_input, chatbot, conversation_state, progress]
376
  )
377
 
378
- clear_btn.click(
379
- fn=clear_conversation,
380
- outputs=[chatbot, conversation_state, progress]
381
- )
382
 
383
- # Examples
384
  gr.Examples(
385
  examples=[
386
  "What is artificial intelligence?",
@@ -388,15 +242,14 @@ with gr.Blocks(title="Athena Playground Chat", css=css, js=js) as demo:
388
  "Write a short poem about technology",
389
  "What are some ethical concerns about AI?"
390
  ],
391
- inputs=user_input
392
  )
393
 
394
  gr.Markdown("""
395
  ### About the Thinking Tags
396
  Some Athena models (particularly R3X series) include reasoning in `<think></think>` tags.
397
- Click on "Thinking completed" to view the model's thought process behind its answers.
398
  """)
399
 
400
  if __name__ == "__main__":
401
- demo.queue()
402
- demo.launch(debug=True)
 
4
  import time
5
  import spaces
6
  import re
 
 
 
 
 
7
 
8
  # Model configurations
9
  MODELS = {
 
18
  "Athena-1 7B": "Spestly/Athena-1-7B"
19
  }
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @spaces.GPU
22
  def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
23
+ """Generate response using ZeroGPU - all CUDA operations happen here"""
24
+ print(f"🚀 Loading {model_id}...")
25
+ start_time = time.time()
26
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
27
+ if tokenizer.pad_token is None:
28
+ tokenizer.pad_token = tokenizer.eos_token
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ model_id,
31
+ torch_dtype=torch.float16,
32
+ device_map="auto",
33
+ trust_remote_code=True
34
+ )
35
+ load_time = time.time() - start_time
36
+ print(f"✅ Model loaded in {load_time:.2f}s")
37
+
38
+ # Build messages in proper chat format (OpenAI-style messages)
39
+ messages = []
40
+ system_prompt = (
41
+ "You are Athena, a helpful, harmless, and honest AI assistant. "
42
+ "You provide clear, accurate, and concise responses to user questions. "
43
+ "You are knowledgeable across many domains and always aim to be respectful and helpful. "
44
+ "You are finetuned by Aayan Mishra"
45
+ )
46
+ messages.append({"role": "system", "content": system_prompt})
47
 
48
+ # Add conversation history
49
+ for msg in conversation:
50
+ messages.append(msg)
51
 
52
+ # Add current user message
53
+ messages.append({"role": "user", "content": user_message})
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ prompt = tokenizer.apply_chat_template(
56
+ messages,
57
+ tokenize=False,
58
+ add_generation_prompt=True
59
+ )
60
+ inputs = tokenizer(prompt, return_tensors="pt")
61
+ device = next(model.parameters()).device
62
+ inputs = {k: v.to(device) for k, v in inputs.items()}
63
+ generation_start = time.time()
64
+ with torch.no_grad():
65
+ outputs = model.generate(
66
+ **inputs,
67
+ max_new_tokens=max_length,
68
+ temperature=temperature,
69
+ do_sample=True,
70
+ top_p=0.9,
71
+ pad_token_id=tokenizer.eos_token_id,
72
+ eos_token_id=tokenizer.eos_token_id
73
+ )
74
+ generation_time = time.time() - generation_start
75
+ response = tokenizer.decode(
76
+ outputs[0][inputs['input_ids'].shape[-1]:],
77
+ skip_special_tokens=True
78
+ ).strip()
79
+ print(f"Generation time: {generation_time:.2f}s")
80
+ return response, load_time, generation_time
 
 
81
 
82
  def format_response_with_thinking(response):
83
  """Format response to handle <think></think> tags"""
84
+ # Check if response contains thinking tags
85
  if '<think>' in response and '</think>' in response:
86
+ # Split the response into parts
87
  pattern = r'(.*?)(<think>(.*?)</think>)(.*)'
88
  match = re.search(pattern, response, re.DOTALL)
89
 
 
92
  thinking_content = match.group(3).strip()
93
  after_thinking = match.group(4).strip()
94
 
95
+ # Create HTML with collapsible thinking section
96
  html = f"{before_thinking}\n"
97
  html += f'<div class="thinking-container">'
98
+ html += f'<button class="thinking-toggle" onclick="this.nextElementSibling.classList.toggle(\'hidden\'); this.textContent = this.textContent === \'Show reasoning\' ? \'Hide reasoning\' : \'Show reasoning\'">Show reasoning</button>'
99
  html += f'<div class="thinking-content hidden">{thinking_content}</div>'
100
  html += f'</div>\n'
101
  html += after_thinking
102
 
103
  return html
104
 
105
+ # If no thinking tags, return the original response
106
  return response
107
 
 
 
 
 
 
 
 
 
108
  def chat_submit(message, history, conversation_state, model_name, max_length, temperature):
109
  """Process a new message and update the chat history"""
110
+ if not message.strip():
111
+ return "", history, conversation_state
112
+
113
+ model_id = MODELS.get(model_name, MODELS["Athena-R3X 4B"])
114
  try:
115
+ # Print debug info to help diagnose issues
116
+ print(f"Processing message: {message}")
117
+ print(f"Selected model: {model_name} ({model_id})")
118
 
119
+ response, load_time, generation_time = generate_response(
 
 
 
 
 
 
 
120
  model_id, conversation_state, message, max_length, temperature
121
  )
122
 
123
+ # Update the conversation state with the raw response
124
  conversation_state.append({"role": "user", "content": message})
125
  conversation_state.append({"role": "assistant", "content": response})
126
 
 
 
 
 
127
  # Format the response for display
128
  formatted_response = format_response_with_thinking(response)
129
 
130
  # Update the visible chat history
131
+ history.append((message, formatted_response))
132
+ print(f"Response added to history. Current length: {len(history)}")
 
133
 
134
+ return "", history, conversation_state
135
  except Exception as e:
136
+ import traceback
137
+ print(f"Error in chat_submit: {str(e)}")
138
+ print(traceback.format_exc())
139
  error_message = f"Error: {str(e)}"
140
+ history.append((message, error_message))
141
+ return "", history, conversation_state
 
 
 
142
 
143
  css = """
144
  .message {
 
150
  margin: 10px 0;
151
  }
152
  .thinking-toggle {
153
+ background-color: #f1f1f1;
154
+ border: 1px solid #ddd;
155
+ border-radius: 4px;
156
+ padding: 5px 10px;
157
  cursor: pointer;
158
+ font-size: 0.9em;
159
+ margin-bottom: 5px;
160
+ color: #555;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  }
162
  .thinking-content {
163
+ background-color: #f9f9f9;
164
+ border-left: 3px solid #ccc;
165
+ padding: 10px;
166
  margin-top: 5px;
 
167
  font-size: 0.95em;
168
+ color: #555;
169
  font-family: monospace;
170
  white-space: pre-wrap;
171
  overflow-x: auto;
 
 
172
  }
173
  .hidden {
174
  display: none;
175
  }
 
 
 
 
 
176
  """
177
 
178
+ theme = gr.themes.Soft()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
+ with gr.Blocks(title="Athena Playground Chat", css=css, theme=theme) as demo:
 
181
  gr.Markdown("# 🚀 Athena Playground Chat")
182
  gr.Markdown("*Powered by HuggingFace ZeroGPU*")
183
 
184
  # State to keep track of the conversation for the model
185
  conversation_state = gr.State([])
186
 
187
+ chatbot = gr.Chatbot(height=500, label="Athena", render_markdown=True)
 
 
 
 
 
 
 
 
 
 
 
 
188
 
 
189
  with gr.Row():
190
+ user_input = gr.Textbox(label="Your message", scale=8, autofocus=True, placeholder="Type your message here...")
191
+ send_btn = gr.Button(value="Send", scale=1, variant="primary")
 
 
 
 
 
 
 
 
 
 
192
 
193
+ # Clear button for resetting the conversation
194
  clear_btn = gr.Button("Clear Conversation")
195
 
196
  # Configuration controls
 
213
  info="Higher values = more creative responses"
214
  )
215
 
216
+ # Function to clear the conversation
217
+ def clear_conversation():
218
+ return [], []
219
+
220
+ # Connect the interface components - note the specific ordering
221
+ user_input.submit(
222
+ chat_submit,
223
  inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
224
+ outputs=[user_input, chatbot, conversation_state]
225
  )
226
 
227
+ # Make sure send button uses the exact same function with the same parameter ordering
228
+ send_btn.click(
229
+ chat_submit,
230
  inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
231
+ outputs=[user_input, chatbot, conversation_state]
232
  )
233
 
234
+ # Connect clear button
235
+ clear_btn.click(clear_conversation, outputs=[chatbot, conversation_state])
 
 
236
 
237
+ # Add examples if desired
238
  gr.Examples(
239
  examples=[
240
  "What is artificial intelligence?",
 
242
  "Write a short poem about technology",
243
  "What are some ethical concerns about AI?"
244
  ],
245
+ inputs=[user_input]
246
  )
247
 
248
  gr.Markdown("""
249
  ### About the Thinking Tags
250
  Some Athena models (particularly R3X series) include reasoning in `<think></think>` tags.
251
+ Click "Show reasoning" to see the model's thought process behind its answers.
252
  """)
253
 
254
  if __name__ == "__main__":
255
+ demo.launch(debug=True) # Enable debug mode for better error reporting