jzou19950715 commited on
Commit
be8a1ca
·
verified ·
1 Parent(s): 05c2a98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -123
app.py CHANGED
@@ -3,139 +3,195 @@ from typing import List, Optional, Union
3
 
4
  import gradio as gr
5
  import pandas as pd
6
- from dotenv import load_dotenv
7
- from pandas import DataFrame
8
  from smolagents import CodeAgent, LiteLLMModel, tool
9
 
10
- # Load environment variables
11
- load_dotenv()
12
 
13
- def create_agent():
14
- """Create a CodeAgent instance with GPT-4 backend."""
15
- model = LiteLLMModel(model_id="gpt-4o-mini")
16
-
17
- @tool
18
- def read_csv(filepath: str) -> DataFrame:
19
- """
20
- Read a CSV file and return a pandas DataFrame.
21
-
22
- Args:
23
- filepath: Path to the CSV file
24
- """
25
- return pd.read_csv(filepath)
26
-
27
- @tool
28
- def read_excel(filepath: str) -> DataFrame:
29
- """
30
- Read an Excel file and return a pandas DataFrame.
31
-
32
- Args:
33
- filepath: Path to the Excel file
34
- """
35
- return pd.read_excel(filepath)
36
 
37
- agent = CodeAgent(
38
- tools=[read_csv, read_excel],
39
- model=model,
40
- additional_authorized_imports=[
41
- "pandas",
42
- "numpy",
43
- "matplotlib",
44
- "seaborn",
45
- "plotly",
46
- "sklearn",
47
- "scipy",
48
- ],
49
- max_steps=5,
50
- verbosity_level=1
51
- )
52
- return agent
53
 
54
- def process_request(
55
- files: Union[str, List[str]],
56
- user_query: str,
57
- api_key: str = "",
58
- temperature: float = 0.7,
59
- history: Optional[List[tuple]] = None
60
- ) -> tuple:
61
  """
62
- Process user request with uploaded files and query.
63
 
64
  Args:
65
- files: Path or list of paths to uploaded files
66
- user_query: Natural language query from user
67
- api_key: Optional API key for GPT-4
68
- temperature: Model temperature
69
- history: Chat history
70
-
71
- Returns:
72
- Tuple of (output, error, new_history)
73
  """
74
- if api_key:
75
- os.environ["OPENAI_API_KEY"] = api_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  try:
78
- # Create agent instance
79
- agent = create_agent()
80
 
81
- # Build context from files
82
- file_context = ""
83
- if isinstance(files, str):
84
- files = [files]
85
-
86
- for file in files:
87
- filename = os.path.basename(file)
88
- file_context += f"File uploaded: {filename}\n"
89
-
90
- # Build complete prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  prompt = f"""
92
- {file_context}
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- User request: {user_query}
95
 
96
  Please analyze the data and provide:
97
- 1. Code to perform the analysis
98
- 2. Explanation of approach
99
- 3. Visualizations if relevant
100
  4. Key insights and findings
101
- """
102
 
103
- # Execute agent
104
- result = agent.run(prompt)
105
-
106
- # Update history
107
- new_history = history or []
108
- new_history.append((user_query, result))
109
 
110
- return result, None, new_history
 
 
111
 
112
  except Exception as e:
113
- return None, str(e), history
114
 
115
- # Create Gradio interface
116
  def create_interface():
117
- """Create Gradio interface for the AI coding assistant."""
118
 
119
- with gr.Blocks(title="AI Coding Assistant") as interface:
120
  gr.Markdown("""
