kakuguo's picture
Upload 52 files
afd4069
raw
history blame contribute delete
No virus
12.7 kB
import base64
from io import BytesIO
import os
from pprint import pprint
import queue
import re
from subprocess import PIPE
import jupyter_client
from PIL import Image
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role
IPYKERNEL = os.environ.get('IPYKERNEL', 'chatglm3-demo')
SYSTEM_PROMPT = '你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。'
MAX_LENGTH = 8192
TRUNCATE_LENGTH = 1024
client = get_client()
class CodeKernel(object):
def __init__(self,
kernel_name='kernel',
kernel_id=None,
kernel_config_path="",
python_path=None,
ipython_path=None,
init_file_path="./startup.py",
verbose=1):
self.kernel_name = kernel_name
self.kernel_id = kernel_id
self.kernel_config_path = kernel_config_path
self.python_path = python_path
self.ipython_path = ipython_path
self.init_file_path = init_file_path
self.verbose = verbose
if python_path is None and ipython_path is None:
env = None
else:
env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path}
# Initialize the backend kernel
self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL,
connection_file=self.kernel_config_path,
exec_files=[self.init_file_path],
env=env)
if self.kernel_config_path:
self.kernel_manager.load_connection_file()
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format(
self.kernel_config_path))
else:
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format(
self.kernel_manager.connection_file))
if verbose:
pprint(self.kernel_manager.get_connection_info())
# Initialize the code kernel
self.kernel = self.kernel_manager.blocking_client()
# self.kernel.load_connection_file()
self.kernel.start_channels()
print("Code kernel started.")
def execute(self, code):
self.kernel.execute(code)
try:
shell_msg = self.kernel.get_shell_msg(timeout=30)
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
while True:
msg_out = io_msg_content
### Poll the message
try:
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
if 'execution_state' in io_msg_content and io_msg_content['execution_state'] == 'idle':
break
except queue.Empty:
break
return shell_msg, msg_out
except Exception as e:
print(e)
return None
def execute_interactive(self, code, verbose=False):
shell_msg = self.kernel.execute_interactive(code)
if shell_msg is queue.Empty:
if verbose:
print("Timeout waiting for shell message.")
self.check_msg(shell_msg, verbose=verbose)
return shell_msg
def inspect(self, code, verbose=False):
msg_id = self.kernel.inspect(code)
shell_msg = self.kernel.get_shell_msg(timeout=30)
if shell_msg is queue.Empty:
if verbose:
print("Timeout waiting for shell message.")
self.check_msg(shell_msg, verbose=verbose)
return shell_msg
def get_error_msg(self, msg, verbose=False) -> str | None:
if msg['content']['status'] == 'error':
try:
error_msg = msg['content']['traceback']
except:
try:
error_msg = msg['content']['traceback'][-1].strip()
except:
error_msg = "Traceback Error"
if verbose:
print("Error: ", error_msg)
return error_msg
return None
def check_msg(self, msg, verbose=False):
status = msg['content']['status']
if status == 'ok':
if verbose:
print("Execution succeeded.")
elif status == 'error':
for line in msg['content']['traceback']:
if verbose:
print(line)
def shutdown(self):
# Shutdown the backend kernel
self.kernel_manager.shutdown_kernel()
print("Backend kernel shutdown.")
# Shutdown the code kernel
self.kernel.shutdown()
print("Code kernel shutdown.")
def restart(self):
# Restart the backend kernel
self.kernel_manager.restart_kernel()
# print("Backend kernel restarted.")
def interrupt(self):
# Interrupt the backend kernel
self.kernel_manager.interrupt_kernel()
# print("Backend kernel interrupted.")
def is_alive(self):
return self.kernel.is_alive()
def b64_2_img(data):
buff = BytesIO(base64.b64decode(data))
return Image.open(buff)
def clean_ansi_codes(input_string):
ansi_escape = re.compile(r'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]')
return ansi_escape.sub('', input_string)
def execute(code, kernel: CodeKernel) -> tuple[str, str | Image.Image]:
res = ""
res_type = None
code = code.replace("<|observation|>", "")
code = code.replace("<|assistant|>interpreter", "")
code = code.replace("<|assistant|>", "")
code = code.replace("<|user|>", "")
code = code.replace("<|system|>", "")
msg, output = kernel.execute(code)
if msg['metadata']['status'] == "timeout":
return res_type, 'Timed out'
elif msg['metadata']['status'] == 'error':
return res_type, clean_ansi_codes('\n'.join(kernel.get_error_msg(msg, verbose=True)))
if 'text' in output:
res_type = "text"
res = output['text']
elif 'data' in output:
for key in output['data']:
if 'text/plain' in key:
res_type = "text"
res = output['data'][key]
elif 'image/png' in key:
res_type = "image"
res = output['data'][key]
break
if res_type == "image":
return res_type, b64_2_img(res)
elif res_type == "text" or res_type == "traceback":
res = res
return res_type, res
@st.cache_resource
def get_kernel():
kernel = CodeKernel()
return kernel
def extract_code(text: str) -> str:
pattern = r'```([^\n]*)\n(.*?)```'
matches = re.findall(pattern, text, re.DOTALL)
return matches[-1][1]
# Append a conversation into history, while show it in a new markdown block
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None=None,
) -> None:
history.append(conversation)
conversation.show(placeholder)
def main(top_p: float, temperature: float, prompt_text: str):
if 'ci_history' not in st.session_state:
st.session_state.ci_history = []
history: list[Conversation] = st.session_state.ci_history
for conversation in history:
conversation.show()
if prompt_text:
prompt_text = prompt_text.strip()
role = Role.USER
append_conversation(Conversation(role, prompt_text), history)
input_text = preprocess_text(
SYSTEM_PROMPT,
None,
history,
)
print("=== Input:")
print(input_text)
print("=== History:")
print(history)
placeholder = st.container()
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
for _ in range(5):
output_text = ''
for response in client.generate_stream(
system=SYSTEM_PROMPT,
tools=None,
history=history,
do_sample=True,
max_length=MAX_LENGTH,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
):
token = response.token
if response.token.special:
print("=== Output:")
print(output_text)
match token.text.strip():
case '<|user|>':
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return
# Initiate tool call
case '<|assistant|>':
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="interpreter", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
output_text = ''
continue
case '<|observation|>':
code = extract_code(output_text)
print("Code:", code)
display_text = output_text.split('interpreter')[-1].strip()
append_conversation(Conversation(
Role.INTERPRETER,
postprocess_text(display_text),
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="observation", avatar="user")
markdown_placeholder = message_placeholder.empty()
output_text = ''
with markdown_placeholder:
with st.spinner('Executing code...'):
try:
res_type, res = execute(code, get_kernel())
except Exception as e:
st.error(f'Error when executing code: {e}')
return
print("Received:", res_type, res)
if res_type == 'text' and len(res) > TRUNCATE_LENGTH:
res = res[:TRUNCATE_LENGTH] + ' [TRUNCATED]'
append_conversation(Conversation(
Role.OBSERVATION,
'[Image]' if res_type == 'image' else postprocess_text(res),
tool=None,
image=res if res_type == 'image' else None,
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
output_text = ''
break
case _:
st.error(f'Unexpected special token: {token.text.strip()}')
break
output_text += response.token.text
display_text = output_text.split('interpreter')[-1].strip()
markdown_placeholder.markdown(postprocess_text(display_text + '▌'))
else:
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return