Spaces:
Runtime error
Runtime error
from time import sleep | |
import ast | |
import astunparse | |
import openai | |
from openai.error import RateLimitError, APIConnectionError | |
from pygments import highlight | |
from pygments.lexers import PythonLexer | |
from pygments.formatters import TerminalFormatter | |
class LMP: | |
def __init__(self, name, cfg, lmp_fgen, fixed_vars, variable_vars): | |
self._name = name | |
self._cfg = cfg | |
with open(self._cfg['prompt_path'], 'r') as f: | |
self._base_prompt = f.read() | |
self._stop_tokens = list(self._cfg['stop']) | |
self._lmp_fgen = lmp_fgen | |
self._fixed_vars = fixed_vars | |
self._variable_vars = variable_vars | |
self.exec_hist = '' | |
def clear_exec_hist(self): | |
self.exec_hist = '' | |
def build_prompt(self, query, context=''): | |
if len(self._variable_vars) > 0: | |
variable_vars_imports_str = f"from utils import {', '.join(self._variable_vars.keys())}" | |
else: | |
variable_vars_imports_str = '' | |
prompt = self._base_prompt.replace('{variable_vars_imports}', variable_vars_imports_str) | |
if self._cfg['maintain_session']: | |
prompt += f'\n{self.exec_hist}' | |
if context != '': | |
prompt += f'\n{context}' | |
use_query = f'{self._cfg["query_prefix"]}{query}{self._cfg["query_suffix"]}' | |
prompt += f'\n{use_query}' | |
return prompt, use_query | |
def __call__(self, query, context='', **kwargs): | |
prompt, use_query = self.build_prompt(query, context=context) | |
while True: | |
try: | |
code_str = openai.Completion.create( | |
prompt=prompt, | |
stop=self._stop_tokens, | |
temperature=self._cfg['temperature'], | |
engine=self._cfg['engine'], | |
max_tokens=self._cfg['max_tokens'] | |
)['choices'][0]['text'].strip() | |
break | |
except (RateLimitError, APIConnectionError) as e: | |
print(f'OpenAI API got err {e}') | |
print('Retrying after 10s.') | |
sleep(10) | |
if self._cfg['include_context'] and context != '': | |
to_exec = f'{context}\n{code_str}' | |
to_log = f'{context}\n{use_query}\n{code_str}' | |
else: | |
to_exec = code_str | |
to_log = f'{use_query}\n{to_exec}' | |
to_log_pretty = highlight(to_log, PythonLexer(), TerminalFormatter()) | |
print(f'LMP {self._name} exec:\n\n{to_log_pretty}\n') | |
new_fs = self._lmp_fgen.create_new_fs_from_code(code_str) | |
self._variable_vars.update(new_fs) | |
gvars = merge_dicts([self._fixed_vars, self._variable_vars]) | |
lvars = kwargs | |
if not self._cfg['debug_mode']: | |
exec_safe(to_exec, gvars, lvars) | |
self.exec_hist += f'\n{to_exec}' | |
if self._cfg['maintain_session']: | |
self._variable_vars.update(lvars) | |
if self._cfg['has_return']: | |
return lvars[self._cfg['return_val_name']] | |
class LMPFGen: | |
def __init__(self, cfg, fixed_vars, variable_vars): | |
self._cfg = cfg | |
self._stop_tokens = list(self._cfg['stop']) | |
self._fixed_vars = fixed_vars | |
self._variable_vars = variable_vars | |
with open(self._cfg['prompt_path'], 'r') as f: | |
self._base_prompt = f.read() | |
def create_f_from_sig(self, f_name, f_sig, other_vars=None, fix_bugs=False, return_src=False): | |
print(f'Creating function: {f_sig}') | |
use_query = f'{self._cfg["query_prefix"]}{f_sig}{self._cfg["query_suffix"]}' | |
prompt = f'{self._base_prompt}\n{use_query}' | |
while True: | |
try: | |
f_src = openai.Completion.create( | |
prompt=prompt, | |
stop=self._stop_tokens, | |
temperature=self._cfg['temperature'], | |
engine=self._cfg['engine'], | |
max_tokens=self._cfg['max_tokens'] | |
)['choices'][0]['text'].strip() | |
break | |
except (RateLimitError, APIConnectionError) as e: | |
print(f'OpenAI API got err {e}') | |
print('Retrying after 10s.') | |
sleep(10) | |
if fix_bugs: | |
f_src = openai.Edit.create( | |
model='code-davinci-edit-001', | |
input='# ' + f_src, | |
temperature=0, | |
instruction='Fix the bug if there is one. Improve readability. Keep same inputs and outputs. Only small changes. No comments.', | |
)['choices'][0]['text'].strip() | |
if other_vars is None: | |
other_vars = {} | |
gvars = merge_dicts([self._fixed_vars, self._variable_vars, other_vars]) | |
lvars = {} | |
exec_safe(f_src, gvars, lvars) | |
f = lvars[f_name] | |
to_print = highlight(f'{use_query}\n{f_src}', PythonLexer(), TerminalFormatter()) | |
print(f'LMP FGEN created:\n\n{to_print}\n') | |
if return_src: | |
return f, f_src | |
return f | |
def create_new_fs_from_code(self, code_str, other_vars=None, fix_bugs=False, return_src=False): | |
fs, f_assigns = {}, {} | |
f_parser = FunctionParser(fs, f_assigns) | |
f_parser.visit(ast.parse(code_str)) | |
for f_name, f_assign in f_assigns.items(): | |
if f_name in fs: | |
fs[f_name] = f_assign | |
if other_vars is None: | |
other_vars = {} | |
new_fs = {} | |
srcs = {} | |
for f_name, f_sig in fs.items(): | |
all_vars = merge_dicts([self._fixed_vars, self._variable_vars, new_fs, other_vars]) | |
if not var_exists(f_name, all_vars): | |
f, f_src = self.create_f_from_sig(f_name, f_sig, new_fs, fix_bugs=fix_bugs, return_src=True) | |
# recursively define child_fs in the function body if needed | |
f_def_body = astunparse.unparse(ast.parse(f_src).body[0].body) | |
child_fs, child_f_srcs = self.create_new_fs_from_code( | |
f_def_body, other_vars=all_vars, fix_bugs=fix_bugs, return_src=True | |
) | |
if len(child_fs) > 0: | |
new_fs.update(child_fs) | |
srcs.update(child_f_srcs) | |
# redefine parent f so newly created child_fs are in scope | |
gvars = merge_dicts([self._fixed_vars, self._variable_vars, new_fs, other_vars]) | |
lvars = {} | |
exec_safe(f_src, gvars, lvars) | |
f = lvars[f_name] | |
new_fs[f_name], srcs[f_name] = f, f_src | |
if return_src: | |
return new_fs, srcs | |
return new_fs | |
class FunctionParser(ast.NodeTransformer): | |
def __init__(self, fs, f_assigns): | |
super().__init__() | |
self._fs = fs | |
self._f_assigns = f_assigns | |
def visit_Call(self, node): | |
self.generic_visit(node) | |
if isinstance(node.func, ast.Name): | |
f_sig = astunparse.unparse(node).strip() | |
f_name = astunparse.unparse(node.func).strip() | |
self._fs[f_name] = f_sig | |
return node | |
def visit_Assign(self, node): | |
self.generic_visit(node) | |
if isinstance(node.value, ast.Call): | |
assign_str = astunparse.unparse(node).strip() | |
f_name = astunparse.unparse(node.value.func).strip() | |
self._f_assigns[f_name] = assign_str | |
return node | |
def var_exists(name, all_vars): | |
try: | |
eval(name, all_vars) | |
except: | |
exists = False | |
else: | |
exists = True | |
return exists | |
def merge_dicts(dicts): | |
return { | |
k : v | |
for d in dicts | |
for k, v in d.items() | |
} | |
def exec_safe(code_str, gvars=None, lvars=None): | |
banned_phrases = ['import', '__'] | |
for phrase in banned_phrases: | |
assert phrase not in code_str | |
if gvars is None: | |
gvars = {} | |
if lvars is None: | |
lvars = {} | |
empty_fn = lambda *args, **kwargs: None | |
custom_gvars = merge_dicts([ | |
gvars, | |
{'exec': empty_fn, 'eval': empty_fn} | |
]) | |
exec(code_str, custom_gvars, lvars) |