jzou19950715 commited on
Commit
21df540
·
verified ·
1 Parent(s): 2952e7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -70
app.py CHANGED
@@ -3,22 +3,21 @@ from typing import Optional
3
 
4
  import gradio as gr
5
  import pandas as pd
6
- from smolagents import CodeAgent, LiteLLMModel, tool
7
 
 
8
 
9
- # Tool definitions to showcase smolagents capabilities
10
- @tool
11
  def analyze_dataframe(df: pd.DataFrame, analysis_type: str) -> str:
12
- """Analyze a pandas DataFrame"""
13
  if analysis_type == "summary":
14
  return str(df.describe())
15
  elif analysis_type == "info":
16
- return str(df.info())
 
 
17
  return "Unknown analysis type"
18
 
19
- @tool
20
  def plot_data(df: pd.DataFrame, plot_type: str) -> None:
21
- """Create plots from DataFrame"""
22
  import matplotlib.pyplot as plt
23
  import seaborn as sns
24
 
@@ -31,21 +30,18 @@ def plot_data(df: pd.DataFrame, plot_type: str) -> None:
31
  plt.tight_layout()
32
 
33
  def process_file(file: gr.File) -> Optional[pd.DataFrame]:
34
- """Process uploaded file into a DataFrame"""
35
  if not file:
36
  return None
37
-
38
  try:
39
  if file.name.endswith('.csv'):
40
- df = pd.read_csv(file.name)
41
  elif file.name.endswith(('.xlsx', '.xls')):
42
- df = pd.read_excel(file.name)
43
- else:
44
- return None
45
- return df
46
  except Exception as e:
47
- print(f"Error reading {file.name}: {str(e)}")
48
- return None
49
 
