File size: 8,509 Bytes
61286f7
 
 
 
 
 
 
 
 
b51c72b
61286f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b51c72b
61286f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b51c72b
 
61286f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import json
import re
import sys
import io
import contextlib
import warnings
from typing import Optional, List, Any, Tuple
from PIL import Image
import streamlit as st
import pandas as pd
import base64
from io import BytesIO
from together import Together
from e2b_code_interpreter import Sandbox

warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")

pattern = re.compile(r"```python\n(.*?)\n```", re.DOTALL)

def code_interpret(e2b_code_interpreter: Sandbox, code: str) -> Optional[List[Any]]:
    with st.spinner('Executing code in E2B sandbox...'):
        stdout_capture = io.StringIO()
        stderr_capture = io.StringIO()

        with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                exec = e2b_code_interpreter.run_code(code)

        if stderr_capture.getvalue():
            print("[Code Interpreter Warnings/Errors]", file=sys.stderr)
            print(stderr_capture.getvalue(), file=sys.stderr)

        if stdout_capture.getvalue():
            print("[Code Interpreter Output]", file=sys.stdout)
            print(stdout_capture.getvalue(), file=sys.stdout)

        if exec.error:
            print(f"[Code Interpreter ERROR] {exec.error}", file=sys.stderr)
            return None
        return exec.results

def match_code_blocks(llm_response: str) -> str:
    match = pattern.search(llm_response)
    if match:
        code = match.group(1)
        return code
    return ""

def chat_with_llm(e2b_code_interpreter: Sandbox, user_message: str, dataset_path: str) -> Tuple[Optional[List[Any]], str]:
    # Update system prompt to include dataset path information
    system_prompt = f"""You're a Python data scientist and data visualization expert. You are given a dataset at path '{dataset_path}' and also the user's query.
You need to analyze the dataset and answer the user's query with a response and you run Python code to solve them.
IMPORTANT: Always use the dataset path variable '{dataset_path}' in your code when reading the CSV file."""

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_message},
    ]

    with st.spinner('Getting response from Together AI LLM model...'):
        client = Together(api_key=st.session_state.together_api_key)
        response = client.chat.completions.create(
            model=st.session_state.model_name,
            messages=messages,
        )

        response_message = response.choices[0].message
        python_code = match_code_blocks(response_message.content)
        
        if python_code:
            code_interpreter_results = code_interpret(e2b_code_interpreter, python_code)
            return code_interpreter_results, response_message.content
        else:
            st.warning(f"Failed to match any Python code in model's response")
            return None, response_message.content

def upload_dataset(code_interpreter: Sandbox, uploaded_file) -> str:
    dataset_path = f"./{uploaded_file.name}"
    
    try:
        code_interpreter.files.write(dataset_path, uploaded_file)
        return dataset_path
    except Exception as error:
        st.error(f"Error during file upload: {error}")
        raise error


def main():
    """Main Streamlit application."""
    st.set_page_config(page_title="πŸ“Š AI Data Visualization Agent", page_icon="πŸ“Š", layout="wide")
    
    st.title("πŸ“Š AI Data Visualization Agent")
    st.write("Upload your dataset and ask questions about it!")

    # Initialize session state variables
    if 'together_api_key' not in st.session_state:
        st.session_state.together_api_key = ''
    if 'e2b_api_key' not in st.session_state:
        st.session_state.e2b_api_key = ''
    if 'model_name' not in st.session_state:
        st.session_state.model_name = ''

    # Sidebar for API keys and model configuration
    with st.sidebar:
        st.header("πŸ”‘ API Keys and Model Configuration")
        st.session_state.together_api_key = st.text_input("Together AI API Key", type="password")
        st.info("πŸ’‘ Everyone gets a free $1 credit by Together AI - AI Acceleration Cloud platform")
        st.markdown("[Get Together AI API Key](https://api.together.ai/signin)")
        
        st.session_state.e2b_api_key = st.text_input("Enter E2B API Key", type="password")
        st.markdown("[Get E2B API Key](https://e2b.dev/docs/legacy/getting-started/api-key)")
        
        # Add model selection dropdown
        model_options = {
            "Meta-Llama 3.1 405B": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
            "DeepSeek V3": "deepseek-ai/DeepSeek-V3",
            "Qwen 2.5 7B": "Qwen/Qwen2.5-7B-Instruct-Turbo",
            "Meta-Llama 3.3 70B": "meta-llama/Llama-3.3-70B-Instruct-Turbo"
        }
        st.session_state.model_name = st.selectbox(
            "Select Model",
            options=list(model_options.keys()),
            index=0  # Default to first option
        )
        st.session_state.model_name = model_options[st.session_state.model_name]

    # Main content layout
    col1, col2 = st.columns([1, 2])  # Split the main content into two columns

    with col1:
        st.header("πŸ“‚ Upload Dataset")
        uploaded_file = st.file_uploader("Choose a CSV file", type="csv", key="file_uploader")
        
        if uploaded_file is not None:
            # Display dataset with toggle
            df = pd.read_csv(uploaded_file)
            st.write("### Dataset Preview")
            show_full = st.checkbox("Show full dataset")
            if show_full:
                st.dataframe(df)
            else:
                st.write("Preview (first 5 rows):")
                st.dataframe(df.head())

    with col2:
        if uploaded_file is not None:
            st.header("❓ Ask a Question")
            query = st.text_area(
                "What would you like to know about your data?",
                "Can you compare the average cost for two people between different categories?",
                height=100
            )
            
            if st.button("Analyze", type="primary", key="analyze_button"):
                if not st.session_state.together_api_key or not st.session_state.e2b_api_key:
                    st.error("Please enter both API keys in the sidebar.")
                else:
                    with Sandbox(api_key=st.session_state.e2b_api_key) as code_interpreter:
                        # Upload the dataset
                        dataset_path = upload_dataset(code_interpreter, uploaded_file)
                        
                        # Pass dataset_path to chat_with_llm
                        code_results, llm_response = chat_with_llm(code_interpreter, query, dataset_path)
                        
                        # Display LLM's text response
                        st.header("πŸ€– AI Response")
                        st.write(llm_response)
                        
                        # Display results/visualizations
                        if code_results:
                            st.header("πŸ“Š Analysis Results")
                            for result in code_results:
                                if hasattr(result, 'png') and result.png:  # Check if PNG data is available
                                    # Decode the base64-encoded PNG data
                                    png_data = base64.b64decode(result.png)
                                    
                                    # Convert PNG data to an image and display it
                                    image = Image.open(BytesIO(png_data))
                                    st.image(image, caption="Generated Visualization", use_container_width=True)
                                elif hasattr(result, 'figure'):  # For matplotlib figures
                                    fig = result.figure  # Extract the matplotlib figure
                                    st.pyplot(fig)  # Display using st.pyplot
                                elif hasattr(result, 'show'):  # For plotly figures
                                    st.plotly_chart(result)
                                elif isinstance(result, (pd.DataFrame, pd.Series)):
                                    st.dataframe(result)
                                else:
                                    st.write(result)

if __name__ == "__main__":
    main()