Anupam202224 commited on
Commit
5058119
·
verified ·
1 Parent(s): c4c8dcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -70
app.py CHANGED
@@ -8,37 +8,37 @@ import matplotlib.pyplot as plt
8
  import seaborn as sns
9
 
10
  # Define constants
11
- MODEL_NAME = "meta-llama/Llama-2-7b-hf" # Replace with a smaller model suitable for CPU
12
  FIGURES_DIR = "./figures"
13
 
14
  # Ensure the figures directory exists
15
  os.makedirs(FIGURES_DIR, exist_ok=True)
16
 
17
  # Initialize tokenizer and model
18
- # Note: Loading large models on CPU can be very slow and may not be feasible
19
  try:
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="cpu")
 
 
22
  except Exception as e:
23
  print(f"Error loading model: {e}")
24
  exit(1)
25
 
26
- # Define the base prompt
27
  base_prompt = """You are an expert data analyst.
28
- According to the features you have and the data structure given below, determine which feature should be the target.
29
- Then list 3 interesting questions that could be asked on this data, for instance about specific correlations with target variable.
30
- Then answer these questions one by one, by finding the relevant numbers.
31
- Meanwhile, plot some figures using matplotlib/seaborn and save them to the (already existing) folder './figures/': take care to clear each figure with plt.clf() before doing another plot.
32
 
33
- In your final answer: summarize these correlations and trends
34
- After each number derive real worlds insights, for instance: "Correlation between is_december and boredness is 1.3453, which suggest people are more bored in winter".
35
- Your final answer should be a long string with at least 3 numbered and detailed parts.
36
 
37
- Structure of the data:
38
- {structure_notes}
39
 
40
- The data file is passed to you as the variable data_file, it is a pandas dataframe, you can use it directly.
41
- DO NOT try to load data_file, it is already a dataframe pre-loaded in your python interpreter!
42
  """
43
 
44
  example_notes = """This data is about the Titanic wreck in 1912.
@@ -48,13 +48,8 @@ pclass: A proxy for socio-economic status (SES)
48
  2nd = Middle
49
  3rd = Lower
50
  age: Age is fractional if less than 1. If the age is estimated, it is in the form of xx.5
51
- sibsp: The dataset defines family relations in this way...
52
- Sibling = brother, sister, stepbrother, stepsister
53
- Spouse = husband, wife (mistresses and fiancés were ignored)
54
- parch: The dataset defines family relations in this way...
55
- Parent = mother, father
56
- Child = daughter, son, stepdaughter, stepson
57
- Some children traveled only with a nanny, therefore parch=0 for them."""
58
 
59
  def get_images_in_directory(directory):
60
  """Retrieve all image file paths from the specified directory."""
@@ -66,25 +61,75 @@ def get_images_in_directory(directory):
66
  image_files.append(os.path.join(root, file))
67
  return image_files
68
 
69
- def generate_response(prompt):
70
- """Generate a response from the language model based on the prompt."""
71
- inputs = tokenizer(prompt, return_tensors="pt")
72
  inputs = inputs.to('cpu') # Ensure the model runs on CPU
73
 
74
- # Generate response (adjust parameters as needed)
75
  with torch.no_grad():
76
  outputs = model.generate(
77
- **inputs,
78
- max_length=2048,
79
  do_sample=True,
80
  top_p=0.95,
81
  temperature=0.7,
82
- eos_token_id=tokenizer.eos_token_id
 
83
  )
84
 
85
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
86
  return response
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def interact_with_agent(file_input, additional_notes):
89
  """Process the uploaded file and interact with the language model to analyze data."""
90
  # Clear and recreate the figures directory
@@ -92,36 +137,37 @@ def interact_with_agent(file_input, additional_notes):
92
  shutil.rmtree(FIGURES_DIR)
93
  os.makedirs(FIGURES_DIR, exist_ok=True)
94
 
95
- # Load the data file into a pandas dataframe
96
- try:
97
- data_file = pd.read_csv(file_input.name)
98
- except Exception as e:
99
- yield [("Error loading CSV file.",)]
100
  return
101
 
102
- # Create structure notes
103
- data_structure_notes = f"""- Description (output of .describe()):
104
- {data_file.describe()}
105
- - Columns with dtypes:
106
- {data_file.dtypes}"""
 
107
 
108
- # Construct the prompt
109
- prompt = base_prompt.format(structure_notes=data_structure_notes)
 
 
 
110
 
111
- if additional_notes and additional_notes.strip():
112
- prompt += "\nAdditional notes on the data:\n" + additional_notes
113
 
114
- # Initialize chat history
115
- messages = [("User", prompt)]
116
- yield messages + [("Assistant", " _Starting analysis..._")]
 
 
117
 
118
- # Generate response from the model
119
- response = generate_response(prompt)
120
- messages.append(("Assistant", response))
121
 
