import base64 from io import BytesIO import os from pprint import pprint import queue import re from subprocess import PIPE import jupyter_client from PIL import Image import streamlit as st from streamlit.delta_generator import DeltaGenerator from client import get_client from conversation import postprocess_text, preprocess_text, Conversation, Role IPYKERNEL = os.environ.get('IPYKERNEL', 'chatglm3-demo') SYSTEM_PROMPT = '你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。' MAX_LENGTH = 8192 TRUNCATE_LENGTH = 1024 client = get_client() class CodeKernel(object): def __init__(self, kernel_name='kernel', kernel_id=None, kernel_config_path="", python_path=None, ipython_path=None, init_file_path="./startup.py", verbose=1): self.kernel_name = kernel_name self.kernel_id = kernel_id self.kernel_config_path = kernel_config_path self.python_path = python_path self.ipython_path = ipython_path self.init_file_path = init_file_path self.verbose = verbose if python_path is None and ipython_path is None: env = None else: env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path} # Initialize the backend kernel self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL, connection_file=self.kernel_config_path, exec_files=[self.init_file_path], env=env) if self.kernel_config_path: self.kernel_manager.load_connection_file() self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE) print("Backend kernel started with the configuration: {}".format( self.kernel_config_path)) else: self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE) print("Backend kernel started with the configuration: {}".format( self.kernel_manager.connection_file)) if verbose: pprint(self.kernel_manager.get_connection_info()) # Initialize the code kernel self.kernel = self.kernel_manager.blocking_client() # self.kernel.load_connection_file() self.kernel.start_channels() print("Code kernel started.") def execute(self, code): self.kernel.execute(code) try: shell_msg = self.kernel.get_shell_msg(timeout=30) io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content'] while True: msg_out = io_msg_content ### Poll the message try: io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content'] if 'execution_state' in io_msg_content and io_msg_content['execution_state'] == 'idle': break except queue.Empty: break return shell_msg, msg_out except Exception as e: print(e) return None def execute_interactive(self, code, verbose=False): shell_msg = self.kernel.execute_interactive(code) if shell_msg is queue.Empty: if verbose: print("Timeout waiting for shell message.") self.check_msg(shell_msg, verbose=verbose) return shell_msg def inspect(self, code, verbose=False): msg_id = self.kernel.inspect(code) shell_msg = self.kernel.get_shell_msg(timeout=30) if shell_msg is queue.Empty: if verbose: print("Timeout waiting for shell message.") self.check_msg(shell_msg, verbose=verbose) return shell_msg def get_error_msg(self, msg, verbose=False) -> str | None: if msg['content']['status'] == 'error': try: error_msg = msg['content']['traceback'] except: try: error_msg = msg['content']['traceback'][-1].strip() except: error_msg = "Traceback Error" if verbose: print("Error: ", error_msg) return error_msg return None def check_msg(self, msg, verbose=False): status = msg['content']['status'] if status == 'ok': if verbose: print("Execution succeeded.") elif status == 'error': for line in msg['content']['traceback']: if verbose: print(line) def shutdown(self): # Shutdown the backend kernel self.kernel_manager.shutdown_kernel() print("Backend kernel shutdown.") # Shutdown the code kernel self.kernel.shutdown() print("Code kernel shutdown.") def restart(self): # Restart the backend kernel self.kernel_manager.restart_kernel() # print("Backend kernel restarted.") def interrupt(self): # Interrupt the backend kernel self.kernel_manager.interrupt_kernel() # print("Backend kernel interrupted.") def is_alive(self): return self.kernel.is_alive() def b64_2_img(data): buff = BytesIO(base64.b64decode(data)) return Image.open(buff) def clean_ansi_codes(input_string): ansi_escape = re.compile(r'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]') return ansi_escape.sub('', input_string) def execute(code, kernel: CodeKernel) -> tuple[str, str | Image.Image]: res = "" res_type = None code = code.replace("<|observation|>", "") code = code.replace("<|assistant|>interpreter", "") code = code.replace("<|assistant|>", "") code = code.replace("<|user|>", "") code = code.replace("<|system|>", "") msg, output = kernel.execute(code) if msg['metadata']['status'] == "timeout": return res_type, 'Timed out' elif msg['metadata']['status'] == 'error': return res_type, clean_ansi_codes('\n'.join(kernel.get_error_msg(msg, verbose=True))) if 'text' in output: res_type = "text" res = output['text'] elif 'data' in output: for key in output['data']: if 'text/plain' in key: res_type = "text" res = output['data'][key] elif 'image/png' in key: res_type = "image" res = output['data'][key] break if res_type == "image": return res_type, b64_2_img(res) elif res_type == "text" or res_type == "traceback": res = res return res_type, res @st.cache_resource def get_kernel(): kernel = CodeKernel() return kernel def extract_code(text: str) -> str: pattern = r'```([^\n]*)\n(.*?)```' matches = re.findall(pattern, text, re.DOTALL) return matches[-1][1] # Append a conversation into history, while show it in a new markdown block def append_conversation( conversation: Conversation, history: list[Conversation], placeholder: DeltaGenerator | None=None, ) -> None: history.append(conversation) conversation.show(placeholder) def main(top_p: float, temperature: float, prompt_text: str): if 'ci_history' not in st.session_state: st.session_state.ci_history = [] history: list[Conversation] = st.session_state.ci_history for conversation in history: conversation.show() if prompt_text: prompt_text = prompt_text.strip() role = Role.USER append_conversation(Conversation(role, prompt_text), history) input_text = preprocess_text( SYSTEM_PROMPT, None, history, ) print("=== Input:") print(input_text) print("=== History:") print(history) placeholder = st.container() message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant") markdown_placeholder = message_placeholder.empty() for _ in range(5): output_text = '' for response in client.generate_stream( system=SYSTEM_PROMPT, tools=None, history=history, do_sample=True, max_length=MAX_LENGTH, temperature=temperature, top_p=top_p, stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)], ): token = response.token if response.token.special: print("=== Output:") print(output_text) match token.text.strip(): case '<|user|>': append_conversation(Conversation( Role.ASSISTANT, postprocess_text(output_text), ), history, markdown_placeholder) return # Initiate tool call case '<|assistant|>': append_conversation(Conversation( Role.ASSISTANT, postprocess_text(output_text), ), history, markdown_placeholder) message_placeholder = placeholder.chat_message(name="interpreter", avatar="assistant") markdown_placeholder = message_placeholder.empty() output_text = '' continue case '<|observation|>': code = extract_code(output_text) print("Code:", code) display_text = output_text.split('interpreter')[-1].strip() append_conversation(Conversation( Role.INTERPRETER, postprocess_text(display_text), ), history, markdown_placeholder) message_placeholder = placeholder.chat_message(name="observation", avatar="user") markdown_placeholder = message_placeholder.empty() output_text = '' with markdown_placeholder: with st.spinner('Executing code...'): try: res_type, res = execute(code, get_kernel()) except Exception as e: st.error(f'Error when executing code: {e}') return print("Received:", res_type, res) if res_type == 'text' and len(res) > TRUNCATE_LENGTH: res = res[:TRUNCATE_LENGTH] + ' [TRUNCATED]' append_conversation(Conversation( Role.OBSERVATION, '[Image]' if res_type == 'image' else postprocess_text(res), tool=None, image=res if res_type == 'image' else None, ), history, markdown_placeholder) message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant") markdown_placeholder = message_placeholder.empty() output_text = '' break case _: st.error(f'Unexpected special token: {token.text.strip()}') break output_text += response.token.text display_text = output_text.split('interpreter')[-1].strip() markdown_placeholder.markdown(postprocess_text(display_text + '▌')) else: append_conversation(Conversation( Role.ASSISTANT, postprocess_text(output_text), ), history, markdown_placeholder) return