50
  def analyze_data(
51
  file: gr.File,
@@ -53,68 +49,56 @@ def analyze_data(
53
  api_key: str,
54
  temperature: float = 0.7,
55
  ) -> str:
56
- """Process user request and generate analysis using smolagents"""
57
 
58
  if not api_key:
59
  return "Error: Please provide an API key."
60
-
61
  if not file:
62
  return "Error: Please upload a file."
63
-
64
  try:
65
  # Set up environment
66
  os.environ["OPENAI_API_KEY"] = api_key
67
 
68
- # Create model and agent
69
- model = LiteLLMModel(
70
  model_id="gpt-4o-mini",
71
  temperature=temperature
72
  )
73
 
74
- # Create agent with various tools
75
- agent = CodeAgent(
76
- tools=[analyze_dataframe, plot_data],
77
- model=model,
78
- additional_authorized_imports=[
79
- "pandas",
80
- "numpy",
81
- "matplotlib",
82
- "seaborn",
83
- "plotly",
84
- "sklearn",
85
- "scipy"
86
- ],
87
- max_steps=5,
88
- verbosity_level=1
89
  )
90
 
91
- # Process uploaded file
92
  df = process_file(file)
93
  if df is None:
94
- return "Error: Could not process uploaded file."
95
-
96
  # Build context
97
  file_info = f"""
98
- Uploaded file: {file.name}
99
- DataFrame Shape: {df.shape}
100
  Columns: {', '.join(df.columns)}
 
101
  Column Types:
102
  {chr(10).join([f'- {col}: {dtype}' for col, dtype in df.dtypes.items()])}
103
  """
104
 
105
- # Build prompt
106
  prompt = f"""
107
  {file_info}
108
-
109
- The data has been loaded into a pandas DataFrame called 'df'.
110
- Available tools:
111
- - analyze_dataframe: Perform basic DataFrame analysis
112
- - plot_data: Create various plots
113
 
114
- Additional capabilities:
115
- - Full pandas, numpy, matplotlib, seaborn access
116
- - Machine learning with sklearn
117
- - Statistical analysis with scipy
118
 
119
  User request: {query}
120
 
@@ -125,9 +109,7 @@ def analyze_data(
125
  4. Key insights and findings
126
  """
127
 
128
- # Run analysis
129
- result = agent.run(prompt, additional_args={"df": df})
130
- return result
131
 
132
  except Exception as e:
133
  return f"Error occurred: {str(e)}"
@@ -135,30 +117,29 @@ def analyze_data(
135
  def create_interface():
136
  """Create Gradio interface"""
137
 
138
- with gr.Blocks(title="AI Agent Testing Interface") as interface:
139
  gr.Markdown("""
140
- # AI Agent Testing Interface
141
 
142
- Test the capabilities of AI agents using smolagents library. Upload data files and ask questions in natural language.
143
 
144
  **Features:**
145
  - Data analysis and visualization
146
- - Machine learning capabilities
147
  - Statistical analysis
148
- - Custom tool integration
149
 
150
- **Note**: Requires your own API key for GPT-4.
151
  """)
152
 
153
  with gr.Row():
154
  with gr.Column():
155
  file = gr.File(
156
- label="Upload Data File (CSV/Excel)",
157
  file_types=[".csv", ".xlsx", ".xls"]
158
  )
159
  query = gr.Textbox(
160
  label="What would you like to analyze?",
161
- placeholder="e.g., Analyze the relationships between variables and create visualizations",
162
  lines=3
163
  )
164
  api_key = gr.Textbox(
@@ -178,25 +159,22 @@ def create_interface():
178
  with gr.Column():
179
  output = gr.Markdown(label="Output")
180
 
181
- # Handle submissions
182
  analyze_btn.click(
183
  analyze_data,
184
  inputs=[file, query, api_key, temperature],
185
  outputs=output
186
  )
187
 
188
- # Example queries
189
  gr.Examples(
190
  examples=[
191
- [None, "Perform comprehensive exploratory data analysis including distributions, correlations, and key statistics"],
192
- [None, "Create visualizations showing relationships between numeric variables"],
193
- [None, "Identify and analyze outliers in the dataset"],
194
- [None, "Perform clustering analysis and visualize the results"],
195
- [None, "Calculate summary statistics and create box plots for numeric columns"],
196
  ],
197
  inputs=[file, query]
198
  )
199
-
200
  return interface
201
 
202
  if __name__ == "__main__":
 
3
 
4
  import gradio as gr
5
  import pandas as pd
 
6
 
7
+ from minimal_agent import MinimalAgent
8
 
 
 
9
  def analyze_dataframe(df: pd.DataFrame, analysis_type: str) -> str:
10
+ """Basic DataFrame analysis"""
11
  if analysis_type == "summary":
12
  return str(df.describe())
13
  elif analysis_type == "info":
14
+ buffer = []
15
+ df.info(buf=buffer)
16
+ return "\n".join(buffer)
17
  return "Unknown analysis type"
18
 
 
19
  def plot_data(df: pd.DataFrame, plot_type: str) -> None:
20
+ """Basic plotting function"""
21
  import matplotlib.pyplot as plt
22
  import seaborn as sns
23
 
 
30
  plt.tight_layout()
31
 
32
  def process_file(file: gr.File) -> Optional[pd.DataFrame]:
33
+ """Process uploaded file into DataFrame"""
34
  if not file:
35
  return None
36
+
37
  try:
38
  if file.name.endswith('.csv'):
39
+ return pd.read_csv(file.name)
40
  elif file.name.endswith(('.xlsx', '.xls')):
41
+ return pd.read_excel(file.name)
 
 
 
42
  except Exception as e:
43
+ print(f"Error reading file: {str(e)}")
44
+ return None
45
 
46
  def analyze_data(
47
  file: gr.File,
 
49
  api_key: str,
50
  temperature: float = 0.7,
51
  ) -> str:
52
+ """Process user request and generate analysis"""
53
 
54
  if not api_key:
55
  return "Error: Please provide an API key."
56
+
57
  if not file:
58
  return "Error: Please upload a file."
59
+
60
  try:
61
  # Set up environment
62
  os.environ["OPENAI_API_KEY"] = api_key
63
 
64
+ # Create agent
65
+ agent = MinimalAgent(
66
  model_id="gpt-4o-mini",
67
  temperature=temperature
68
  )
69
 
70
+ # Add tools
71
+ agent.add_tool(
72
+ "analyze_dataframe",
73
+ "Analyze DataFrame with various metrics",
74
+ analyze_dataframe
75
+ )
76
+ agent.add_tool(
77
+ "plot_data",
78
+ "Create various plots from DataFrame",
79
+ plot_data
 
 
 
 
 
80
  )
81
 
82
+ # Process file
83
  df = process_file(file)
84
  if df is None:
85
+ return "Error: Could not process file."
86
+
87
  # Build context
88
  file_info = f"""
89
+ File: {file.name}
90
+ Shape: {df.shape}
91
  Columns: {', '.join(df.columns)}
92
+
93
  Column Types:
94
  {chr(10).join([f'- {col}: {dtype}' for col, dtype in df.dtypes.items()])}
95
  """
96
 
97
+ # Run analysis
98
  prompt = f"""
99
  {file_info}
 
 
 
 
 
100
 
101
+ The data is loaded in a pandas DataFrame called 'df'.
 
 
 
102
 
103
  User request: {query}
104
 
 
109
  4. Key insights and findings
110
  """
111
 
112
+ return agent.run(prompt)
 
 
113
 
114
  except Exception as e:
115
  return f"Error occurred: {str(e)}"
 
117
  def create_interface():
118
  """Create Gradio interface"""
119
 
120
+ with gr.Blocks(title="AI Data Analysis Assistant") as interface:
121
  gr.Markdown("""
122
+ # AI Data Analysis Assistant
123
 
124
+ Upload your data file and ask questions in natural language.
125
 
126
  **Features:**
127
  - Data analysis and visualization
 
128
  - Statistical analysis
129
+ - Machine learning capabilities
130
 
131
+ **Note**: Requires your own GPT-4 API key.
132
  """)
133
 
134
  with gr.Row():
135
  with gr.Column():
136
  file = gr.File(
137
+ label="Upload Data File",
138
  file_types=[".csv", ".xlsx", ".xls"]
139
  )
140
  query = gr.Textbox(
141
  label="What would you like to analyze?",
142
+ placeholder="e.g., Create visualizations showing relationships between variables",
143
  lines=3
144
  )
145
  api_key = gr.Textbox(
 
159
  with gr.Column():
160
  output = gr.Markdown(label="Output")
161
 
 
162
  analyze_btn.click(
163
  analyze_data,
164
  inputs=[file, query, api_key, temperature],
165
  outputs=output
166
  )
167
 
 
168
  gr.Examples(
169
  examples=[
170
+ [None, "Show key statistics and create visualizations for numeric columns"],
171
+ [None, "Find correlations and patterns in the data"],
172
+ [None, "Identify outliers and unusual patterns"],
173
+ [None, "Create summary visualizations of the main variables"],
 
174
  ],
175
  inputs=[file, query]
176
  )
177
+
178
  return interface
179
 
180
  if __name__ == "__main__":