Spaces:
Running
on
L4
Running
on
L4
| import json | |
| import tempfile | |
| import requests | |
| import streamlit as st | |
| from lagent.schema import AgentStatusCode | |
| from pyvis.network import Network | |
| # Function to create the network graph | |
| def create_network_graph(nodes, adjacency_list): | |
| net = Network(height='500px', | |
| width='60%', | |
| bgcolor='white', | |
| font_color='black') | |
| for node_id, node_data in nodes.items(): | |
| if node_id in ['root', 'response']: | |
| title = node_data.get('content', node_id) | |
| else: | |
| title = node_data['detail']['content'] | |
| net.add_node(node_id, | |
| label=node_id, | |
| title=title, | |
| color='#FF5733', | |
| size=25) | |
| for node_id, neighbors in adjacency_list.items(): | |
| for neighbor in neighbors: | |
| if neighbor['name'] in nodes: | |
| net.add_edge(node_id, neighbor['name']) | |
| net.show_buttons(filter_=['physics']) | |
| return net | |
| # Function to draw the graph and return the HTML file path | |
| def draw_graph(net): | |
| path = tempfile.mktemp(suffix='.html') | |
| net.save_graph(path) | |
| return path | |
| def streaming(raw_response): | |
| for chunk in raw_response.iter_lines(chunk_size=8192, | |
| decode_unicode=False, | |
| delimiter=b'\n'): | |
| if chunk: | |
| decoded = chunk.decode('utf-8') | |
| if decoded == '\r': | |
| continue | |
| if decoded[:6] == 'data: ': | |
| decoded = decoded[6:] | |
| elif decoded.startswith(': ping - '): | |
| continue | |
| response = json.loads(decoded) | |
| yield (response['response'], response['current_node']) | |
| # Initialize Streamlit session state | |
| if 'queries' not in st.session_state: | |
| st.session_state['queries'] = [] | |
| st.session_state['responses'] = [] | |
| st.session_state['graphs_html'] = [] | |
| st.session_state['nodes_list'] = [] | |
| st.session_state['adjacency_list_list'] = [] | |
| st.session_state['history'] = [] | |
| st.session_state['already_used_keys'] = list() | |
| # Set up page layout | |
| st.set_page_config(layout='wide') | |
| st.title('MindSearch-思索') | |
| # Function to update chat | |
| def update_chat(query): | |
| with st.chat_message('user'): | |
| st.write(query) | |
| if query not in st.session_state['queries']: | |
| # Mock data to simulate backend response | |
| # response, history, nodes, adjacency_list | |
| st.session_state['queries'].append(query) | |
| st.session_state['responses'].append([]) | |
| history = None | |
| # 暂不支持多轮 | |
| message = [dict(role='user', content=query)] | |
| url = 'http://localhost:8002/solve' | |
| headers = {'Content-Type': 'application/json'} | |
| data = {'inputs': message} | |
| raw_response = requests.post(url, | |
| headers=headers, | |
| data=json.dumps(data), | |
| timeout=20, | |
| stream=True) | |
| for resp in streaming(raw_response): | |
| agent_return, node_name = resp | |
| if node_name and node_name in ['root', 'response']: | |
| continue | |
| nodes = agent_return['nodes'] | |
| adjacency_list = agent_return['adj'] | |
| response = agent_return['response'] | |
| history = agent_return['inner_steps'] | |
| if nodes: | |
| net = create_network_graph(nodes, adjacency_list) | |
| graph_html_path = draw_graph(net) | |
| with open(graph_html_path, encoding='utf-8') as f: | |
| graph_html = f.read() | |
| else: | |
| graph_html = None | |
| if 'graph_placeholder' not in st.session_state: | |
| st.session_state['graph_placeholder'] = st.empty() | |
| if 'expander_placeholder' not in st.session_state: | |
| st.session_state['expander_placeholder'] = st.empty() | |
| if graph_html: | |
| with st.session_state['expander_placeholder'].expander( | |
| 'Show Graph', expanded=False): | |
| st.session_state['graph_placeholder']._html(graph_html, | |
| height=500) | |
| if 'container_placeholder' not in st.session_state: | |
| st.session_state['container_placeholder'] = st.empty() | |
| with st.session_state['container_placeholder'].container(): | |
| if 'columns_placeholder' not in st.session_state: | |
| st.session_state['columns_placeholder'] = st.empty() | |
| col1, col2 = st.session_state['columns_placeholder'].columns( | |
| [2, 1]) | |
| with col1: | |
| if 'planner_placeholder' not in st.session_state: | |
| st.session_state['planner_placeholder'] = st.empty() | |
| if 'session_info_temp' not in st.session_state: | |
| st.session_state['session_info_temp'] = '' | |
| if not node_name: | |
| if agent_return['state'] in [ | |
| AgentStatusCode.STREAM_ING, | |
| AgentStatusCode.ANSWER_ING | |
| ]: | |
| st.session_state['session_info_temp'] = response | |
| elif agent_return[ | |
| 'state'] == AgentStatusCode.PLUGIN_START: | |
| thought = st.session_state[ | |
| 'session_info_temp'].split('```')[0] | |
| if agent_return['response'].startswith('```'): | |
| st.session_state[ | |
| 'session_info_temp'] = thought + '\n' + response | |
| elif agent_return[ | |
| 'state'] == AgentStatusCode.PLUGIN_RETURN: | |
| assert agent_return['inner_steps'][-1][ | |
| 'role'] == 'environment' | |
| st.session_state[ | |
| 'session_info_temp'] += '\n' + agent_return[ | |
| 'inner_steps'][-1]['content'] | |
| st.session_state['planner_placeholder'].markdown( | |
| st.session_state['session_info_temp']) | |
| if agent_return[ | |
| 'state'] == AgentStatusCode.PLUGIN_RETURN: | |
| st.session_state['responses'][-1].append( | |
| st.session_state['session_info_temp']) | |
| st.session_state['session_info_temp'] = '' | |
| else: | |
| st.session_state['planner_placeholder'].markdown( | |
| st.session_state['responses'][-1][-1] if | |
| not st.session_state['session_info_temp'] else st. | |
| session_state['session_info_temp']) | |
| with col2: | |
| if 'selectbox_placeholder' not in st.session_state: | |
| st.session_state['selectbox_placeholder'] = st.empty() | |
| if 'searcher_placeholder' not in st.session_state: | |
| st.session_state['searcher_placeholder'] = st.empty() | |
| # st.session_state['searcher_placeholder'].markdown('') | |
| if node_name: | |
| selected_node_key = f"selected_node_{len(st.session_state['queries'])}_{node_name}" | |
| if selected_node_key not in st.session_state: | |
| st.session_state[selected_node_key] = node_name | |
| if selected_node_key not in st.session_state[ | |
| 'already_used_keys']: | |
| selected_node = st.session_state[ | |
| 'selectbox_placeholder'].selectbox( | |
| 'Select a node:', | |
| list(nodes.keys()), | |
| key=f'key_{selected_node_key}', | |
| index=list(nodes.keys()).index(node_name)) | |
| st.session_state['already_used_keys'].append( | |
| selected_node_key) | |
| else: | |
| selected_node = node_name | |
| st.session_state[selected_node_key] = selected_node | |
| if selected_node in nodes: | |
| node = nodes[selected_node] | |
| agent_return = node['detail'] | |
| node_info_key = f'{selected_node}_info' | |
| if 'node_info_temp' not in st.session_state: | |
| st.session_state[ | |
| 'node_info_temp'] = f'### {agent_return["content"]}' | |
| if node_info_key not in st.session_state: | |
| st.session_state[node_info_key] = [] | |
| if agent_return['state'] in [ | |
| AgentStatusCode.STREAM_ING, | |
| AgentStatusCode.ANSWER_ING | |
| ]: | |
| st.session_state[ | |
| 'node_info_temp'] = agent_return[ | |
| 'response'] | |
| elif agent_return[ | |
| 'state'] == AgentStatusCode.PLUGIN_START: | |
| thought = st.session_state[ | |
| 'node_info_temp'].split('```')[0] | |
| if agent_return['response'].startswith('```'): | |
| st.session_state[ | |
| 'node_info_temp'] = thought + '\n' + agent_return[ | |
| 'response'] | |
| elif agent_return[ | |
| 'state'] == AgentStatusCode.PLUGIN_END: | |
| thought = st.session_state[ | |
| 'node_info_temp'].split('```')[0] | |
| if isinstance(agent_return['response'], dict): | |
| st.session_state[ | |
| 'node_info_temp'] = thought + '\n' + f'```json\n{json.dumps(agent_return["response"], ensure_ascii=False, indent=4)}\n```' # noqa: E501 | |
| elif agent_return[ | |
| 'state'] == AgentStatusCode.PLUGIN_RETURN: | |
| assert agent_return['inner_steps'][-1][ | |
| 'role'] == 'environment' | |
| st.session_state[node_info_key].append( | |
| ('thought', | |
| st.session_state['node_info_temp'])) | |
| st.session_state[node_info_key].append( | |
| ('observation', | |
| agent_return['inner_steps'][-1]['content'] | |
| )) | |
| st.session_state['searcher_placeholder'].markdown( | |
| st.session_state['node_info_temp']) | |
| if agent_return['state'] == AgentStatusCode.END: | |
| st.session_state[node_info_key].append( | |
| ('answer', | |
| st.session_state['node_info_temp'])) | |
| st.session_state['node_info_temp'] = '' | |
| if st.session_state['session_info_temp']: | |
| st.session_state['responses'][-1].append( | |
| st.session_state['session_info_temp']) | |
| st.session_state['session_info_temp'] = '' | |
| # st.session_state['responses'][-1] = '\n'.join(st.session_state['responses'][-1]) | |
| st.session_state['graphs_html'].append(graph_html) | |
| st.session_state['nodes_list'].append(nodes) | |
| st.session_state['adjacency_list_list'].append(adjacency_list) | |
| st.session_state['history'] = history | |
| def display_chat_history(): | |
| for i, query in enumerate(st.session_state['queries'][-1:]): | |
| # with st.chat_message('assistant'): | |
| if st.session_state['graphs_html'][i]: | |
| with st.session_state['expander_placeholder'].expander( | |
| 'Show Graph', expanded=False): | |
| st.session_state['graph_placeholder']._html( | |
| st.session_state['graphs_html'][i], height=500) | |
| with st.session_state['container_placeholder'].container(): | |
| col1, col2 = st.session_state['columns_placeholder'].columns( | |
| [2, 1]) | |
| with col1: | |
| st.session_state['planner_placeholder'].markdown( | |
| st.session_state['responses'][-1][-1]) | |
| with col2: | |
| selected_node_key = st.session_state['already_used_keys'][ | |
| -1] | |
| st.session_state['selectbox_placeholder'] = st.empty() | |
| selected_node = st.session_state[ | |
| 'selectbox_placeholder'].selectbox( | |
| 'Select a node:', | |
| list(st.session_state['nodes_list'][i].keys()), | |
| key=f'replay_key_{i}', | |
| index=list(st.session_state['nodes_list'][i].keys( | |
| )).index(st.session_state[selected_node_key])) | |
| st.session_state[selected_node_key] = selected_node | |
| if selected_node not in [ | |
| 'root', 'response' | |
| ] and selected_node in st.session_state['nodes_list'][i]: | |
| node_info_key = f'{selected_node}_info' | |
| for item in st.session_state[node_info_key]: | |
| if item[0] in ['thought', 'answer']: | |
| st.session_state[ | |
| 'searcher_placeholder'] = st.empty() | |
| st.session_state[ | |
| 'searcher_placeholder'].markdown(item[1]) | |
| elif item[0] == 'observation': | |
| st.session_state[ | |
| 'observation_expander'] = st.empty() | |
| with st.session_state[ | |
| 'observation_expander'].expander( | |
| 'Results'): | |
| st.write(item[1]) | |
| # st.session_state['searcher_placeholder'].markdown(st.session_state[node_info_key]) | |
| def clean_history(): | |
| st.session_state['queries'] = [] | |
| st.session_state['responses'] = [] | |
| st.session_state['graphs_html'] = [] | |
| st.session_state['nodes_list'] = [] | |
| st.session_state['adjacency_list_list'] = [] | |
| st.session_state['history'] = [] | |
| st.session_state['already_used_keys'] = list() | |
| for k in st.session_state: | |
| if k.endswith('placeholder') or k.endswith('_info'): | |
| del st.session_state[k] | |
| # Main function to run the Streamlit app | |
| def main(): | |
| st.sidebar.title('Model Control') | |
| col1, col2 = st.columns([4, 1]) | |
| with col1: | |
| user_input = st.chat_input('Enter your query:') | |
| with col2: | |
| if st.button('Clear History'): | |
| clean_history() | |
| if user_input: | |
| update_chat(user_input) | |
| display_chat_history() | |
| if __name__ == '__main__': | |
| main() | |