import os import shutil import gradio as gr from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import pandas as pd import torch import matplotlib.pyplot as plt import seaborn as sns import base64 # Define constants MODEL_NAME = "facebook/bart-large-cnn" # Fine-tuned for summarization FIGURES_DIR = "./figures" EXAMPLE_DIR = "./example" EXAMPLE_FILE = os.path.join(EXAMPLE_DIR, "titanic.csv") # Ensure the figures and example directories exist os.makedirs(FIGURES_DIR, exist_ok=True) os.makedirs(EXAMPLE_DIR, exist_ok=True) # Download the Titanic dataset if it doesn't exist if not os.path.isfile(EXAMPLE_FILE): print("Downloading the Titanic dataset for examples...") try: # Using seaborn's built-in Titanic dataset titanic = sns.load_dataset('titanic') titanic.to_csv(EXAMPLE_FILE, index=False) print(f"Example dataset saved to {EXAMPLE_FILE}.") except Exception as e: print(f"Failed to download the Titanic dataset: {e}") print("Please ensure the 'example/titanic.csv' file exists.") # Optionally, exit or continue without examples # exit(1) # Initialize tokenizer and model print("Loading model and tokenizer...") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) model.to('cpu') # Ensure the model runs on CPU print("Model and tokenizer loaded successfully.") except Exception as e: print(f"Error loading model: {e}") exit(1) # Define the base prompt for the model base_prompt = """You are an expert data analyst. Based on the following data description, determine an appropriate target feature. List 3 insightful questions regarding the data. Provide detailed answers to each question with relevant statistics. Summarize the findings with real-world insights. Data Description: {data_description} Additional Notes: {additional_notes} Please provide your response in a structured and detailed manner. """ example_notes = """This data is about the Titanic wreck in 1912. The target figure is the survival of passengers, noted by 'Survived'. pclass: A proxy for socio-economic status (SES) 1st = Upper 2nd = Middle 3rd = Lower age: Age is fractional if less than 1. If the age is estimated, it is in the form of xx.5 sibsp: Number of siblings/spouses aboard parch: Number of parents/children aboard""" def get_images_in_directory(directory): """Retrieve all image file paths from the specified directory.""" image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff'} image_files = [] for root, dirs, files in os.walk(directory): for file in files: if os.path.splitext(file)[1].lower() in image_extensions: image_files.append(os.path.join(root, file)) return image_files def generate_summary(prompt): """Generate a summary from the language model based on the prompt.""" inputs = tokenizer.encode(prompt, return_tensors="pt") inputs = inputs.to('cpu') # Ensure the model runs on CPU # Generate response with torch.no_grad(): summary_ids = model.generate( inputs, max_length=500, num_beams=4, early_stopping=True ) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) return summary def analyze_data(data_file_path): """Perform data analysis on the uploaded CSV file.""" try: data = pd.read_csv(data_file_path) except Exception as e: return None, f"Error loading CSV file: {e}", None # Generate data description data_description = f"- **Data Summary (.describe()):**\n{data.describe().to_markdown()}\n\n" data_description += f"- **Data Types:**\n{data.dtypes.to_frame().to_markdown()}\n" # Determine target variable (for demonstration, assume 'Survived' or first numeric column) if 'Survived' in data.columns: target = 'Survived' else: numeric_cols = data.select_dtypes(include='number').columns target = numeric_cols[0] if len(numeric_cols) > 0 else data.columns[0] # Generate visualizations visualization_paths = [] # Correlation heatmap plt.figure(figsize=(10, 8)) sns.heatmap(data.corr(), annot=True, fmt=".2f", cmap='coolwarm') plt.title("Correlation Heatmap") heatmap_path = os.path.join(FIGURES_DIR, "correlation_heatmap.png") plt.savefig(heatmap_path) plt.clf() visualization_paths.append(heatmap_path) # Distribution of target variable plt.figure(figsize=(8, 6)) sns.countplot(x=target, data=data) plt.title(f"Distribution of {target}") distribution_path = os.path.join(FIGURES_DIR, f"{target}_distribution.png") plt.savefig(distribution_path) plt.clf() visualization_paths.append(distribution_path) # Pairplot (limited to first 5 numeric columns for performance) numeric_cols = data.select_dtypes(include='number').columns[:5] if len(numeric_cols) >= 2: sns.pairplot(data[numeric_cols].dropna()) pairplot_path = os.path.join(FIGURES_DIR, "pairplot.png") plt.savefig(pairplot_path) plt.clf() visualization_paths.append(pairplot_path) return data_description, visualization_paths, target def interact_with_agent(file_input, additional_notes): """Process the uploaded file and interact with the language model to analyze data.""" # Clear and recreate the figures directory if os.path.exists(FIGURES_DIR): shutil.rmtree(FIGURES_DIR) os.makedirs(FIGURES_DIR, exist_ok=True) if file_input is None: return [{"role": "assistant", "content": "❌ No file uploaded. Please upload a CSV file to proceed."}] # Analyze the data data_description, visualization_paths, target = analyze_data(file_input.name) if data_description is None: return [{"role": "assistant", "content": data_description}] # data_description contains the error message # Construct the prompt for the model prompt = base_prompt.format( data_description=data_description, additional_notes=additional_notes if additional_notes else "None." ) # Generate summary from the model summary = generate_summary(prompt) # Prepare chat messages in 'messages' format messages = [ {"role": "user", "content": "I have uploaded a CSV file for analysis."}, {"role": "assistant", "content": "⏳ _Analyzing the data..._"} ] # Append the summary messages.append({"role": "assistant", "content": summary}) # Append images by converting them to Base64 for image_path in visualization_paths: # Ensure the image path is valid before attempting to display if os.path.isfile(image_path): with open(image_path, "rb") as img_file: img_bytes = img_file.read() encoded_img = base64.b64encode(img_bytes).decode() img_md = f"![{os.path.basename(image_path)}](data:image/png;base64,{encoded_img})" messages.append({"role": "assistant", "content": img_md}) else: messages.append({"role": "assistant", "content": f"⚠️ Unable to find image: {image_path}"}) return messages # Define the Gradio interface with gr.Blocks( theme=gr.themes.Soft( primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.orange, ) ) as demo: gr.Markdown("""# 📊 Data Analyst Assistant Upload a `.csv` file, add any additional notes, and **the assistant will analyze the data and generate visualizations and insights for you!** **Example:** [Titanic Dataset](./example/titanic.csv) """) with gr.Row(): file_input = gr.File(label="Upload CSV File", file_types=[".csv"]) text_input = gr.Textbox( label="Additional Notes", placeholder="Enter any additional notes or leave blank..." ) submit = gr.Button("Run Analysis", variant="primary") chatbot = gr.Chatbot(label="Data Analyst Agent", type='messages', height=500) # Handle examples only if the example file exists if os.path.isfile(EXAMPLE_FILE): gr.Examples( examples=[[EXAMPLE_FILE, example_notes]], inputs=[file_input, text_input], label="Examples", cache_examples=False ) else: gr.Markdown("**No example files available.** Please upload your own CSV files.") # Connect the submit button to the interact_with_agent function submit.click( interact_with_agent, inputs=[file_input, text_input], outputs=[chatbot], api_name="run_analysis" ) # Launch the Gradio app if __name__ == "__main__": demo.launch(share=True)