jzou19950715 commited on
Commit
2952e7a
·
verified ·
1 Parent(s): cf9db86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -64
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import List, Optional, Union
3
 
4
  import gradio as gr
5
  import pandas as pd
@@ -7,20 +7,9 @@ from smolagents import CodeAgent, LiteLLMModel, tool
7
 
8
 
9
  # Tool definitions to showcase smolagents capabilities
10
- @tool
11
- def search_web(query: str) -> str:
12
- """Simulate web search (for demo purposes)"""
13
- return f"Simulated web search results for: {query}"
14
-
15
  @tool
16
  def analyze_dataframe(df: pd.DataFrame, analysis_type: str) -> str:
17
- """
18
- Analyze a pandas DataFrame based on specified analysis type.
19
-
20
- Args:
21
- df: DataFrame to analyze
22
- analysis_type: Type of analysis to perform
23
- """
24
  if analysis_type == "summary":
25
  return str(df.describe())
26
  elif analysis_type == "info":
@@ -29,13 +18,7 @@ def analyze_dataframe(df: pd.DataFrame, analysis_type: str) -> str:
29
 
30
  @tool
31
  def plot_data(df: pd.DataFrame, plot_type: str) -> None:
32
- """
33
- Create plots from DataFrame.
34
-
35
- Args:
36
- df: DataFrame to plot
37
- plot_type: Type of plot to create
38
- """
39
  import matplotlib.pyplot as plt
40
  import seaborn as sns
41
 
@@ -47,45 +30,39 @@ def plot_data(df: pd.DataFrame, plot_type: str) -> None:
47
  df.hist(figsize=(15, 10))
48
  plt.tight_layout()
49
 
50
- def process_files(files: List[gr.File]) -> Optional[pd.DataFrame]:
51
- """Process uploaded files into a DataFrame."""
52
- if not files:
53
  return None
54
 
55
- dfs = []
56
- for file in files:
57
- try:
58
- if file.name.endswith('.csv'):
59
- df = pd.read_csv(file.name)
60
- elif file.name.endswith(('.xlsx', '.xls')):
61
- df = pd.read_excel(file.name)
62
- else:
63
- continue
64
- dfs.append(df)
65
- except Exception as e:
66
- print(f"Error reading {file.name}: {str(e)}")
67
-
68
- if not dfs:
69
  return None
70
-
71
- return pd.concat(dfs) if len(dfs) > 1 else dfs[0]
72
 
73
  def analyze_data(
74
- files: List[gr.File],
75
  query: str,
76
  api_key: str,
77
  temperature: float = 0.7,
78
  ) -> str:
79
- """Process user request and generate analysis using smolagents."""
80
 
81
  if not api_key:
82
  return "Error: Please provide an API key."
83
 
84
- if not files:
85
- return "Error: Please upload at least one file."
86
 
87
  try:
88
- # Set up the environment
89
  os.environ["OPENAI_API_KEY"] = api_key
90
 
91
  # Create model and agent
@@ -94,9 +71,9 @@ def analyze_data(
94
  temperature=temperature
95
  )
96
 
