Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
import gradio as gr | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import pandas as pd | |
import torch | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import base64 | |
# Define constants | |
MODEL_NAME = "facebook/bart-large-cnn" # Fine-tuned for summarization | |
FIGURES_DIR = "./figures" | |
EXAMPLE_DIR = "./example" | |
EXAMPLE_FILE = os.path.join(EXAMPLE_DIR, "titanic.csv") | |
# Ensure the figures and example directories exist | |
os.makedirs(FIGURES_DIR, exist_ok=True) | |
os.makedirs(EXAMPLE_DIR, exist_ok=True) | |
# Download the Titanic dataset if it doesn't exist | |
if not os.path.isfile(EXAMPLE_FILE): | |
print("Downloading the Titanic dataset for examples...") | |
try: | |
# Using seaborn's built-in Titanic dataset | |
titanic = sns.load_dataset('titanic') | |
titanic.to_csv(EXAMPLE_FILE, index=False) | |
print(f"Example dataset saved to {EXAMPLE_FILE}.") | |
except Exception as e: | |
print(f"Failed to download the Titanic dataset: {e}") | |
print("Please ensure the 'example/titanic.csv' file exists.") | |
# Optionally, exit or continue without examples | |
# exit(1) | |
# Initialize tokenizer and model | |
print("Loading model and tokenizer...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
model.to('cpu') # Ensure the model runs on CPU | |
print("Model and tokenizer loaded successfully.") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
exit(1) | |
# Define the base prompt for the model | |
base_prompt = """You are an expert data analyst. | |
Based on the following data description, determine an appropriate target feature. | |
List 3 insightful questions regarding the data. | |
Provide detailed answers to each question with relevant statistics. | |
Summarize the findings with real-world insights. | |
Data Description: | |
{data_description} | |
Additional Notes: | |
{additional_notes} | |
Please provide your response in a structured and detailed manner. | |
""" | |
example_notes = """This data is about the Titanic wreck in 1912. | |
The target figure is the survival of passengers, noted by 'Survived'. | |
pclass: A proxy for socio-economic status (SES) | |
1st = Upper | |
2nd = Middle | |
3rd = Lower | |
age: Age is fractional if less than 1. If the age is estimated, it is in the form of xx.5 | |
sibsp: Number of siblings/spouses aboard | |
parch: Number of parents/children aboard""" | |
def get_images_in_directory(directory): | |
"""Retrieve all image file paths from the specified directory.""" | |
image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff'} | |
image_files = [] | |
for root, dirs, files in os.walk(directory): | |
for file in files: | |
if os.path.splitext(file)[1].lower() in image_extensions: | |
image_files.append(os.path.join(root, file)) | |
return image_files | |
def generate_summary(prompt): | |
"""Generate a summary from the language model based on the prompt.""" | |
inputs = tokenizer.encode(prompt, return_tensors="pt") | |
inputs = inputs.to('cpu') # Ensure the model runs on CPU | |
# Generate response | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
inputs, | |
max_length=500, | |
num_beams=4, | |
early_stopping=True | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |
def analyze_data(data_file_path): | |
"""Perform data analysis on the uploaded CSV file.""" | |
try: | |
data = pd.read_csv(data_file_path) | |
except Exception as e: | |
return None, f"Error loading CSV file: {e}", None | |
# Generate data description | |
data_description = f"- **Data Summary (.describe()):**\n{data.describe().to_markdown()}\n\n" | |
data_description += f"- **Data Types:**\n{data.dtypes.to_frame().to_markdown()}\n" | |
# Determine target variable (for demonstration, assume 'Survived' or first numeric column) | |
if 'Survived' in data.columns: | |
target = 'Survived' | |
else: | |
numeric_cols = data.select_dtypes(include='number').columns | |
target = numeric_cols[0] if len(numeric_cols) > 0 else data.columns[0] | |
# Generate visualizations | |
visualization_paths = [] | |
# Correlation heatmap | |
plt.figure(figsize=(10, 8)) | |
sns.heatmap(data.corr(), annot=True, fmt=".2f", cmap='coolwarm') | |
plt.title("Correlation Heatmap") | |
heatmap_path = os.path.join(FIGURES_DIR, "correlation_heatmap.png") | |
plt.savefig(heatmap_path) | |
plt.clf() | |
visualization_paths.append(heatmap_path) | |
# Distribution of target variable | |
plt.figure(figsize=(8, 6)) | |
sns.countplot(x=target, data=data) | |
plt.title(f"Distribution of {target}") | |
distribution_path = os.path.join(FIGURES_DIR, f"{target}_distribution.png") | |
plt.savefig(distribution_path) | |
plt.clf() | |
visualization_paths.append(distribution_path) | |
# Pairplot (limited to first 5 numeric columns for performance) | |
numeric_cols = data.select_dtypes(include='number').columns[:5] | |
if len(numeric_cols) >= 2: | |
sns.pairplot(data[numeric_cols].dropna()) | |
pairplot_path = os.path.join(FIGURES_DIR, "pairplot.png") | |
plt.savefig(pairplot_path) | |
plt.clf() | |
visualization_paths.append(pairplot_path) | |
return data_description, visualization_paths, target | |
def interact_with_agent(file_input, additional_notes): | |
"""Process the uploaded file and interact with the language model to analyze data.""" | |
# Clear and recreate the figures directory | |
if os.path.exists(FIGURES_DIR): | |
shutil.rmtree(FIGURES_DIR) | |
os.makedirs(FIGURES_DIR, exist_ok=True) | |
if file_input is None: | |
return [{"role": "assistant", "content": "❌ No file uploaded. Please upload a CSV file to proceed."}] | |
# Analyze the data | |
data_description, visualization_paths, target = analyze_data(file_input.name) | |
if data_description is None: | |
return [{"role": "assistant", "content": data_description}] # data_description contains the error message | |
# Construct the prompt for the model | |
prompt = base_prompt.format( | |
data_description=data_description, | |
additional_notes=additional_notes if additional_notes else "None." | |
) | |
# Generate summary from the model | |
summary = generate_summary(prompt) | |
# Prepare chat messages in 'messages' format | |
messages = [ | |
{"role": "user", "content": "I have uploaded a CSV file for analysis."}, | |
{"role": "assistant", "content": "⏳ _Analyzing the data..._"} | |
] | |
# Append the summary | |
messages.append({"role": "assistant", "content": summary}) | |
# Append images by converting them to Base64 | |
for image_path in visualization_paths: | |
# Ensure the image path is valid before attempting to display | |
if os.path.isfile(image_path): | |
with open(image_path, "rb") as img_file: | |
img_bytes = img_file.read() | |
encoded_img = base64.b64encode(img_bytes).decode() | |
img_md = f"" | |
messages.append({"role": "assistant", "content": img_md}) | |
else: | |
messages.append({"role": "assistant", "content": f"⚠️ Unable to find image: {image_path}"}) | |
return messages | |
# Define the Gradio interface | |
with gr.Blocks( | |
theme=gr.themes.Soft( | |
primary_hue=gr.themes.colors.blue, | |
secondary_hue=gr.themes.colors.orange, | |
) | |
) as demo: | |
gr.Markdown("""# 📊 Data Analyst Assistant | |
Upload a `.csv` file, add any additional notes, and **the assistant will analyze the data and generate visualizations and insights for you!** | |
**Example:** [Titanic Dataset](./example/titanic.csv) | |
""") | |
with gr.Row(): | |
file_input = gr.File(label="Upload CSV File", file_types=[".csv"]) | |
text_input = gr.Textbox( | |
label="Additional Notes", | |
placeholder="Enter any additional notes or leave blank..." | |
) | |
submit = gr.Button("Run Analysis", variant="primary") | |
chatbot = gr.Chatbot(label="Data Analyst Agent", type='messages', height=500) | |
# Handle examples only if the example file exists | |
if os.path.isfile(EXAMPLE_FILE): | |
gr.Examples( | |
examples=[[EXAMPLE_FILE, example_notes]], | |
inputs=[file_input, text_input], | |
label="Examples", | |
cache_examples=False | |
) | |
else: | |
gr.Markdown("**No example files available.** Please upload your own CSV files.") | |
# Connect the submit button to the interact_with_agent function | |
submit.click( | |
interact_with_agent, | |
inputs=[file_input, text_input], | |
outputs=[chatbot], | |
api_name="run_analysis" | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch(share=True) | |