122
- # Extract and display generated images
123
- image_paths = get_images_in_directory(FIGURES_DIR)
124
- for image_path in image_paths:
125
  messages.append(("Assistant", gr.Image.update(value=image_path)))
126
 
127
  yield messages
@@ -129,40 +175,40 @@ def interact_with_agent(file_input, additional_notes):
129
  # Define the Gradio interface
130
  with gr.Blocks(
131
  theme=gr.themes.Soft(
132
- primary_hue=gr.themes.colors.yellow,
133
- secondary_hue=gr.themes.colors.blue,
134
  )
135
  ) as demo:
136
- gr.Markdown("""# Llama-2 Data Analyst 📊🤔
 
 
 
 
 
137
 
138
- Drop a `.csv` file below, add notes to describe this data if needed, and **the model will analyze the file content and draw figures for you!**""")
139
-
140
  with gr.Row():
141
- file_input = gr.File(label="Your file to analyze", type="file")
142
  text_input = gr.Textbox(
143
- label="Additional notes to support the analysis",
144
- placeholder="Enter any additional notes here..."
145
  )
146
-
147
- submit = gr.Button("Run analysis!", variant="primary")
148
-
149
- chatbot = gr.Chatbot(
150
- label="Data Analyst Agent",
151
- height=400,
152
- )
153
-
154
  gr.Examples(
155
  examples=[["./example/titanic.csv", example_notes]],
156
  inputs=[file_input, text_input],
 
157
  cache_examples=False
158
  )
159
-
160
  # Connect the submit button to the interact_with_agent function
161
  submit.click(
162
  interact_with_agent,
163
  inputs=[file_input, text_input],
164
  outputs=[chatbot],
165
- show_progress=True
166
  )
167
 
168
  # Launch the Gradio app
 
8
  import seaborn as sns
9
 
10
  # Define constants
11
+ MODEL_NAME = "gpt2" # Publicly accessible model suitable for CPU
12
  FIGURES_DIR = "./figures"
13
 
14
  # Ensure the figures directory exists
15
  os.makedirs(FIGURES_DIR, exist_ok=True)
16
 
17
  # Initialize tokenizer and model
18
+ print("Loading model and tokenizer...")
19
  try:
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
22
+ model.to('cpu') # Ensure the model runs on CPU
23
+ print("Model and tokenizer loaded successfully.")
24
  except Exception as e:
25
  print(f"Error loading model: {e}")
26
  exit(1)
27
 
28
+ # Define the base prompt for the model
29
  base_prompt = """You are an expert data analyst.
30
+ Based on the following data description, determine an appropriate target feature.
31
+ List 3 insightful questions regarding the data.
32
+ Provide detailed answers to each question with relevant statistics.
33
+ Summarize the findings with real-world insights.
34
 
35
+ Data Description:
36
+ {data_description}
 
37
 
38
+ Additional Notes:
39
+ {additional_notes}
40
 
41
+ Please provide your response in a structured and detailed manner.
 
42
  """
43
 
44
  example_notes = """This data is about the Titanic wreck in 1912.
 
48
  2nd = Middle
49
  3rd = Lower
50
  age: Age is fractional if less than 1. If the age is estimated, it is in the form of xx.5
51
+ sibsp: Number of siblings/spouses aboard
52
+ parch: Number of parents/children aboard"""
 
 
 
 
 
53
 
54
  def get_images_in_directory(directory):
55
  """Retrieve all image file paths from the specified directory."""
 
61
  image_files.append(os.path.join(root, file))
62
  return image_files
63
 
64
+ def generate_summary(prompt):
65
+ """Generate a summary from the language model based on the prompt."""
66
+ inputs = tokenizer.encode(prompt, return_tensors="pt")
67
  inputs = inputs.to('cpu') # Ensure the model runs on CPU
68
 
69
+ # Generate response
70
  with torch.no_grad():
71
  outputs = model.generate(
72
+ inputs,
73
+ max_length=500,
74
  do_sample=True,
75
  top_p=0.95,
76
  temperature=0.7,
77
+ eos_token_id=tokenizer.eos_token_id,
78
+ pad_token_id=tokenizer.eos_token_id
79
  )
80
 
81
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
82
  return response
83
 
