Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from beeper_model import BeeperRoseGPT, generate | |
| from tokenizers import Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file as load_safetensors | |
| # ---------------------------- | |
| # 🔧 Model versions configuration | |
| # ---------------------------- | |
| MODEL_VERSIONS = { | |
| "Beeper v4 (Advanced)": { | |
| "repo_id": "AbstractPhil/beeper-rose-v4", | |
| "model_file": "beeper_rose_final.safetensors", | |
| "description": "Beeper v4 with nearly 40% the full corpus training - the most capable version currently." | |
| }, | |
| "Beeper v3 (Multi-Concept)": { | |
| "repo_id": "AbstractPhil/beeper-rose-v3", | |
| "model_file": "beeper_rose_final.safetensors", | |
| "description": "Beeper v3 with 30+ epochs including reasoning, math, and ethics" | |
| }, | |
| "Beeper v2 (Extended)": { | |
| "repo_id": "AbstractPhil/beeper-rose-v2", | |
| "model_file": "beeper_final.safetensors", | |
| "description": "Beeper v2 with extended training (~15 epochs)" | |
| }, | |
| "Beeper v1 (Original)": { | |
| "repo_id": "AbstractPhil/beeper-rose-tinystories-6l-512d-ctx512", | |
| "model_file": "beeper_rose_final.safetensors", | |
| "description": "Original Beeper trained on TinyStories" | |
| }, | |
| } | |
| # Base configuration | |
| config = { | |
| "context": 512, | |
| "vocab_size": 8192, | |
| "dim": 512, | |
| "n_heads": 8, | |
| "n_layers": 6, | |
| "mlp_ratio": 4.0, | |
| "temperature": 0.9, | |
| "top_k": 40, | |
| "top_p": 0.9, | |
| "repetition_penalty": 1.1, | |
| "presence_penalty": 0.6, | |
| "frequency_penalty": 0.0, | |
| "resid_dropout": 0.1, | |
| "dropout": 0.0, | |
| "grad_checkpoint": False, | |
| "tokenizer_path": "beeper.tokenizer.json" | |
| } | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Global model and tokenizer variables | |
| infer = None | |
| tok = None | |
| current_version = None | |
| def load_model_version(version_name): | |
| """Load the selected model version""" | |
| global infer, tok, current_version | |
| if current_version == version_name and infer is not None: | |
| return f"Already loaded: {version_name}" | |
| version_info = MODEL_VERSIONS[version_name] | |
| try: | |
| # Download model and tokenizer files | |
| model_file = hf_hub_download( | |
| repo_id=version_info["repo_id"], | |
| filename=version_info["model_file"] | |
| ) | |
| tokenizer_file = hf_hub_download( | |
| repo_id=version_info["repo_id"], | |
| filename="tokenizer.json" | |
| ) | |
| # Initialize model | |
| infer = BeeperRoseGPT(config).to(device) | |
| # Load safetensors | |
| state_dict = load_safetensors(model_file, device=str(device)) | |
| infer.load_state_dict(state_dict) | |
| infer.eval() | |
| # Load tokenizer | |
| tok = Tokenizer.from_file(tokenizer_file) | |
| current_version = version_name | |
| return f"Successfully loaded: {version_name}" | |
| except Exception as e: | |
| return f"Error loading {version_name}: {str(e)}" | |
| # Load default model on startup - try v4 first, fallback to v3 | |
| try: | |
| load_status = load_model_version("Beeper v4 (Advanced)") | |
| if "Error" in load_status: | |
| print(f"v4 not ready yet: {load_status}") | |
| load_status = load_model_version("Beeper v3 (Multi-Concept)") | |
| except: | |
| load_status = load_model_version("Beeper v3 (Multi-Concept)") | |
| print(load_status) | |
| # ---------------------------- | |
| # 💬 Gradio Chat Wrapper | |
| # ---------------------------- | |
| def beeper_reply(message, history, model_version, temperature=None, top_k=None, top_p=None, max_new_tokens=80): | |
| global infer, tok, current_version | |
| # Load model if version changed | |
| if model_version != current_version: | |
| status = load_model_version(model_version) | |
| if "Error" in status: | |
| return f"⚠️ {status}" | |
| # Check if model is loaded | |
| if infer is None or tok is None: | |
| return "⚠️ Model not loaded. Please select a version and try again." | |
| # Use defaults if not provided | |
| if temperature is None: | |
| temperature = 0.9 | |
| if top_k is None: | |
| top_k = 40 | |
| if top_p is None: | |
| top_p = 0.9 | |
| # Try Q&A format since she has some in corpus | |
| if "?" in message: | |
| prompt = f"Q: {message}\nA:" | |
| elif message.lower().strip() in ["hi", "hello", "hey"]: | |
| prompt = "The little robot said hello. She said, \"" | |
| elif "story" in message.lower(): | |
| prompt = "Once upon a time, there was a robot. " | |
| else: | |
| # Simple continuation | |
| prompt = message + ". " | |
| # Generate response with lower temperature for less repetition | |
| response = generate( | |
| model=infer, | |
| tok=tok, | |
| cfg=config, | |
| prompt=prompt, | |
| max_new_tokens=max_new_tokens, # Shorter to avoid rambling | |
| temperature=float(temperature), # Slightly lower temp | |
| top_k=int(top_k), | |
| top_p=float(top_p), | |
| repetition_penalty=1.1, # Higher penalty for repetition | |
| presence_penalty=0.8, # Higher presence penalty | |
| frequency_penalty=0.1, # Add frequency penalty | |
| device=device, | |
| detokenize=True | |
| ) | |
| # Aggressive cleanup | |
| # Remove the prompt completely | |
| if response.startswith(prompt): | |
| response = response[len(prompt):] | |
| # Remove Q&A format artifacts | |
| response = response.replace("Q:", "").replace("A:", "") | |
| # Split on newlines and take first non-empty line | |
| lines = response.split('\n') | |
| for line in lines: | |
| clean_line = line.strip() | |
| if clean_line and not clean_line.startswith(message[:10]): | |
| response = clean_line | |
| break | |
| # If response still contains the user message, try to extract after it | |
| if message.lower()[:20] in response.lower()[:50]: | |
| # Find where the echo ends | |
| words_in_message = message.split() | |
| for i in range(min(5, len(words_in_message)), 0, -1): | |
| pattern = ' '.join(words_in_message[:i]) | |
| if pattern.lower() in response.lower(): | |
| idx = response.lower().find(pattern.lower()) + len(pattern) | |
| response = response[idx:].strip() | |
| break | |
| # Remove any remaining "User" or "Beeper" artifacts | |
| for artifact in ["User:", "Beeper:", "U ser:", "Beep er:", "User ", "Beeper "]: | |
| response = response.replace(artifact, "") | |
| # Ensure we have something | |
| if not response or len(response) < 3: | |
| responses = [ | |
| "I like robots and stories!", | |
| "That's interesting!", | |
| "I want to play in the park.", | |
| "The robot was happy.", | |
| "Yes, I think so too!" | |
| ] | |
| import random | |
| response = random.choice(responses) | |
| # Clean ending | |
| response = response.strip() | |
| if response and response[-1] not in '.!?"': | |
| response = response.rsplit('.', 1)[0] + '.' if '.' in response else response + '.' | |
| return response[:200] # Cap length | |
| # ---------------------------- | |
| # 🖼️ Interface | |
| # ---------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🤖 Beeper - A Rose-based Tiny Language Model | |
| Hello! I'm Beeper, a small language model trained with love and care. Please be patient with me - I'm still learning! 💕 | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODEL_VERSIONS.keys()), | |
| value="Beeper v3 (Multi-Concept)", # Default to v3 since v4 might not be ready | |
| label="Select Beeper Version", | |
| info="Choose which version of Beeper to chat with" | |
| ) | |
| with gr.Column(scale=7): | |
| version_info = gr.Markdown("**Current:** Beeper v3 with 30+ epochs including reasoning, math, and ethics") | |
| # Update version info when dropdown changes | |
| def update_version_info(version_name): | |
| info = MODEL_VERSIONS[version_name]["description"] | |
| return f"**Current:** {info}" | |
| model_dropdown.change( | |
| fn=update_version_info, | |
| inputs=[model_dropdown], | |
| outputs=[version_info] | |
| ) | |
| # Chat interface | |
| chatbot = gr.Chatbot(label="Chat with Beeper", type="tuples", height=400) | |
| msg = gr.Textbox(label="Message", placeholder="Type your message here...") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| temperature_slider = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature") | |
| with gr.Column(scale=2): | |
| top_k_slider = gr.Slider(1, 100, value=40, step=1, label="Top-k") | |
| with gr.Column(scale=2): | |
| top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| with gr.Column(scale=2): | |
| max_new_tokens_slider = gr.Slider(20, 512, value=128, step=1, label="Max-new-tokens") | |
| with gr.Row(): | |
| submit = gr.Button("Send", variant="primary") | |
| clear = gr.Button("Clear") | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["Hello Beeper! How are you today?"], | |
| ["Can you tell me a story about a robot?"], | |
| ["What do you like to do for fun?"], | |
| ["What makes you happy?"], | |
| ["Tell me about your dreams"], | |
| ], | |
| inputs=msg | |
| ) | |
| # Handle chat | |
| def respond(message, chat_history, model_version, temperature, top_k, top_p, max_new_tokens): | |
| if not chat_history: | |
| chat_history = [] | |
| response = beeper_reply(message, chat_history, model_version, temperature, top_k, top_p, max_new_tokens) | |
| chat_history.append([message, response]) | |
| return "", chat_history | |
| msg.submit( | |
| respond, | |
| [msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider, max_new_tokens_slider], | |
| [msg, chatbot] | |
| ) | |
| submit.click( | |
| respond, | |
| [msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider, max_new_tokens_slider], | |
| [msg, chatbot] | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| if __name__ == "__main__": | |
| demo.launch() |