import os import gradio as gr import torch from transformers import TextGenerationPipeline from transformers import AutoTokenizer, AutoModelForCausalLM import datetime # Model Constants MODEL_ID = "FlameF0X/Snowflake-G0-Release" #HF repo when published MAX_LENGTH = 384 TEMPERATURE_MIN = 0.1 TEMPERATURE_MAX = 2.0 TEMPERATURE_DEFAULT = 0.7 TOP_P_MIN = 0.1 TOP_P_MAX = 1.0 TOP_P_DEFAULT = 0.9 TOP_K_MIN = 1 TOP_K_MAX = 100 TOP_K_DEFAULT = 40 MAX_NEW_TOKENS_MIN = 16 MAX_NEW_TOKENS_MAX = 1024 MAX_NEW_TOKENS_DEFAULT = 256 # CSS for the app css = """ .gradio-container { background-color: #1e1e2f !important; color: #e0e0e0 !important; } .header { background-color: #2b2b3c; padding: 20px; margin-bottom: 20px; border-radius: 10px; text-align: center; } .header h1 { color: #66ccff; margin-bottom: 10px; } .snowflake-icon { font-size: 24px; margin-right: 10px; } .footer { text-align: center; margin-top: 20px; font-size: 0.9em; color: #999; } .parameter-section { background-color: #2a2a3a; padding: 15px; border-radius: 8px; margin-bottom: 15px; } .parameter-section h3 { margin-top: 0; color: #66ccff; } .example-section { background-color: #223344; padding: 15px; border-radius: 8px; margin-bottom: 15px; } .example-section h3 { margin-top: 0; color: #66ffaa; } """ # Helper functions def load_model_and_tokenizer(): global model, tokenizer, pipeline # Add this line tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" ) pipeline = TextGenerationPipeline( model=model, tokenizer=tokenizer, return_full_text=False, max_length=MAX_LENGTH ) return model, tokenizer, pipeline def generate_text( prompt, temperature=TEMPERATURE_DEFAULT, top_p=TOP_P_DEFAULT, top_k=TOP_K_DEFAULT, max_new_tokens=MAX_NEW_TOKENS_DEFAULT, history=None ): if history is None: history = [] # Add current prompt to history history.append({"role": "user", "content": prompt}) try: # Generate response outputs = pipeline( prompt, do_sample=temperature > 0, temperature=temperature, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.pad_token_id, num_return_sequences=1 ) response = outputs[0]["generated_text"] # Add model response to history history.append({"role": "assistant", "content": response}) # Format chat history for display formatted_history = [] for entry in history: role_prefix = "👤 User: " if entry["role"] == "user" else "❄️ Snowflake: " formatted_history.append(f"{role_prefix}{entry['content']}") return response, history, "\n\n".join(formatted_history) except Exception as e: error_msg = f"Error generating response: {str(e)}" history.append({"role": "assistant", "content": f"[ERROR] {error_msg}"}) return error_msg, history, str(history) def clear_conversation(): return "", [], "" def apply_preset_example(example, history): return example, history # Example prompts examples = [ "Write a short story about a snowflake that comes to life.", "Explain the concept of artificial neural networks to a 10-year-old.", "What are some interesting applications of natural language processing?", "Write a haiku about programming.", "Create a dialogue between two AI researchers discussing the future of language models." ] # Main function def create_demo(): with gr.Blocks(css=css) as demo: # Header gr.HTML("""
Experience the capabilities of the Snowflake-G0-Release language model
There was a problem loading the Snowflake-G0-Release model: {str(e)}