from time import sleep import ast import astunparse import openai from openai.error import RateLimitError, APIConnectionError from pygments import highlight from pygments.lexers import PythonLexer from pygments.formatters import TerminalFormatter from utils import retrieve_proxy, chat_completion_request, generate_messages class LMP: def __init__(self, name, cfg, lmp_fgen, fixed_vars, variable_vars, md_logger): self._name = name self._cfg = cfg self._md_logger = md_logger with open(self._cfg['prompt_path'], 'r') as f: self._base_prompt = f.read() self._stop_tokens = list(self._cfg['stop']) self._lmp_fgen = lmp_fgen self._fixed_vars = fixed_vars self._variable_vars = variable_vars self.exec_hist = '' def clear_exec_hist(self): self.exec_hist = '' def build_prompt(self, query, context=''): if len(self._variable_vars) > 0: variable_vars_imports_str = f"from utils import {', '.join(self._variable_vars.keys())}" else: variable_vars_imports_str = '' prompt = self._base_prompt.replace('{variable_vars_imports}', variable_vars_imports_str) if self._cfg['maintain_session']: prompt += f'\n{self.exec_hist}' if context != '': prompt += f'\n{context}' use_query = f'{self._cfg["query_prefix"]}{query}{self._cfg["query_suffix"]}' prompt += f'\n{use_query}' return prompt, use_query def __call__(self, query, context='', **kwargs): prompt, use_query = self.build_prompt(query, context=context) response = chat_completion_request( generate_messages(prompt), # stop=self._stop_tokens, temperature=self._cfg['temperature'], model=self._cfg['model'], max_tokens=self._cfg['max_tokens'], # proxy=openai.proxy, ) code_str = response.json()["choices"][0]["message"]['content'] if self._cfg['include_context'] and context != '': to_exec = f'{context}\n{code_str}' to_log = f'{context}\n{use_query}\n{code_str}' else: to_exec = code_str to_log = f'{use_query}\n{to_exec}' to_log_pretty = highlight(to_log, PythonLexer(), TerminalFormatter()) print(f'LMP {self._name} generated code:\n{to_log_pretty}') self._md_logger.log_text(f'LMP {self._name} Generated Code:') self._md_logger.log_code(to_log) new_fs = self._lmp_fgen.create_new_fs_from_code(code_str) self._variable_vars.update(new_fs) gvars = merge_dicts([self._fixed_vars, self._variable_vars]) lvars = kwargs if not self._cfg['debug_mode']: exec_safe(to_exec, gvars, lvars) self.exec_hist += f'\n{to_exec}' if self._cfg['maintain_session']: self._variable_vars.update(lvars) if self._cfg['has_return'] and not self._cfg['debug_mode']: return lvars[self._cfg['return_val_name']] class LMPFGen: def __init__(self, cfg, fixed_vars, variable_vars, md_logger): self._cfg = cfg self._stop_tokens = list(self._cfg['stop']) self._fixed_vars = fixed_vars self._variable_vars = variable_vars self._md_logger = md_logger with open(self._cfg['prompt_path'], 'r') as f: self._base_prompt = f.read() def create_f_from_sig(self, f_name, f_sig, other_vars=None, fix_bugs=False, return_src=False): print(f'Creating function: {f_sig}') use_query = f'{self._cfg["query_prefix"]}{f_sig}{self._cfg["query_suffix"]}' prompt = f'{self._base_prompt}\n{use_query}' response = chat_completion_request( generate_messages(prompt), temperature=self._cfg['temperature'], model=self._cfg['model'], max_tokens=self._cfg['max_tokens'], # proxy=openai.proxy, ) f_src = response.json()["choices"][0]["message"]['content'] if fix_bugs: f_src = openai.Edit.create( model='code-davinci-edit-001', input='# ' + f_src, temperature=0, instruction='Fix the bug if there is one. Improve readability. Keep same inputs and outputs. Only small changes. No comments.', )['choices'][0]['text'].strip() if other_vars is None: other_vars = {} gvars = merge_dicts([self._fixed_vars, self._variable_vars, other_vars]) lvars = {} exec_safe(f_src, gvars, lvars) f = lvars[f_name] to_print = f'{use_query}\n{f_src}' to_print_pretty = highlight(to_print, PythonLexer(), TerminalFormatter()) print(f'LMPFGen generated code:\n{to_print_pretty}') self._md_logger.log_text('Generated Function:') self._md_logger.log_code(to_print) if return_src: return f, f_src return f def create_new_fs_from_code(self, code_str, other_vars=None, fix_bugs=False, return_src=False): fs, f_assigns = {}, {} f_parser = FunctionParser(fs, f_assigns) f_parser.visit(ast.parse(code_str)) for f_name, f_assign in f_assigns.items(): if f_name in fs: fs[f_name] = f_assign if other_vars is None: other_vars = {} new_fs = {} srcs = {} for f_name, f_sig in fs.items(): all_vars = merge_dicts([self._fixed_vars, self._variable_vars, new_fs, other_vars]) if not var_exists(f_name, all_vars): f, f_src = self.create_f_from_sig(f_name, f_sig, new_fs, fix_bugs=fix_bugs, return_src=True) # recursively define child_fs in the function body if needed f_def_body = astunparse.unparse(ast.parse(f_src).body[0].body) child_fs, child_f_srcs = self.create_new_fs_from_code( f_def_body, other_vars=all_vars, fix_bugs=fix_bugs, return_src=True ) if len(child_fs) > 0: new_fs.update(child_fs) srcs.update(child_f_srcs) # redefine parent f so newly created child_fs are in scope gvars = merge_dicts([self._fixed_vars, self._variable_vars, new_fs, other_vars]) lvars = {} exec_safe(f_src, gvars, lvars) f = lvars[f_name] new_fs[f_name], srcs[f_name] = f, f_src if return_src: return new_fs, srcs return new_fs class FunctionParser(ast.NodeTransformer): def __init__(self, fs, f_assigns): super().__init__() self._fs = fs self._f_assigns = f_assigns def visit_Call(self, node): self.generic_visit(node) if isinstance(node.func, ast.Name): f_sig = astunparse.unparse(node).strip() f_name = astunparse.unparse(node.func).strip() self._fs[f_name] = f_sig return node def visit_Assign(self, node): self.generic_visit(node) if isinstance(node.value, ast.Call): assign_str = astunparse.unparse(node).strip() f_name = astunparse.unparse(node.value.func).strip() self._f_assigns[f_name] = assign_str return node def var_exists(name, all_vars): try: eval(name, all_vars) except: exists = False else: exists = True return exists def merge_dicts(dicts): return { k : v for d in dicts for k, v in d.items() } import traceback def exec_safe(code_str, gvars=None, lvars=None): banned_phrases = ['import', '__'] for phrase in banned_phrases: assert phrase not in code_str if gvars is None: gvars = {} if lvars is None: lvars = {} empty_fn = lambda *args, **kwargs: None custom_gvars = merge_dicts([ gvars, {'exec': empty_fn, 'eval': empty_fn} ]) try: exec(code_str, custom_gvars, lvars) except Exception as e: traceback.print_exc() # exec(code_str, custom_gvars, lvars)