84
+ def analyze_data(data_file_path):
85
+ """Perform data analysis on the uploaded CSV file."""
86
+ try:
87
+ data = pd.read_csv(data_file_path)
88
+ except Exception as e:
89
+ return None, f"Error loading CSV file: {e}"
90
+
91
+ # Generate data description
92
+ data_description = f"- **Data Summary (.describe()):**\n{data.describe().to_markdown()}\n\n"
93
+ data_description += f"- **Data Types:**\n{data.dtypes.to_frame().to_markdown()}\n"
94
+
95
+ # Determine target variable (for demonstration, assume 'Survived' or first numeric column)
96
+ if 'Survived' in data.columns:
97
+ target = 'Survived'
98
+ else:
99
+ numeric_cols = data.select_dtypes(include='number').columns
100
+ target = numeric_cols[0] if len(numeric_cols) > 0 else data.columns[0]
101
+
102
+ # Generate visualizations
103
+ visualization_paths = []
104
+
105
+ # Correlation heatmap
106
+ plt.figure(figsize=(10, 8))
107
+ sns.heatmap(data.corr(), annot=True, fmt=".2f", cmap='coolwarm')
108
+ plt.title("Correlation Heatmap")
109
+ heatmap_path = os.path.join(FIGURES_DIR, "correlation_heatmap.png")
110
+ plt.savefig(heatmap_path)
111
+ plt.clf()
112
+ visualization_paths.append(heatmap_path)
113
+
114
+ # Distribution of target variable
115
+ plt.figure(figsize=(8, 6))
116
+ sns.countplot(x=target, data=data)
117
+ plt.title(f"Distribution of {target}")
118
+ plt.savefig(os.path.join(FIGURES_DIR, f"{target}_distribution.png"))
119
+ plt.clf()
120
+ visualization_paths.append(os.path.join(FIGURES_DIR, f"{target}_distribution.png"))
121
+
122
+ # Pairplot (limited to first 5 numeric columns for performance)
123
+ numeric_cols = data.select_dtypes(include='number').columns[:5]
124
+ if len(numeric_cols) >= 2:
125
+ sns.pairplot(data[numeric_cols].dropna())
126
+ pairplot_path = os.path.join(FIGURES_DIR, "pairplot.png")
127
+ plt.savefig(pairplot_path)
128
+ plt.clf()
129
+ visualization_paths.append(pairplot_path)
130
+
131
+ return data_description, visualization_paths, target
132
+
133
  def interact_with_agent(file_input, additional_notes):
134
  """Process the uploaded file and interact with the language model to analyze data."""
135
  # Clear and recreate the figures directory
 
137
  shutil.rmtree(FIGURES_DIR)
138
  os.makedirs(FIGURES_DIR, exist_ok=True)
139
 
140
+ if file_input is None:
141
+ yield [("Error", "No file uploaded.")]
 
 
 
142
  return
143
 
144
+ # Analyze the data
145
+ data_description, visualization_paths, target = analyze_data(file_input.name)
146
+
147
+ if data_description is None:
148
+ yield [("Error", visualization_paths)] # visualization_paths contains the error message
149
+ return
150
 
151
+ # Construct the prompt for the model
152
+ prompt = base_prompt.format(
153
+ data_description=data_description,
154
+ additional_notes=additional_notes if additional_notes else "None."
155
+ )
156
 
157
+ # Generate summary from the model
158
+ summary = generate_summary(prompt)
159
 
160
+ # Prepare chat messages
161
+ messages = [
162
+ ("User", "I have uploaded a CSV file for analysis."),
163
+ ("Assistant", "⏳ _Analyzing the data..._")
164
+ ]
165
 
166
+ # Append the summary
167
+ messages.append(("Assistant", summary))
 
168
 
169
+ # Append images
170
+ for image_path in visualization_paths:
 
171
  messages.append(("Assistant", gr.Image.update(value=image_path)))
172
 
173
  yield messages
 
175
  # Define the Gradio interface
176
  with gr.Blocks(
177
  theme=gr.themes.Soft(
178
+ primary_hue=gr.themes.colors.blue,
179
+ secondary_hue=gr.themes.colors.orange,
180
  )
181
  ) as demo:
182
+ gr.Markdown("""# 📊 Data Analyst Assistant
183
+
184
+ Upload a `.csv` file, add any additional notes, and **the assistant will analyze the data and generate visualizations and insights for you!**
185
+
186
+ **Example:** [Titanic Dataset](./example/titanic.csv)
187
+ """)
188
 
 
 
189
  with gr.Row():
190
+ file_input = gr.File(label="Upload CSV File", file_types=[".csv"])
191
  text_input = gr.Textbox(
192
+ label="Additional Notes",
193
+ placeholder="Enter any additional notes or leave blank..."
194
  )
195
+
196
+ submit = gr.Button("Run Analysis", variant="primary")
197
+ chatbot = gr.Chatbot(label="Data Analyst Agent")
198
+
 
 
 
 
199
  gr.Examples(
200
  examples=[["./example/titanic.csv", example_notes]],
201
  inputs=[file_input, text_input],
202
+ label="Examples",
203
  cache_examples=False
204
  )
205
+
206
  # Connect the submit button to the interact_with_agent function
207
  submit.click(
208
  interact_with_agent,
209
  inputs=[file_input, text_input],
210
  outputs=[chatbot],
211
+ api_name="run_analysis"
212
  )
213
 
214
  # Launch the Gradio app