121
- # AI Coding Assistant
122
- Upload data files and ask questions in natural language to get code, analysis and visualizations.
 
 
 
 
 
 
 
 
 
 
123
  """)
124
 
125
  with gr.Row():
126
  with gr.Column():
127
- files = gr.File(
128
- label="Upload Data Files",
129
- file_types=[".csv", ".xlsx", ".xls"],
130
- multiple=True
131
  )
132
  query = gr.Textbox(
133
  label="What would you like to analyze?",
134
- placeholder="e.g., Create a scatter plot comparing column A vs B"
 
135
  )
136
  api_key = gr.Textbox(
137
- label="API Key (Optional)",
138
- placeholder="Your OpenAI API key",
139
  type="password"
140
  )
141
  temperature = gr.Slider(
@@ -145,39 +201,29 @@ def create_interface():
145
  value=0.7,
146
  step=0.1
147
  )
148
- submit = gr.Button("Analyze")
149
-
150
  with gr.Column():
151
  output = gr.Markdown(label="Output")
152
- error = gr.Markdown(label="Errors")
153
-
154
- # Hidden state for chat history
155
- history = gr.State([])
156
 
157
  # Handle submissions
158
- submit.click(
159
- process_request,
160
- inputs=[files, query, api_key, temperature, history],
161
- outputs=[output, error, history]
162
  )
163
 
164
- # Add examples
165
  gr.Examples(
166
  examples=[
167
- [
168
- None,
169
- "Create a scatter plot showing the relationship between column A and B, with a trend line",
170
- ],
171
- [
172
- None,
173
- "Calculate summary statistics and identify any outliers in the numerical columns",
174
- ],
175
- [
176
- None,
177
- "Perform clustering analysis on the data and visualize the clusters",
178
- ],
179
  ],
180
- inputs=[files, query],
181
  )
182
 
183
  return interface
 
3
 
4
  import gradio as gr
5
  import pandas as pd
 
 
6
  from smolagents import CodeAgent, LiteLLMModel, tool
7
 
 
 
8
 
9
+ # Tool definitions to showcase smolagents capabilities
10
+ @tool
11
+ def search_web(query: str) -> str:
12
+ """Simulate web search (for demo purposes)"""
13
+ return f"Simulated web search results for: {query}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ @tool
16
+ def analyze_dataframe(df: pd.DataFrame, analysis_type: str) -> str:
17
+ """
18
+ Analyze a pandas DataFrame based on specified analysis type.
19
+
20
+ Args:
21
+ df: DataFrame to analyze
22
+ analysis_type: Type of analysis to perform
23
+ """
24
+ if analysis_type == "summary":
25
+ return str(df.describe())
26
+ elif analysis_type == "info":
27
+ return str(df.info())
28
+ return "Unknown analysis type"
 
 
29
 
30
+ @tool
31
+ def plot_data(df: pd.DataFrame, plot_type: str) -> None:
 
 
 
 
 
32
  """
33
+ Create plots from DataFrame.
34
 
35
  Args:
36
+ df: DataFrame to plot
37
+ plot_type: Type of plot to create
 
 
 
 
 
 
38
  """
39
+ import matplotlib.pyplot as plt
40
+ import seaborn as sns
41
+
42
+ if plot_type == "correlation":
43
+ plt.figure(figsize=(10, 8))
44
+ sns.heatmap(df.corr(), annot=True)
45
+ plt.title("Correlation Heatmap")
46
+ elif plot_type == "distribution":
47
+ df.hist(figsize=(15, 10))
48
+ plt.tight_layout()
49
+
50
+ def process_files(files: List[gr.File]) -> Optional[pd.DataFrame]:
51
+ """Process uploaded files into a DataFrame."""
52
+ if not files:
53
+ return None
54
+
55
+ dfs = []
56
+ for file in files:
57
+ try:
58
+ if file.name.endswith('.csv'):
59
+ df = pd.read_csv(file.name)
60
+ elif file.name.endswith(('.xlsx', '.xls')):
61
+ df = pd.read_excel(file.name)
62
+ else:
63
+ continue
64
+ dfs.append(df)
65
+ except Exception as e:
66
+ print(f"Error reading {file.name}: {str(e)}")
67
+
68
+ if not dfs:
69
+ return None
70
+
71
+ return pd.concat(dfs) if len(dfs) > 1 else dfs[0]
72
+
73
+ def analyze_data(
74
+ files: List[gr.File],
75
+ query: str,
76
+ api_key: str,
77
+ temperature: float = 0.7,
78
+ ) -> str:
79
+ """Process user request and generate analysis using smolagents."""
80
+
81
+ if not api_key:
82
+ return "Error: Please provide an API key."
83
+
84
+ if not files:
85
+ return "Error: Please upload at least one file."
86
 
87
  try:
88
+ # Set up the environment
89
+ os.environ["OPENAI_API_KEY"] = api_key
90
 
91
+ # Create model and agent
92
+ model = LiteLLMModel(
93
+ model_id="gpt-4o-mini",
94
+ temperature=temperature
95
+ )
96
+
97
+ # Create agent with various tools to showcase capabilities
98
+ agent = CodeAgent(
99
+ tools=[search_web, analyze_dataframe, plot_data],
100
+ model=model,
101
+ additional_authorized_imports=[
102
+ "pandas",
103
+ "numpy",
104
+ "matplotlib",
105
+ "seaborn",
106
+ "plotly",
107
+ "sklearn",
108
+ "scipy"
109
+ ],
110
+ max_steps=5,
111
+ verbosity_level=1
112
+ )
113
+
114
+ # Process uploaded files
115
+ df = process_files(files)
116
+ if df is None:
117
+ return "Error: Could not process uploaded files."
118
+
119
+ # Build context
120
+ file_info = "\n".join([
121
+ "Uploaded files:",
122
+ *[f"- {f.name}" for f in files],
123
+ f"\nDataFrame Shape: {df.shape}",
124
+ f"Columns: {', '.join(df.columns)}",
125
+ "\nColumn Types:",
126
+ *[f"- {col}: {dtype}" for col, dtype in df.dtypes.items()]
127
+ ])
128
+
129
+ # Build prompt
130
  prompt = f"""
