Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,11 @@ import torch
|
|
4 |
import time
|
5 |
import spaces
|
6 |
import re
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# Model configurations
|
9 |
MODELS = {
|
@@ -21,82 +26,107 @@ MODELS = {
|
|
21 |
# Models that need the enable_thinking parameter
|
22 |
THINKING_ENABLED_MODELS = ["Spestly/Athena-R3X-4B"]
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
@spaces.GPU
|
25 |
def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
|
26 |
-
"""Generate response using
|
27 |
-
|
28 |
-
|
29 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
30 |
-
if tokenizer.pad_token is None:
|
31 |
-
tokenizer.pad_token = tokenizer.eos_token
|
32 |
-
model = AutoModelForCausalLM.from_pretrained(
|
33 |
-
model_id,
|
34 |
-
torch_dtype=torch.float16,
|
35 |
-
device_map="auto",
|
36 |
-
trust_remote_code=True
|
37 |
-
)
|
38 |
-
load_time = time.time() - start_time
|
39 |
-
print(f"✅ Model loaded in {load_time:.2f}s")
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
def format_response_with_thinking(response):
|
96 |
"""Format response to handle <think></think> tags"""
|
97 |
-
# Check if response contains thinking tags
|
98 |
if '<think>' in response and '</think>' in response:
|
99 |
-
# Split the response into parts
|
100 |
pattern = r'(.*?)(<think>(.*?)</think>)(.*)'
|
101 |
match = re.search(pattern, response, re.DOTALL)
|
102 |
|
@@ -105,7 +135,6 @@ def format_response_with_thinking(response):
|
|
105 |
thinking_content = match.group(3).strip()
|
106 |
after_thinking = match.group(4).strip()
|
107 |
|
108 |
-
# Create HTML with collapsible thinking section
|
109 |
html = f"{before_thinking}\n"
|
110 |
html += f'<div class="thinking-container">'
|
111 |
html += f'<button class="thinking-toggle"><div class="thinking-icon"></div> Thinking completed <span class="dropdown-arrow">▼</span></button>'
|
@@ -115,43 +144,57 @@ def format_response_with_thinking(response):
|
|
115 |
|
116 |
return html
|
117 |
|
118 |
-
# If no thinking tags, return the original response
|
119 |
return response
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
def chat_submit(message, history, conversation_state, model_name, max_length, temperature):
|
122 |
"""Process a new message and update the chat history"""
|
123 |
-
# For debugging - print when the function is called
|
124 |
-
print(f"chat_submit function called with message: '{message}'")
|
125 |
-
|
126 |
-
if not message or not message.strip():
|
127 |
-
print("Empty message, returning without processing")
|
128 |
-
return "", history, conversation_state
|
129 |
-
|
130 |
-
model_id = MODELS.get(model_name, MODELS["Athena-R3X 4B"])
|
131 |
try:
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
model_id, conversation_state, message, max_length, temperature
|
134 |
)
|
135 |
|
136 |
-
# Update
|
137 |
conversation_state.append({"role": "user", "content": message})
|
138 |
conversation_state.append({"role": "assistant", "content": response})
|
139 |
|
|
|
|
|
|
|
|
|
140 |
# Format the response for display
|
141 |
formatted_response = format_response_with_thinking(response)
|
142 |
|
143 |
# Update the visible chat history
|
144 |
-
history
|
145 |
-
|
|
|
146 |
|
147 |
-
return "", history, conversation_state
|
148 |
except Exception as e:
|
149 |
-
|
150 |
-
print(f"Error in chat_submit: {str(e)}")
|
151 |
-
print(traceback.format_exc())
|
152 |
error_message = f"Error: {str(e)}"
|
153 |
-
history.
|
154 |
-
|
|
|
|
|
|
|
155 |
|
156 |
css = """
|
157 |
.message {
|
@@ -223,53 +266,39 @@ css = """
|
|
223 |
.hidden {
|
224 |
display: none;
|
225 |
}
|
|
|
|
|
|
|
|
|
|
|
226 |
"""
|
227 |
|
228 |
-
# Add JavaScript to make the thinking buttons work
|
229 |
js = """
|
230 |
function setupThinkingToggle() {
|
231 |
document.querySelectorAll('.thinking-toggle').forEach(button => {
|
232 |
-
if (!button.
|
233 |
button.addEventListener('click', function() {
|
234 |
const content = this.nextElementSibling;
|
235 |
content.classList.toggle('hidden');
|
236 |
const arrow = this.querySelector('.dropdown-arrow');
|
237 |
-
|
238 |
-
arrow.textContent = '▼';
|
239 |
-
arrow.style.transform = '';
|
240 |
-
} else {
|
241 |
-
arrow.textContent = '▲';
|
242 |
-
arrow.style.transform = 'rotate(0deg)';
|
243 |
-
}
|
244 |
});
|
245 |
-
button.
|
246 |
}
|
247 |
});
|
248 |
}
|
249 |
|
250 |
-
// Setup a mutation observer to watch for changes in the DOM
|
251 |
-
const observer = new MutationObserver(function(mutations) {
|
252 |
-
setupThinkingToggle();
|
253 |
-
});
|
254 |
-
|
255 |
-
// Start observing after DOM is loaded
|
256 |
document.addEventListener('DOMContentLoaded', () => {
|
257 |
setupThinkingToggle();
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
observer.observe(document.body, {
|
268 |
-
childList: true,
|
269 |
-
subtree: true
|
270 |
-
});
|
271 |
-
}
|
272 |
-
}, 1000);
|
273 |
});
|
274 |
"""
|
275 |
|
@@ -281,6 +310,12 @@ with gr.Blocks(title="Athena Playground Chat", css=css, js=js) as demo:
|
|
281 |
# State to keep track of the conversation for the model
|
282 |
conversation_state = gr.State([])
|
283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
# Chatbot component
|
285 |
chatbot = gr.Chatbot(
|
286 |
height=500,
|
@@ -327,28 +362,22 @@ with gr.Blocks(title="Athena Playground Chat", css=css, js=js) as demo:
|
|
327 |
info="Higher values = more creative responses"
|
328 |
)
|
329 |
|
330 |
-
#
|
331 |
-
|
332 |
-
return [], []
|
333 |
-
|
334 |
-
# Connect the interface components with explicit handlers
|
335 |
-
submit_click = user_input.submit(
|
336 |
fn=chat_submit,
|
337 |
inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
|
338 |
-
outputs=[user_input, chatbot, conversation_state]
|
339 |
)
|
340 |
|
341 |
-
# Connect send button explicitly
|
342 |
send_click = send_btn.click(
|
343 |
fn=chat_submit,
|
344 |
inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
|
345 |
-
outputs=[user_input, chatbot, conversation_state]
|
346 |
)
|
347 |
|
348 |
-
# Clear conversation
|
349 |
clear_btn.click(
|
350 |
fn=clear_conversation,
|
351 |
-
outputs=[chatbot, conversation_state]
|
352 |
)
|
353 |
|
354 |
# Examples
|
@@ -369,6 +398,5 @@ with gr.Blocks(title="Athena Playground Chat", css=css, js=js) as demo:
|
|
369 |
""")
|
370 |
|
371 |
if __name__ == "__main__":
|
372 |
-
# Enable queue and debugging
|
373 |
demo.queue()
|
374 |
demo.launch(debug=True)
|
|
|
4 |
import time
|
5 |
import spaces
|
6 |
import re
|
7 |
+
import logging
|
8 |
+
|
9 |
+
# Configure logging
|
10 |
+
logging.basicConfig(level=logging.INFO)
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
|
13 |
# Model configurations
|
14 |
MODELS = {
|
|
|
26 |
# Models that need the enable_thinking parameter
|
27 |
THINKING_ENABLED_MODELS = ["Spestly/Athena-R3X-4B"]
|
28 |
|
29 |
+
# Cache for loaded models
|
30 |
+
loaded_models = {}
|
31 |
+
|
32 |
+
@spaces.GPU
|
33 |
+
def load_model(model_id):
|
34 |
+
"""Load model and tokenizer once and cache them"""
|
35 |
+
try:
|
36 |
+
if model_id not in loaded_models:
|
37 |
+
logger.info(f"🚀 Loading {model_id}...")
|
38 |
+
start_time = time.time()
|
39 |
+
|
40 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
41 |
+
if tokenizer.pad_token is None:
|
42 |
+
tokenizer.pad_token = tokenizer.eos_token
|
43 |
+
|
44 |
+
model = AutoModelForCausalLM.from_pretrained(
|
45 |
+
model_id,
|
46 |
+
torch_dtype=torch.float16,
|
47 |
+
device_map="auto",
|
48 |
+
trust_remote_code=True
|
49 |
+
)
|
50 |
+
|
51 |
+
load_time = time.time() - start_time
|
52 |
+
logger.info(f"✅ Model loaded in {load_time:.2f}s")
|
53 |
+
loaded_models[model_id] = (model, tokenizer, load_time)
|
54 |
+
|
55 |
+
return loaded_models[model_id]
|
56 |
+
except Exception as e:
|
57 |
+
logger.error(f"Error loading model {model_id}: {str(e)}")
|
58 |
+
raise gr.Error(f"Failed to load model {model_id}. Please try another model.")
|
59 |
+
|
60 |
@spaces.GPU
|
61 |
def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
|
62 |
+
"""Generate response using the specified model"""
|
63 |
+
try:
|
64 |
+
model, tokenizer, _ = load_model(model_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
# Build messages in proper chat format
|
67 |
+
messages = []
|
68 |
+
system_prompt = (
|
69 |
+
"You are Athena, a helpful, harmless, and honest AI assistant. "
|
70 |
+
"You provide clear, accurate, and concise responses to user questions. "
|
71 |
+
"You are knowledgeable across many domains and always aim to be respectful and helpful. "
|
72 |
+
"You are finetuned by Aayan Mishra"
|
73 |
+
)
|
74 |
+
messages.append({"role": "system", "content": system_prompt})
|
75 |
|
76 |
+
# Add conversation history
|
77 |
+
for msg in conversation:
|
78 |
+
messages.append(msg)
|
79 |
|
80 |
+
# Add current user message
|
81 |
+
messages.append({"role": "user", "content": user_message})
|
82 |
|
83 |
+
# Check if this model needs the enable_thinking parameter
|
84 |
+
if model_id in THINKING_ENABLED_MODELS:
|
85 |
+
prompt = tokenizer.apply_chat_template(
|
86 |
+
messages,
|
87 |
+
tokenize=False,
|
88 |
+
add_generation_prompt=True,
|
89 |
+
enable_thinking=True
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
prompt = tokenizer.apply_chat_template(
|
93 |
+
messages,
|
94 |
+
tokenize=False,
|
95 |
+
add_generation_prompt=True
|
96 |
+
)
|
97 |
|
98 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
99 |
+
device = next(model.parameters()).device
|
100 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
101 |
+
|
102 |
+
generation_start = time.time()
|
103 |
+
with torch.no_grad():
|
104 |
+
outputs = model.generate(
|
105 |
+
**inputs,
|
106 |
+
max_new_tokens=max_length,
|
107 |
+
temperature=temperature,
|
108 |
+
do_sample=True,
|
109 |
+
top_p=0.9,
|
110 |
+
pad_token_id=tokenizer.eos_token_id,
|
111 |
+
eos_token_id=tokenizer.eos_token_id
|
112 |
+
)
|
113 |
+
|
114 |
+
generation_time = time.time() - generation_start
|
115 |
+
response = tokenizer.decode(
|
116 |
+
outputs[0][inputs['input_ids'].shape[-1]:],
|
117 |
+
skip_special_tokens=True
|
118 |
+
).strip()
|
119 |
+
|
120 |
+
logger.info(f"Generation time: {generation_time:.2f}s")
|
121 |
+
return response, generation_time
|
122 |
+
|
123 |
+
except Exception as e:
|
124 |
+
logger.error(f"Error in generate_response: {str(e)}")
|
125 |
+
raise gr.Error(f"Error generating response: {str(e)}")
|
126 |
|
127 |
def format_response_with_thinking(response):
|
128 |
"""Format response to handle <think></think> tags"""
|
|
|
129 |
if '<think>' in response and '</think>' in response:
|
|
|
130 |
pattern = r'(.*?)(<think>(.*?)</think>)(.*)'
|
131 |
match = re.search(pattern, response, re.DOTALL)
|
132 |
|
|
|
135 |
thinking_content = match.group(3).strip()
|
136 |
after_thinking = match.group(4).strip()
|
137 |
|
|
|
138 |
html = f"{before_thinking}\n"
|
139 |
html += f'<div class="thinking-container">'
|
140 |
html += f'<button class="thinking-toggle"><div class="thinking-icon"></div> Thinking completed <span class="dropdown-arrow">▼</span></button>'
|
|
|
144 |
|
145 |
return html
|
146 |
|
|
|
147 |
return response
|
148 |
|
149 |
+
def validate_input(message):
|
150 |
+
"""Validate user input"""
|
151 |
+
if not message or not message.strip():
|
152 |
+
raise gr.Error("Message cannot be empty")
|
153 |
+
if len(message) > 2000:
|
154 |
+
raise gr.Error("Message too long (max 2000 characters)")
|
155 |
+
return message
|
156 |
+
|
157 |
def chat_submit(message, history, conversation_state, model_name, max_length, temperature):
|
158 |
"""Process a new message and update the chat history"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
try:
|
160 |
+
# Validate input
|
161 |
+
message = validate_input(message)
|
162 |
+
|
163 |
+
# Get model ID
|
164 |
+
model_id = MODELS.get(model_name, MODELS["Athena-R3X 4B"])
|
165 |
+
|
166 |
+
# Show generating message
|
167 |
+
yield "", history + [(message, "Generating response...")], conversation_state, gr.update(visible=True)
|
168 |
+
|
169 |
+
# Generate response
|
170 |
+
response, generation_time = generate_response(
|
171 |
model_id, conversation_state, message, max_length, temperature
|
172 |
)
|
173 |
|
174 |
+
# Update conversation state
|
175 |
conversation_state.append({"role": "user", "content": message})
|
176 |
conversation_state.append({"role": "assistant", "content": response})
|
177 |
|
178 |
+
# Limit conversation history to last 10 exchanges
|
179 |
+
if len(conversation_state) > 20: # 10 user + 10 assistant messages
|
180 |
+
conversation_state = conversation_state[-20:]
|
181 |
+
|
182 |
# Format the response for display
|
183 |
formatted_response = format_response_with_thinking(response)
|
184 |
|
185 |
# Update the visible chat history
|
186 |
+
updated_history = history[:-1] + [(message, formatted_response)]
|
187 |
+
|
188 |
+
yield "", updated_history, conversation_state, gr.update(visible=False)
|
189 |
|
|
|
190 |
except Exception as e:
|
191 |
+
logger.error(f"Error in chat_submit: {str(e)}")
|
|
|
|
|
192 |
error_message = f"Error: {str(e)}"
|
193 |
+
yield error_message, history, conversation_state, gr.update(visible=False)
|
194 |
+
|
195 |
+
def clear_conversation():
|
196 |
+
"""Clear the conversation history"""
|
197 |
+
return [], [], gr.update(visible=False)
|
198 |
|
199 |
css = """
|
200 |
.message {
|
|
|
266 |
.hidden {
|
267 |
display: none;
|
268 |
}
|
269 |
+
.progress-container {
|
270 |
+
text-align: center;
|
271 |
+
margin: 10px 0;
|
272 |
+
color: #6366f1;
|
273 |
+
}
|
274 |
"""
|
275 |
|
|
|
276 |
js = """
|
277 |
function setupThinkingToggle() {
|
278 |
document.querySelectorAll('.thinking-toggle').forEach(button => {
|
279 |
+
if (!button.dataset.listenerAdded) {
|
280 |
button.addEventListener('click', function() {
|
281 |
const content = this.nextElementSibling;
|
282 |
content.classList.toggle('hidden');
|
283 |
const arrow = this.querySelector('.dropdown-arrow');
|
284 |
+
arrow.textContent = content.classList.contains('hidden') ? '▼' : '▲';
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
});
|
286 |
+
button.dataset.listenerAdded = 'true';
|
287 |
}
|
288 |
});
|
289 |
}
|
290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
document.addEventListener('DOMContentLoaded', () => {
|
292 |
setupThinkingToggle();
|
293 |
+
|
294 |
+
const observer = new MutationObserver((mutations) => {
|
295 |
+
setupThinkingToggle();
|
296 |
+
});
|
297 |
+
|
298 |
+
observer.observe(document.body, {
|
299 |
+
childList: true,
|
300 |
+
subtree: true
|
301 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
});
|
303 |
"""
|
304 |
|
|
|
310 |
# State to keep track of the conversation for the model
|
311 |
conversation_state = gr.State([])
|
312 |
|
313 |
+
# Hidden progress indicator
|
314 |
+
progress = gr.HTML(
|
315 |
+
"""<div class="progress-container">Generating response...</div>""",
|
316 |
+
visible=False
|
317 |
+
)
|
318 |
+
|
319 |
# Chatbot component
|
320 |
chatbot = gr.Chatbot(
|
321 |
height=500,
|
|
|
362 |
info="Higher values = more creative responses"
|
363 |
)
|
364 |
|
365 |
+
# Connect the interface components
|
366 |
+
submit_event = user_input.submit(
|
|
|
|
|
|
|
|
|
367 |
fn=chat_submit,
|
368 |
inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
|
369 |
+
outputs=[user_input, chatbot, conversation_state, progress]
|
370 |
)
|
371 |
|
|
|
372 |
send_click = send_btn.click(
|
373 |
fn=chat_submit,
|
374 |
inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
|
375 |
+
outputs=[user_input, chatbot, conversation_state, progress]
|
376 |
)
|
377 |
|
|
|
378 |
clear_btn.click(
|
379 |
fn=clear_conversation,
|
380 |
+
outputs=[chatbot, conversation_state, progress]
|
381 |
)
|
382 |
|
383 |
# Examples
|
|
|
398 |
""")
|
399 |
|
400 |
if __name__ == "__main__":
|
|
|
401 |
demo.queue()
|
402 |
demo.launch(debug=True)
|