97
- # Create agent with various tools to showcase capabilities
98
  agent = CodeAgent(
99
- tools=[search_web, analyze_dataframe, plot_data],
100
  model=model,
101
  additional_authorized_imports=[
102
  "pandas",
@@ -111,20 +88,19 @@ def analyze_data(
111
  verbosity_level=1
112
  )
113
 
114
- # Process uploaded files
115
- df = process_files(files)
116
  if df is None:
117
- return "Error: Could not process uploaded files."
118
 
119
  # Build context
120
- file_info = "\n".join([
121
- "Uploaded files:",
122
- *[f"- {f.name}" for f in files],
123
- f"\nDataFrame Shape: {df.shape}",
124
- f"Columns: {', '.join(df.columns)}",
125
- "\nColumn Types:",
126
- *[f"- {col}: {dtype}" for col, dtype in df.dtypes.items()]
127
- ])
128
 
129
  # Build prompt
130
  prompt = f"""
@@ -132,7 +108,6 @@ def analyze_data(
132
 
133
  The data has been loaded into a pandas DataFrame called 'df'.
134
  Available tools:
135
- - search_web: Search for relevant information
136
  - analyze_dataframe: Perform basic DataFrame analysis
137
  - plot_data: Create various plots
138
 
@@ -148,8 +123,6 @@ def analyze_data(
148
  2. Code for the analysis
149
  3. Visualizations where relevant
150
  4. Key insights and findings
151
-
152
- Make use of the available tools and libraries to provide comprehensive analysis.
153
  """
154
 
155
  # Run analysis
@@ -160,7 +133,7 @@ def analyze_data(
160
  return f"Error occurred: {str(e)}"
161
 
162
  def create_interface():
163
- """Create Gradio interface."""
164
 
165
  with gr.Blocks(title="AI Agent Testing Interface") as interface:
166
  gr.Markdown("""
@@ -171,7 +144,6 @@ def create_interface():
171
  **Features:**
172
  - Data analysis and visualization
173
  - Machine learning capabilities
174
- - Web search simulation
175
  - Statistical analysis
176
  - Custom tool integration
177
 
@@ -181,7 +153,7 @@ def create_interface():
181
  with gr.Row():
182
  with gr.Column():
183
  file = gr.File(
184
- label="Upload Data Files (CSV/Excel)",
185
  file_types=[".csv", ".xlsx", ".xls"]
186
  )
187
  query = gr.Textbox(
@@ -221,7 +193,6 @@ def create_interface():
221
  [None, "Identify and analyze outliers in the dataset"],
222
  [None, "Perform clustering analysis and visualize the results"],
223
  [None, "Calculate summary statistics and create box plots for numeric columns"],
224
- [None, "Analyze trends and patterns in the data over time"],
225
  ],
226
  inputs=[file, query]
227
  )
 
1
  import os
2
+ from typing import Optional
3
 
4
  import gradio as gr
5
  import pandas as pd
 
7
 
8
 
9
  # Tool definitions to showcase smolagents capabilities
 
 
 
 
 
10
  @tool
11
  def analyze_dataframe(df: pd.DataFrame, analysis_type: str) -> str:
12
+ """Analyze a pandas DataFrame"""
 
 
 
 
 
 
13
  if analysis_type == "summary":
14
  return str(df.describe())
15
  elif analysis_type == "info":
 
18
 
19
  @tool
20
  def plot_data(df: pd.DataFrame, plot_type: str) -> None:
21
+ """Create plots from DataFrame"""
 
 
 
 
 
 
22
  import matplotlib.pyplot as plt
23
  import seaborn as sns
24
 
 
30
  df.hist(figsize=(15, 10))
31
  plt.tight_layout()
32
 
33
+ def process_file(file: gr.File) -> Optional[pd.DataFrame]:
34
+ """Process uploaded file into a DataFrame"""
35
+ if not file:
36
  return None
37
 
38
+ try:
39
+ if file.name.endswith('.csv'):
40
+ df = pd.read_csv(file.name)
41
+ elif file.name.endswith(('.xlsx', '.xls')):
42
+ df = pd.read_excel(file.name)
43
+ else:
44
+ return None
45
+ return df
46
+ except Exception as e:
47
+ print(f"Error reading {file.name}: {str(e)}")
 
 
 
 
48
  return None
 
 
49
 
50
  def analyze_data(
51
+ file: gr.File,
52
  query: str,
53
  api_key: str,
54
  temperature: float = 0.7,
55
  ) -> str:
56
+ """Process user request and generate analysis using smolagents"""
57
 
58
  if not api_key:
59
  return "Error: Please provide an API key."
60
 
61
+ if not file:
62
+ return "Error: Please upload a file."
63
 
64
  try:
65
+ # Set up environment
66
  os.environ["OPENAI_API_KEY"] = api_key
67
 
68
  # Create model and agent
 
71
  temperature=temperature
72
  )
73
 
74
+ # Create agent with various tools
75
  agent = CodeAgent(
76
+ tools=[analyze_dataframe, plot_data],
77
  model=model,
78
  additional_authorized_imports=[
79
  "pandas",
 
88
  verbosity_level=1
89
  )
90
 
91
+ # Process uploaded file
92
+ df = process_file(file)
93
  if df is None:
94
+ return "Error: Could not process uploaded file."
95
 
96
  # Build context
97
+ file_info = f"""
98
+ Uploaded file: {file.name}
99
+ DataFrame Shape: {df.shape}
100
+ Columns: {', '.join(df.columns)}
101
+ Column Types:
102
+ {chr(10).join([f'- {col}: {dtype}' for col, dtype in df.dtypes.items()])}
103
+ """
 
104
 
105
  # Build prompt
106
  prompt = f"""
 
108
 
109
  The data has been loaded into a pandas DataFrame called 'df'.
110
  Available tools:
 
111
  - analyze_dataframe: Perform basic DataFrame analysis
112
  - plot_data: Create various plots
113
 
 
123
  2. Code for the analysis
124
  3. Visualizations where relevant
125
  4. Key insights and findings
 
 
126
  """
127
 
128
  # Run analysis
 
133
  return f"Error occurred: {str(e)}"
134
 
135
  def create_interface():
136
+ """Create Gradio interface"""
137
 
138
  with gr.Blocks(title="AI Agent Testing Interface") as interface:
139
  gr.Markdown("""
 
144
  **Features:**
145
  - Data analysis and visualization
146
  - Machine learning capabilities
 
147
  - Statistical analysis
148
  - Custom tool integration
149
 
 
153
  with gr.Row():
154
  with gr.Column():
155
  file = gr.File(
156
+ label="Upload Data File (CSV/Excel)",
157
  file_types=[".csv", ".xlsx", ".xls"]
158
  )
159
  query = gr.Textbox(
 
193
  [None, "Identify and analyze outliers in the dataset"],
194
  [None, "Perform clustering analysis and visualize the results"],
195
  [None, "Calculate summary statistics and create box plots for numeric columns"],
 
196
  ],
197
  inputs=[file, query]
198
  )