Spestly commited on
Commit
a2633ea
·
verified ·
1 Parent(s): 2a46096

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -128
app.py CHANGED
@@ -4,6 +4,11 @@ import torch
4
  import time
5
  import spaces
6
  import re
 
 
 
 
 
7
 
8
  # Model configurations
9
  MODELS = {
@@ -21,82 +26,107 @@ MODELS = {
21
  # Models that need the enable_thinking parameter
22
  THINKING_ENABLED_MODELS = ["Spestly/Athena-R3X-4B"]
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  @spaces.GPU
25
  def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
26
- """Generate response using ZeroGPU - all CUDA operations happen here"""
27
- print(f"🚀 Loading {model_id}...")
28
- start_time = time.time()
29
- tokenizer = AutoTokenizer.from_pretrained(model_id)
30
- if tokenizer.pad_token is None:
31
- tokenizer.pad_token = tokenizer.eos_token
32
- model = AutoModelForCausalLM.from_pretrained(
33
- model_id,
34
- torch_dtype=torch.float16,
35
- device_map="auto",
36
- trust_remote_code=True
37
- )
38
- load_time = time.time() - start_time
39
- print(f"✅ Model loaded in {load_time:.2f}s")
40
 
41
- # Build messages in proper chat format (OpenAI-style messages)
42
- messages = []
43
- system_prompt = (
44
- "You are Athena, a helpful, harmless, and honest AI assistant. "
45
- "You provide clear, accurate, and concise responses to user questions. "
46
- "You are knowledgeable across many domains and always aim to be respectful and helpful. "
47
- "You are finetuned by Aayan Mishra"
48
- )
49
- messages.append({"role": "system", "content": system_prompt})
50
 
51
- # Add conversation history
52
- for msg in conversation:
53
- messages.append(msg)
54
 
55
- # Add current user message
56
- messages.append({"role": "user", "content": user_message})
57
 
58
- # Check if this model needs the enable_thinking parameter
59
- if model_id in THINKING_ENABLED_MODELS:
60
- prompt = tokenizer.apply_chat_template(
61
- messages,
62
- tokenize=False,
63
- add_generation_prompt=True,
64
- enable_thinking=True
65
- )
66
- else:
67
- prompt = tokenizer.apply_chat_template(
68
- messages,
69
- tokenize=False,
70
- add_generation_prompt=True
71
- )
72
 
73
- inputs = tokenizer(prompt, return_tensors="pt")
74
- device = next(model.parameters()).device
75
- inputs = {k: v.to(device) for k, v in inputs.items()}
76
- generation_start = time.time()
77
- with torch.no_grad():
78
- outputs = model.generate(
79
- **inputs,
80
- max_new_tokens=max_length,
81
- temperature=temperature,
82
- do_sample=True,
83
- top_p=0.9,
84
- pad_token_id=tokenizer.eos_token_id,
85
- eos_token_id=tokenizer.eos_token_id
86
- )
87
- generation_time = time.time() - generation_start
88
- response = tokenizer.decode(
89
- outputs[0][inputs['input_ids'].shape[-1]:],
90
- skip_special_tokens=True
91
- ).strip()
92
- print(f"Generation time: {generation_time:.2f}s")
93
- return response, load_time, generation_time
 
 
 
 
 
 
 
94
 
95
  def format_response_with_thinking(response):
96
  """Format response to handle <think></think> tags"""
97
- # Check if response contains thinking tags
98
  if '<think>' in response and '</think>' in response:
99
- # Split the response into parts
100
  pattern = r'(.*?)(<think>(.*?)</think>)(.*)'
101
  match = re.search(pattern, response, re.DOTALL)
102
 
@@ -105,7 +135,6 @@ def format_response_with_thinking(response):
105
  thinking_content = match.group(3).strip()
106
  after_thinking = match.group(4).strip()
107
 
108
- # Create HTML with collapsible thinking section
109
  html = f"{before_thinking}\n"
110
  html += f'<div class="thinking-container">'
111
  html += f'<button class="thinking-toggle"><div class="thinking-icon"></div> Thinking completed <span class="dropdown-arrow">▼</span></button>'
@@ -115,43 +144,57 @@ def format_response_with_thinking(response):
115
 
116
  return html
117
 
118
- # If no thinking tags, return the original response
119
  return response
120
 
 
 
 
 
 
 
 
 
121
  def chat_submit(message, history, conversation_state, model_name, max_length, temperature):
122
  """Process a new message and update the chat history"""
123
- # For debugging - print when the function is called
124
- print(f"chat_submit function called with message: '{message}'")
125
-
126
- if not message or not message.strip():
127
- print("Empty message, returning without processing")
128
- return "", history, conversation_state
129
-
130
- model_id = MODELS.get(model_name, MODELS["Athena-R3X 4B"])
131
  try:
132
- response, load_time, generation_time = generate_response(
 
 
 
 
 
 
 
 
 
 
133
  model_id, conversation_state, message, max_length, temperature
134
  )
135
 
136
- # Update the conversation state with the raw response
137
  conversation_state.append({"role": "user", "content": message})
138
  conversation_state.append({"role": "assistant", "content": response})
139
 
 
 
 
 
140
  # Format the response for display
141
  formatted_response = format_response_with_thinking(response)
142
 
143
  # Update the visible chat history
144
- history.append((message, formatted_response))
145
- print(f"Response added to history. Current length: {len(history)}")
 
146
 
147
- return "", history, conversation_state
148
  except Exception as e:
149
- import traceback
150
- print(f"Error in chat_submit: {str(e)}")
151
- print(traceback.format_exc())
152
  error_message = f"Error: {str(e)}"
153
- history.append((message, error_message))
154
- return "", history, conversation_state
 
 
 
155
 
156
  css = """
157
  .message {
@@ -223,53 +266,39 @@ css = """
223
  .hidden {
224
  display: none;
225
  }
 
 
 
 
 
226
  """
227
 
228
- # Add JavaScript to make the thinking buttons work
229
  js = """
230
  function setupThinkingToggle() {
231
  document.querySelectorAll('.thinking-toggle').forEach(button => {
232
- if (!button.hasEventListener) {
233
  button.addEventListener('click', function() {
234
  const content = this.nextElementSibling;
235
  content.classList.toggle('hidden');
236
  const arrow = this.querySelector('.dropdown-arrow');
237
- if (content.classList.contains('hidden')) {
238
- arrow.textContent = '▼';
239
- arrow.style.transform = '';
240
- } else {
241
- arrow.textContent = '▲';
242
- arrow.style.transform = 'rotate(0deg)';
243
- }
244
  });
245
- button.hasEventListener = true;
246
  }
247
  });
248
  }
249
 
250
- // Setup a mutation observer to watch for changes in the DOM
251
- const observer = new MutationObserver(function(mutations) {
252
- setupThinkingToggle();
253
- });
254
-
255
- // Start observing after DOM is loaded
256
  document.addEventListener('DOMContentLoaded', () => {
257
  setupThinkingToggle();
258
- setTimeout(() => {
259
- const chatbot = document.querySelector('.chatbot');
260
- if (chatbot) {
261
- observer.observe(chatbot, {
262
- childList: true,
263
- subtree: true,
264
- characterData: true
265
- });
266
- } else {
267
- observer.observe(document.body, {
268
- childList: true,
269
- subtree: true
270
- });
271
- }
272
- }, 1000);
273
  });
274
  """
275
 
@@ -281,6 +310,12 @@ with gr.Blocks(title="Athena Playground Chat", css=css, js=js) as demo:
281
  # State to keep track of the conversation for the model
282
  conversation_state = gr.State([])
283
 
 
 
 
 
 
 
284
  # Chatbot component
285
  chatbot = gr.Chatbot(
286
  height=500,
@@ -327,28 +362,22 @@ with gr.Blocks(title="Athena Playground Chat", css=css, js=js) as demo:
327
  info="Higher values = more creative responses"
328
  )
329
 
330
- # Function to clear the conversation
331
- def clear_conversation():
332
- return [], []
333
-
334
- # Connect the interface components with explicit handlers
335
- submit_click = user_input.submit(
336
  fn=chat_submit,
337
  inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
338
- outputs=[user_input, chatbot, conversation_state]
339
  )
340
 
341
- # Connect send button explicitly
342
  send_click = send_btn.click(
343
  fn=chat_submit,
344
  inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
345
- outputs=[user_input, chatbot, conversation_state]
346
  )
347
 
348
- # Clear conversation
349
  clear_btn.click(
350
  fn=clear_conversation,
351
- outputs=[chatbot, conversation_state]
352
  )
353
 
354
  # Examples
@@ -369,6 +398,5 @@ with gr.Blocks(title="Athena Playground Chat", css=css, js=js) as demo:
369
  """)
370
 
371
  if __name__ == "__main__":
372
- # Enable queue and debugging
373
  demo.queue()
374
  demo.launch(debug=True)
 
4
  import time
5
  import spaces
6
  import re
7
+ import logging
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
 
13
  # Model configurations
14
  MODELS = {
 
26
  # Models that need the enable_thinking parameter
27
  THINKING_ENABLED_MODELS = ["Spestly/Athena-R3X-4B"]
28
 
29
+ # Cache for loaded models
30
+ loaded_models = {}
31
+
32
+ @spaces.GPU
33
+ def load_model(model_id):
34
+ """Load model and tokenizer once and cache them"""
35
+ try:
36
+ if model_id not in loaded_models:
37
+ logger.info(f"🚀 Loading {model_id}...")
38
+ start_time = time.time()
39
+
40
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
41
+ if tokenizer.pad_token is None:
42
+ tokenizer.pad_token = tokenizer.eos_token
43
+
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ model_id,
46
+ torch_dtype=torch.float16,
47
+ device_map="auto",
48
+ trust_remote_code=True
49
+ )
50
+
51
+ load_time = time.time() - start_time
52
+ logger.info(f"✅ Model loaded in {load_time:.2f}s")
53
+ loaded_models[model_id] = (model, tokenizer, load_time)
54
+
55
+ return loaded_models[model_id]
56
+ except Exception as e:
57
+ logger.error(f"Error loading model {model_id}: {str(e)}")
58
+ raise gr.Error(f"Failed to load model {model_id}. Please try another model.")
59
+
60
  @spaces.GPU
61
  def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
62
+ """Generate response using the specified model"""
63
+ try:
64
+ model, tokenizer, _ = load_model(model_id)
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ # Build messages in proper chat format
67
+ messages = []
68
+ system_prompt = (
69
+ "You are Athena, a helpful, harmless, and honest AI assistant. "
70
+ "You provide clear, accurate, and concise responses to user questions. "
71
+ "You are knowledgeable across many domains and always aim to be respectful and helpful. "
72
+ "You are finetuned by Aayan Mishra"
73
+ )
74
+ messages.append({"role": "system", "content": system_prompt})
75
 
76
+ # Add conversation history
77
+ for msg in conversation:
78
+ messages.append(msg)
79
 
80
+ # Add current user message
81
+ messages.append({"role": "user", "content": user_message})
82
 
83
+ # Check if this model needs the enable_thinking parameter
84
+ if model_id in THINKING_ENABLED_MODELS:
85
+ prompt = tokenizer.apply_chat_template(
86
+ messages,
87
+ tokenize=False,
88
+ add_generation_prompt=True,
89
+ enable_thinking=True
90
+ )
91
+ else:
92
+ prompt = tokenizer.apply_chat_template(
93
+ messages,
94
+ tokenize=False,
95
+ add_generation_prompt=True
96
+ )
97
 
98
+ inputs = tokenizer(prompt, return_tensors="pt")
99
+ device = next(model.parameters()).device
100
+ inputs = {k: v.to(device) for k, v in inputs.items()}
101
+
102
+ generation_start = time.time()
103
+ with torch.no_grad():
104
+ outputs = model.generate(
105
+ **inputs,
106
+ max_new_tokens=max_length,
107
+ temperature=temperature,
108
+ do_sample=True,
109
+ top_p=0.9,
110
+ pad_token_id=tokenizer.eos_token_id,
111
+ eos_token_id=tokenizer.eos_token_id
112
+ )
113
+
114
+ generation_time = time.time() - generation_start
115
+ response = tokenizer.decode(
116
+ outputs[0][inputs['input_ids'].shape[-1]:],
117
+ skip_special_tokens=True
118
+ ).strip()
119
+
120
+ logger.info(f"Generation time: {generation_time:.2f}s")
121
+ return response, generation_time
122
+
123
+ except Exception as e:
124
+ logger.error(f"Error in generate_response: {str(e)}")
125
+ raise gr.Error(f"Error generating response: {str(e)}")
126
 
127
  def format_response_with_thinking(response):
128
  """Format response to handle <think></think> tags"""
 
129
  if '<think>' in response and '</think>' in response:
 
130
  pattern = r'(.*?)(<think>(.*?)</think>)(.*)'
131
  match = re.search(pattern, response, re.DOTALL)
132
 
 
135
  thinking_content = match.group(3).strip()
136
  after_thinking = match.group(4).strip()
137
 
 
138
  html = f"{before_thinking}\n"
139
  html += f'<div class="thinking-container">'
140
  html += f'<button class="thinking-toggle"><div class="thinking-icon"></div> Thinking completed <span class="dropdown-arrow">▼</span></button>'
 
144
 
145
  return html
146
 
 
147
  return response
148
 
149
+ def validate_input(message):
150
+ """Validate user input"""
151
+ if not message or not message.strip():
152
+ raise gr.Error("Message cannot be empty")
153
+ if len(message) > 2000:
154
+ raise gr.Error("Message too long (max 2000 characters)")
155
+ return message
156
+
157
  def chat_submit(message, history, conversation_state, model_name, max_length, temperature):
158
  """Process a new message and update the chat history"""
 
 
 
 
 
 
 
 
159
  try:
160
+ # Validate input
161
+ message = validate_input(message)
162
+
163
+ # Get model ID
164
+ model_id = MODELS.get(model_name, MODELS["Athena-R3X 4B"])
165
+
166
+ # Show generating message
167
+ yield "", history + [(message, "Generating response...")], conversation_state, gr.update(visible=True)
168
+
169
+ # Generate response
170
+ response, generation_time = generate_response(
171
  model_id, conversation_state, message, max_length, temperature
172
  )
173
 
174
+ # Update conversation state
175
  conversation_state.append({"role": "user", "content": message})
176
  conversation_state.append({"role": "assistant", "content": response})
177
 
178
+ # Limit conversation history to last 10 exchanges
179
+ if len(conversation_state) > 20: # 10 user + 10 assistant messages
180
+ conversation_state = conversation_state[-20:]
181
+
182
  # Format the response for display
183
  formatted_response = format_response_with_thinking(response)
184
 
185
  # Update the visible chat history
186
+ updated_history = history[:-1] + [(message, formatted_response)]
187
+
188
+ yield "", updated_history, conversation_state, gr.update(visible=False)
189
 
 
190
  except Exception as e:
191
+ logger.error(f"Error in chat_submit: {str(e)}")
 
 
192
  error_message = f"Error: {str(e)}"
193
+ yield error_message, history, conversation_state, gr.update(visible=False)
194
+
195
+ def clear_conversation():
196
+ """Clear the conversation history"""
197
+ return [], [], gr.update(visible=False)
198
 
199
  css = """
200
  .message {
 
266
  .hidden {
267
  display: none;
268
  }
269
+ .progress-container {
270
+ text-align: center;
271
+ margin: 10px 0;
272
+ color: #6366f1;
273
+ }
274
  """
275
 
 
276
  js = """
277
  function setupThinkingToggle() {
278
  document.querySelectorAll('.thinking-toggle').forEach(button => {
279
+ if (!button.dataset.listenerAdded) {
280
  button.addEventListener('click', function() {
281
  const content = this.nextElementSibling;
282
  content.classList.toggle('hidden');
283
  const arrow = this.querySelector('.dropdown-arrow');
284
+ arrow.textContent = content.classList.contains('hidden') ? '▼' : '▲';
 
 
 
 
 
 
285
  });
286
+ button.dataset.listenerAdded = 'true';
287
  }
288
  });
289
  }
290
 
 
 
 
 
 
 
291
  document.addEventListener('DOMContentLoaded', () => {
292
  setupThinkingToggle();
293
+
294
+ const observer = new MutationObserver((mutations) => {
295
+ setupThinkingToggle();
296
+ });
297
+
298
+ observer.observe(document.body, {
299
+ childList: true,
300
+ subtree: true
301
+ });
 
 
 
 
 
 
302
  });
303
  """
304
 
 
310
  # State to keep track of the conversation for the model
311
  conversation_state = gr.State([])
312
 
313
+ # Hidden progress indicator
314
+ progress = gr.HTML(
315
+ """<div class="progress-container">Generating response...</div>""",
316
+ visible=False
317
+ )
318
+
319
  # Chatbot component
320
  chatbot = gr.Chatbot(
321
  height=500,
 
362
  info="Higher values = more creative responses"
363
  )
364
 
365
+ # Connect the interface components
366
+ submit_event = user_input.submit(
 
 
 
 
367
  fn=chat_submit,
368
  inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
369
+ outputs=[user_input, chatbot, conversation_state, progress]
370
  )
371
 
 
372
  send_click = send_btn.click(
373
  fn=chat_submit,
374
  inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
375
+ outputs=[user_input, chatbot, conversation_state, progress]
376
  )
377
 
 
378
  clear_btn.click(
379
  fn=clear_conversation,
380
+ outputs=[chatbot, conversation_state, progress]
381
  )
382
 
383
  # Examples
 
398
  """)
399
 
400
  if __name__ == "__main__":
 
401
  demo.queue()
402
  demo.launch(debug=True)