import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import time import spaces import re # Model configurations MODELS = { "Athena-R3X 8B": "Spestly/Athena-R3X-8B", "Athena-R3X 4B": "Spestly/Athena-R3X-4B", "Athena-R3 7B": "Spestly/Athena-R3-7B", "Athena-3 3B": "Spestly/Athena-3-3B", "Athena-3 7B": "Spestly/Athena-3-7B", "Athena-3 14B": "Spestly/Athena-3-14B", "Athena-2 1.5B": "Spestly/Athena-2-1.5B", "Athena-1 3B": "Spestly/Athena-1-3B", "Athena-1 7B": "Spestly/Athena-1-7B" } @spaces.GPU def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7): """Generate response using ZeroGPU - all CUDA operations happen here""" print(f"🚀 Loading {model_id}...") start_time = time.time() tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) load_time = time.time() - start_time print(f"✅ Model loaded in {load_time:.2f}s") # Build messages in proper chat format (OpenAI-style messages) messages = [] system_prompt = ( "You are Athena, a helpful, harmless, and honest AI assistant. " "You provide clear, accurate, and concise responses to user questions. " "You are knowledgeable across many domains and always aim to be respectful and helpful. " "You are finetuned by Aayan Mishra" ) messages.append({"role": "system", "content": system_prompt}) # Add conversation history for msg in conversation: messages.append(msg) # Add current user message messages.append({"role": "user", "content": user_message}) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt, return_tensors="pt") device = next(model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} generation_start = time.time() with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_length, temperature=temperature, do_sample=True, top_p=0.9, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) generation_time = time.time() - generation_start response = tokenizer.decode( outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True ).strip() print(f"Generation time: {generation_time:.2f}s") return response, load_time, generation_time def format_response_with_thinking(response): """Format response to handle tags""" # Check if response contains thinking tags if '' in response and '' in response: # Split the response into parts pattern = r'(.*?)((.*?))(.*)' match = re.search(pattern, response, re.DOTALL) if match: before_thinking = match.group(1).strip() thinking_content = match.group(3).strip() after_thinking = match.group(4).strip() # Create HTML with collapsible thinking section html = f"{before_thinking}\n" html += f'
' html += f'' html += f'' html += f'
\n' html += after_thinking return html # If no thinking tags, return the original response return response def chat_submit(message, history, conversation_state, model_name, max_length, temperature): """Process a new message and update the chat history""" if not message.strip(): return "", history, conversation_state model_id = MODELS.get(model_name, MODELS["Athena-R3X 4B"]) try: # Print debug info to help diagnose issues print(f"Processing message: {message}") print(f"Selected model: {model_name} ({model_id})") response, load_time, generation_time = generate_response( model_id, conversation_state, message, max_length, temperature ) # Update the conversation state with the raw response conversation_state.append({"role": "user", "content": message}) conversation_state.append({"role": "assistant", "content": response}) # Format the response for display formatted_response = format_response_with_thinking(response) # Update the visible chat history history.append((message, formatted_response)) print(f"Response added to history. Current length: {len(history)}") return "", history, conversation_state except Exception as e: import traceback print(f"Error in chat_submit: {str(e)}") print(traceback.format_exc()) error_message = f"Error: {str(e)}" history.append((message, error_message)) return "", history, conversation_state css = """ .message { padding: 10px; margin: 5px; border-radius: 10px; } .thinking-container { margin: 10px 0; } .thinking-toggle { background-color: #f1f1f1; border: 1px solid #ddd; border-radius: 4px; padding: 5px 10px; cursor: pointer; font-size: 0.9em; margin-bottom: 5px; color: #555; } .thinking-content { background-color: #f9f9f9; border-left: 3px solid #ccc; padding: 10px; margin-top: 5px; font-size: 0.95em; color: #555; font-family: monospace; white-space: pre-wrap; overflow-x: auto; } .hidden { display: none; } """ theme = gr.themes.Soft() with gr.Blocks(title="Athena Playground Chat", css=css, theme=theme) as demo: gr.Markdown("# 🚀 Athena Playground Chat") gr.Markdown("*Powered by HuggingFace ZeroGPU*") # State to keep track of the conversation for the model conversation_state = gr.State([]) chatbot = gr.Chatbot(height=500, label="Athena", render_markdown=True) with gr.Row(): user_input = gr.Textbox(label="Your message", scale=8, autofocus=True, placeholder="Type your message here...") send_btn = gr.Button(value="Send", scale=1, variant="primary") # Clear button for resetting the conversation clear_btn = gr.Button("Clear Conversation") # Configuration controls gr.Markdown("### ⚙️ Model & Generation Settings") with gr.Row(): model_choice = gr.Dropdown( label="📱 Model", choices=list(MODELS.keys()), value="Athena-R3X 4B", info="Select which Athena model to use" ) max_length = gr.Slider( 32, 8192, value=512, label="📝 Max Tokens", info="Maximum number of tokens to generate" ) temperature = gr.Slider( 0.1, 2.0, value=0.7, label="🎨 Creativity", info="Higher values = more creative responses" ) # Function to clear the conversation def clear_conversation(): return [], [] # Connect the interface components - note the specific ordering user_input.submit( chat_submit, inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature], outputs=[user_input, chatbot, conversation_state] ) # Make sure send button uses the exact same function with the same parameter ordering send_btn.click( chat_submit, inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature], outputs=[user_input, chatbot, conversation_state] ) # Connect clear button clear_btn.click(clear_conversation, outputs=[chatbot, conversation_state]) # Add examples if desired gr.Examples( examples=[ "What is artificial intelligence?", "Can you explain quantum computing?", "Write a short poem about technology", "What are some ethical concerns about AI?" ], inputs=[user_input] ) gr.Markdown(""" ### About the Thinking Tags Some Athena models (particularly R3X series) include reasoning in `` tags. Click "Show reasoning" to see the model's thought process behind its answers. """) if __name__ == "__main__": demo.launch(debug=True) # Enable debug mode for better error reporting