131
+ {file_info}
132
+
133
+ The data has been loaded into a pandas DataFrame called 'df'.
134
+ Available tools:
135
+ - search_web: Search for relevant information
136
+ - analyze_dataframe: Perform basic DataFrame analysis
137
+ - plot_data: Create various plots
138
+
139
+ Additional capabilities:
140
+ - Full pandas, numpy, matplotlib, seaborn access
141
+ - Machine learning with sklearn
142
+ - Statistical analysis with scipy
143
 
144
+ User request: {query}
145
 
146
  Please analyze the data and provide:
147
+ 1. A clear explanation of your approach
148
+ 2. Code for the analysis
149
+ 3. Visualizations where relevant
150
  4. Key insights and findings
 
151
 
152
+ Make use of the available tools and libraries to provide comprehensive analysis.
153
+ """
 
 
 
 
154
 
155
+ # Run analysis
156
+ result = agent.run(prompt, additional_args={"df": df})
157
+ return result
158
 
159
  except Exception as e:
160
+ return f"Error occurred: {str(e)}"
161
 
 
162
  def create_interface():
163
+ """Create Gradio interface."""
164
 
165
+ with gr.Blocks(title="AI Agent Testing Interface") as interface:
166
  gr.Markdown("""
167
+ # AI Agent Testing Interface
168
+
169
+ Test the capabilities of AI agents using smolagents library. Upload data files and ask questions in natural language.
170
+
171
+ **Features:**
172
+ - Data analysis and visualization
173
+ - Machine learning capabilities
174
+ - Web search simulation
175
+ - Statistical analysis
176
+ - Custom tool integration
177
+
178
+ **Note**: Requires your own API key for GPT-4.
179
  """)
180
 
181
  with gr.Row():
182
  with gr.Column():
183
+ file = gr.File(
184
+ label="Upload Data Files (CSV/Excel)",
185
+ file_types=[".csv", ".xlsx", ".xls"]
 
186
  )
187
  query = gr.Textbox(
188
  label="What would you like to analyze?",
189
+ placeholder="e.g., Analyze the relationships between variables and create visualizations",
190
+ lines=3
191
  )
192
  api_key = gr.Textbox(
193
+ label="API Key (Required)",
194
+ placeholder="Your API key",
195
  type="password"
196
  )
197
  temperature = gr.Slider(
 
201
  value=0.7,
202
  step=0.1
203
  )
204
+ analyze_btn = gr.Button("Analyze")
205
+
206
  with gr.Column():
207
  output = gr.Markdown(label="Output")
 
 
 
 
208
 
209
  # Handle submissions
210
+ analyze_btn.click(
211
+ analyze_data,
212
+ inputs=[file, query, api_key, temperature],
213
+ outputs=output
214
  )
215
 
216
+ # Example queries
217
  gr.Examples(
218
  examples=[
219
+ [None, "Perform comprehensive exploratory data analysis including distributions, correlations, and key statistics"],
220
+ [None, "Create visualizations showing relationships between numeric variables"],
221
+ [None, "Identify and analyze outliers in the dataset"],
222
+ [None, "Perform clustering analysis and visualize the results"],
223
+ [None, "Calculate summary statistics and create box plots for numeric columns"],
224
+ [None, "Analyze trends and patterns in the data over time"],
 
 
 
 
 
 
225
  ],
226
+ inputs=[file, query]
227
  )
228
 
229
  return interface