Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import json | |
| import os | |
| from utils import add_log | |
| # Initialize huggingface_models in session state if not present | |
| if 'huggingface_models' not in st.session_state: | |
| st.session_state.huggingface_models = [ | |
| "codegen-350M-mono", | |
| "codegen-2B-mono", | |
| "Salesforce/codegen-350M-mono", | |
| "Salesforce/codegen-2B-mono", | |
| "gpt2", | |
| "EleutherAI/gpt-neo-125M" | |
| ] | |
| # Handle missing dependencies | |
| try: | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| TRANSFORMERS_AVAILABLE = True | |
| except ImportError: | |
| TRANSFORMERS_AVAILABLE = False | |
| # Mock classes for demo purposes | |
| class DummyTokenizer: | |
| def from_pretrained(cls, model_name): | |
| return cls() | |
| def __call__(self, text, **kwargs): | |
| return {"input_ids": [list(range(10))] * (1 if isinstance(text, str) else len(text))} | |
| def decode(self, token_ids, **kwargs): | |
| return "# Generated code placeholder\n\ndef example_function():\n return 'Hello world!'" | |
| def eos_token(self): | |
| return "[EOS]" | |
| def eos_token_id(self): | |
| return 0 | |
| def pad_token(self): | |
| return None | |
| def pad_token(self, value): | |
| pass | |
| class DummyModel: | |
| def from_pretrained(cls, model_name): | |
| return cls() | |
| def generate(self, input_ids, **kwargs): | |
| return [[1, 2, 3, 4, 5]] | |
| def config(self): | |
| class Config: | |
| def eos_token_id(self): | |
| return 0 | |
| def pad_token_id(self): | |
| return 0 | |
| def pad_token_id(self, value): | |
| pass | |
| return Config() | |
| # Set aliases to match transformers | |
| AutoTokenizer = DummyTokenizer | |
| AutoModelForCausalLM = DummyModel | |
| def list_available_huggingface_models(): | |
| """ | |
| List available code generation models from Hugging Face. | |
| Returns: | |
| list: List of model names | |
| """ | |
| # Return the list stored in session state | |
| return st.session_state.huggingface_models | |
| def get_model_and_tokenizer(model_name): | |
| """ | |
| Load model and tokenizer from Hugging Face Hub. | |
| Args: | |
| model_name: Name of the model to load | |
| Returns: | |
| tuple: (model, tokenizer) or (None, None) if loading fails | |
| """ | |
| try: | |
| add_log(f"Loading model and tokenizer: {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| add_log(f"Model and tokenizer loaded successfully: {model_name}") | |
| return model, tokenizer | |
| except Exception as e: | |
| add_log(f"Error loading model {model_name}: {str(e)}", "ERROR") | |
| return None, None | |
| def save_trained_model(model_id, model, tokenizer): | |
| """ | |
| Save trained model information to session state. | |
| Args: | |
| model_id: Identifier for the model | |
| model: The trained model | |
| tokenizer: The model's tokenizer | |
| Returns: | |
| bool: Success status | |
| """ | |
| try: | |
| # Store model information in session state | |
| from datetime import datetime | |
| st.session_state.trained_models[model_id] = { | |
| 'model': model, | |
| 'tokenizer': tokenizer, | |
| 'info': { | |
| 'id': model_id, | |
| 'created_at': datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| } | |
| } | |
| add_log(f"Model {model_id} saved to session state") | |
| return True | |
| except Exception as e: | |
| add_log(f"Error saving model {model_id}: {str(e)}", "ERROR") | |
| return False | |
| def list_trained_models(): | |
| """ | |
| List all trained models in session state. | |
| Returns: | |
| list: List of model IDs | |
| """ | |
| if 'trained_models' in st.session_state: | |
| return list(st.session_state.trained_models.keys()) | |
| return [] | |
| def generate_code(model_id, prompt, max_length=100, temperature=0.7, top_p=0.9): | |
| """ | |
| Generate code using a trained model. | |
| Args: | |
| model_id: ID of the model to use | |
| prompt: Input prompt for code generation | |
| max_length: Maximum length of generated text | |
| temperature: Sampling temperature | |
| top_p: Nucleus sampling probability | |
| Returns: | |
| str: Generated code or error message | |
| """ | |
| try: | |
| if model_id not in st.session_state.trained_models: | |
| return "Error: Model not found. Please select a valid model." | |
| model_data = st.session_state.trained_models[model_id] | |
| model = model_data['model'] | |
| tokenizer = model_data['tokenizer'] | |
| if TRANSFORMERS_AVAILABLE: | |
| # Tokenize the prompt | |
| inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) | |
| # Generate text | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| num_return_sequences=1, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode the generated text | |
| generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| else: | |
| # Demo mode - return dummy generated code | |
| inputs = tokenizer(prompt) | |
| outputs = model.generate(inputs["input_ids"]) | |
| generated_code = tokenizer.decode(outputs[0]) | |
| # Add some context to the generated code based on the prompt | |
| if "fibonacci" in prompt.lower(): | |
| generated_code = "def fibonacci(n):\n if n <= 0:\n return 0\n elif n == 1:\n return 1\n else:\n return fibonacci(n-1) + fibonacci(n-2)\n" | |
| elif "sort" in prompt.lower(): | |
| generated_code = "def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]\n return arr\n" | |
| # If the prompt is included in the output, remove it to get only the generated code | |
| if generated_code.startswith(prompt): | |
| generated_code = generated_code[len(prompt):] | |
| return generated_code | |
| except Exception as e: | |
| add_log(f"Error generating code: {str(e)}", "ERROR") | |
| return f"Error generating code: {str(e)}" | |
| def get_model_info(model_id): | |
| """ | |
| Get information about a model. | |
| Args: | |
| model_id: ID of the model | |
| Returns: | |
| dict: Model information | |
| """ | |
| if 'trained_models' in st.session_state and model_id in st.session_state.trained_models: | |
| return st.session_state.trained_models[model_id]['info'] | |
| return None | |