import os import logging import traceback from typing import Dict, List, Any from nemo_skills.inference.server.code_execution_model import get_code_execution_model from nemo_skills.code_execution.sandbox import get_sandbox from nemo_skills.prompt.utils import get_prompt # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: """Custom endpoint handler for NeMo Skills code execution inference.""" def __init__(self): """ Initialize the handler with the model and prompt configurations. """ self.model = None self.prompt = None self.initialized = False # Configuration self.prompt_config_path = os.getenv("PROMPT_CONFIG_PATH", "generic/math") self.prompt_template_path = os.getenv("PROMPT_TEMPLATE_PATH", "openmath-instruct") def _initialize_components(self): """Initialize the model, sandbox, and prompt components lazily.""" if self.initialized: return try: logger.info("Initializing sandbox...") sandbox = get_sandbox(sandbox_type="local") logger.info("Initializing code execution model...") self.model = get_code_execution_model( server_type="vllm", sandbox=sandbox, host="127.0.0.1", port=5000 ) logger.info("Initializing prompt...") if self.prompt_config_path: self.prompt = get_prompt( prompt_config=self.prompt_config_path, prompt_template=self.prompt_template_path ) self.initialized = True logger.info("All components initialized successfully") except Exception as e: logger.warning(f"Failed to initialize the model") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process inference requests. Args: data: Dictionary containing the request data Expected keys: - inputs: str or list of str - the input prompts/problems - parameters: dict (optional) - generation parameters Returns: List of dictionaries containing the generated responses """ try: # Initialize components if not already done self._initialize_components() # Extract inputs and parameters inputs = data.get("inputs", "") parameters = data.get("parameters", {}) # Handle both single string and list of strings if isinstance(inputs, str): prompts = [inputs] elif isinstance(inputs, list): prompts = inputs else: raise ValueError("inputs must be a string or list of strings") # If we have a prompt template configured, format the inputs if self.prompt is not None: formatted_prompts = [] for prompt_text in prompts: formatted_prompt = self.prompt.fill({"problem": prompt_text, "total_code_executions": 8}) formatted_prompts.append(formatted_prompt) prompts = formatted_prompts # Get code execution arguments from prompt if available extra_generate_params = {} if self.prompt is not None: extra_generate_params = self.prompt.get_code_execution_args() # Set default generation parameters generation_params = { "tokens_to_generate": 12000, "temperature": 0.0, "top_p": 0.95, "top_k": 0, "repetition_penalty": 1.0, "random_seed": 0, } # Update with provided parameters generation_params.update(parameters) generation_params.update(extra_generate_params) logger.info(f"Processing {len(prompts)} prompt(s)") # Generate responses outputs = self.model.generate( prompts=prompts, **generation_params ) # Format outputs results = [] for output in outputs: result = { "generated_text": output.get("generation", ""), "code_rounds_executed": output.get("code_rounds_executed", 0), } results.append(result) logger.info(f"Successfully processed {len(results)} request(s)") return results except Exception as e: logger.error(f"Error processing request: {str(e)}") logger.error(traceback.format_exc()) return [{"error": str(e), "generated_text": ""}]