import streamlit as st import pandas as pd import numpy as np import plotly.express as px import plotly.graph_objects as go from plotly.subplots import make_subplots def render_dataset_visualization(dataset, dataset_type): """ Renders visualizations for the dataset. Args: dataset: The dataset to visualize (pandas DataFrame) dataset_type: The type of dataset (csv, json, etc.) """ if dataset is None: st.warning("No dataset to visualize.") return st.markdown("

Dataset Visualization

", unsafe_allow_html=True) # Get column types numeric_cols = dataset.select_dtypes(include=[np.number]).columns.tolist() categorical_cols = dataset.select_dtypes(include=['object', 'category']).columns.tolist() date_cols = [col for col in dataset.columns if dataset[col].dtype == 'datetime64[ns]'] # Add visualization options based on column types viz_type = st.selectbox( "Select visualization type", ["Distribution", "Correlation", "Categories", "Time Series", "Custom"], help="Choose the type of visualization to create" ) if viz_type == "Distribution": if numeric_cols: # Select columns for distribution visualization selected_cols = st.multiselect( "Select columns to visualize", numeric_cols, default=numeric_cols[:min(3, len(numeric_cols))] ) if not selected_cols: st.warning("Please select at least one column to visualize.") return # Distribution plots if len(selected_cols) == 1: # Single column histogram with density curve col = selected_cols[0] fig = px.histogram( dataset, x=col, histnorm='probability density', title=f"Distribution of {col}", color_discrete_sequence=["#FFD21E"], template="simple_white" ) fig.add_traces( go.Scatter( x=dataset[col].sort_values(), y=dataset[col].sort_values().reset_index(drop=True).rolling( window=int(len(dataset[col])/10) if len(dataset[col]) > 10 else len(dataset[col]), min_periods=1, center=True ).mean(), mode='lines', line=dict(color="#2563EB", width=3), name='Smoothed' ) ) st.plotly_chart(fig, use_container_width=True) else: # Multiple histograms in a grid num_cols = min(len(selected_cols), 2) num_rows = (len(selected_cols) + num_cols - 1) // num_cols fig = make_subplots( rows=num_rows, cols=num_cols, subplot_titles=[f"Distribution of {col}" for col in selected_cols] ) for i, col in enumerate(selected_cols): row = i // num_cols + 1 col_pos = i % num_cols + 1 # Add histogram fig.add_trace( go.Histogram( x=dataset[col], name=col, marker_color="#FFD21E" ), row=row, col=col_pos ) fig.update_layout( title="Distribution of Selected Features", showlegend=False, template="simple_white", height=300 * num_rows ) st.plotly_chart(fig, use_container_width=True) # Show distribution statistics st.markdown("### Distribution Statistics") stats_df = dataset[selected_cols].describe().T st.dataframe(stats_df, use_container_width=True) else: st.warning("No numeric columns found for distribution visualization.") elif viz_type == "Correlation": if len(numeric_cols) >= 2: # Correlation matrix st.markdown("### Correlation Matrix") # Select columns for correlation selected_cols = st.multiselect( "Select columns for correlation analysis", numeric_cols, default=numeric_cols[:min(5, len(numeric_cols))] ) if len(selected_cols) < 2: st.warning("Please select at least two columns for correlation analysis.") return # Compute correlation corr = dataset[selected_cols].corr() # Heatmap fig = px.imshow( corr, color_continuous_scale="RdBu_r", title="Correlation Matrix", template="simple_white", text_auto=True ) st.plotly_chart(fig, use_container_width=True) # Scatter plot matrix for selected columns if len(selected_cols) > 2 and len(selected_cols) <= 5: # Limit to 5 columns for readability st.markdown("### Scatter Plot Matrix") fig = px.scatter_matrix( dataset, dimensions=selected_cols, color_discrete_sequence=["#2563EB"], title="Scatter Plot Matrix", template="simple_white" ) fig.update_traces(diagonal_visible=False) st.plotly_chart(fig, use_container_width=True) # Correlation pairs as bar chart st.markdown("### Top Correlation Pairs") # Get correlation pairs corr_pairs = [] for i in range(len(corr.columns)): for j in range(i+1, len(corr.columns)): corr_pairs.append({ 'Feature 1': corr.columns[i], 'Feature 2': corr.columns[j], 'Correlation': corr.iloc[i, j] }) # Sort by absolute correlation corr_pairs = sorted(corr_pairs, key=lambda x: abs(x['Correlation']), reverse=True) # Create bar chart if corr_pairs: # Convert to DataFrame corr_df = pd.DataFrame(corr_pairs) pair_labels = [f"{row['Feature 1']} & {row['Feature 2']}" for _, row in corr_df.iterrows()] # Bar chart fig = px.bar( x=pair_labels, y=[abs(c) for c in corr_df['Correlation']], color=corr_df['Correlation'], color_continuous_scale="RdBu_r", labels={'x': 'Feature Pairs', 'y': 'Absolute Correlation'}, title="Top Feature Correlations" ) st.plotly_chart(fig, use_container_width=True) else: st.warning("Need at least two numeric columns for correlation analysis.") elif viz_type == "Categories": if categorical_cols: # Select categorical column selected_cat = st.selectbox("Select categorical column", categorical_cols) # Category counts value_counts = dataset[selected_cat].value_counts() # Limit to top N categories if there are too many if len(value_counts) > 20: st.info(f"Showing top 20 categories out of {len(value_counts)}") value_counts = value_counts.head(20) # Bar chart fig = px.bar( x=value_counts.index, y=value_counts.values, title=f"Category Counts for {selected_cat}", labels={'x': selected_cat, 'y': 'Count'}, color_discrete_sequence=["#FFD21E"] ) st.plotly_chart(fig, use_container_width=True) # If there are numeric columns, show relationship with categorical if numeric_cols: st.markdown(f"### {selected_cat} vs Numeric Features") selected_num = st.selectbox("Select numeric column", numeric_cols) # Box plot fig = px.box( dataset, x=selected_cat, y=selected_num, title=f"{selected_cat} vs {selected_num}", color_discrete_sequence=["#2563EB"], template="simple_white" ) st.plotly_chart(fig, use_container_width=True) # Statistics by category st.markdown(f"### Statistics of {selected_num} by {selected_cat}") stats_by_cat = dataset.groupby(selected_cat)[selected_num].describe() st.dataframe(stats_by_cat, use_container_width=True) else: st.warning("No categorical columns found for category visualization.") elif viz_type == "Time Series": # Check if there are potential date columns potential_date_cols = date_cols.copy() # Also check for object columns that might be dates for col in categorical_cols: # Sample the column to check if it contains date-like strings sample = dataset[col].dropna().head(5).tolist() if sample and all('/' in str(x) or '-' in str(x) for x in sample): potential_date_cols.append(col) if potential_date_cols: date_col = st.selectbox("Select date column", potential_date_cols) # Convert to datetime if it's not already if dataset[date_col].dtype != 'datetime64[ns]': try: temp_df = dataset.copy() temp_df[date_col] = pd.to_datetime(temp_df[date_col]) except: st.error(f"Could not convert {date_col} to datetime.") return else: temp_df = dataset.copy() # Select numeric column for time series if numeric_cols: value_col = st.selectbox("Select value column", numeric_cols) # Aggregate by time period time_period = st.selectbox( "Aggregate by", ["Day", "Week", "Month", "Quarter", "Year"] ) # Set up time grouping if time_period == "Day": temp_df['period'] = temp_df[date_col].dt.date elif time_period == "Week": temp_df['period'] = temp_df[date_col].dt.to_period('W').dt.start_time elif time_period == "Month": temp_df['period'] = temp_df[date_col].dt.to_period('M').dt.start_time elif time_period == "Quarter": temp_df['period'] = temp_df[date_col].dt.to_period('Q').dt.start_time else: # Year temp_df['period'] = temp_df[date_col].dt.year # Aggregate data agg_method = st.selectbox("Aggregation method", ["Mean", "Sum", "Min", "Max", "Count"]) agg_map = { "Mean": "mean", "Sum": "sum", "Min": "min", "Max": "max", "Count": "count" } time_series = temp_df.groupby('period')[value_col].agg(agg_map[agg_method]).reset_index() # Line chart fig = px.line( time_series, x='period', y=value_col, title=f"{agg_method} of {value_col} by {time_period}", markers=True, color_discrete_sequence=["#2563EB"], template="simple_white" ) fig.update_layout( xaxis_title=time_period, yaxis_title=f"{agg_method} of {value_col}" ) st.plotly_chart(fig, use_container_width=True) # Show trendline option if st.checkbox("Show trendline"): fig = px.scatter( time_series, x='period', y=value_col, trendline="ols", title=f"{agg_method} of {value_col} by {time_period} with Trendline", color_discrete_sequence=["#2563EB"], template="simple_white" ) fig.update_layout( xaxis_title=time_period, yaxis_title=f"{agg_method} of {value_col}" ) st.plotly_chart(fig, use_container_width=True) # Table view of time series data st.dataframe(time_series, use_container_width=True) else: st.warning("No numeric columns found for time series values.") else: st.warning("No date columns found for time series visualization.") elif viz_type == "Custom": st.markdown("### Custom Visualization") st.info("Create a custom plot by selecting axes and plot type") # Select plot type plot_type = st.selectbox( "Select plot type", ["Scatter", "Line", "Bar", "Box", "Violin", "Histogram", "Pie", "3D Scatter"] ) # Depending on the plot type, get required axes if plot_type in ["Scatter", "Line", "Bar", "3D Scatter"]: # For scatter/line/bar, we need x and y x_col = st.selectbox("X-axis", dataset.columns.tolist()) y_col = st.selectbox("Y-axis", numeric_cols if numeric_cols else dataset.columns.tolist()) # For 3D scatter, we need a z-axis if plot_type == "3D Scatter": z_col = st.selectbox("Z-axis", numeric_cols if numeric_cols else dataset.columns.tolist()) # Optional color dimension use_color = st.checkbox("Add color dimension") color_col = None if use_color: color_col = st.selectbox("Color by", dataset.columns.tolist()) # Create plot if plot_type == "Scatter": fig = px.scatter( dataset, x=x_col, y=y_col, color=color_col, title=f"{y_col} vs {x_col}", template="simple_white" ) elif plot_type == "Line": fig = px.line( dataset.sort_values(x_col), x=x_col, y=y_col, color=color_col, title=f"{y_col} vs {x_col}", template="simple_white" ) elif plot_type == "Bar": fig = px.bar( dataset, x=x_col, y=y_col, color=color_col, title=f"{y_col} by {x_col}", template="simple_white" ) elif plot_type == "3D Scatter": fig = px.scatter_3d( dataset, x=x_col, y=y_col, z=z_col, color=color_col, title=f"3D Scatter: {x_col}, {y_col}, {z_col}", template="simple_white" ) st.plotly_chart(fig, use_container_width=True) elif plot_type in ["Box", "Violin"]: # For box/violin, we need x (categorical) and y (numeric) x_col = st.selectbox("X-axis (categories)", categorical_cols if categorical_cols else dataset.columns.tolist()) y_col = st.selectbox("Y-axis (values)", numeric_cols if numeric_cols else dataset.columns.tolist()) # Optional color dimension use_color = st.checkbox("Add color dimension") color_col = None if use_color: color_col = st.selectbox("Color by", dataset.columns.tolist()) # Create plot if plot_type == "Box": fig = px.box( dataset, x=x_col, y=y_col, color=color_col, title=f"Box Plot: {y_col} by {x_col}", template="simple_white" ) else: # Violin fig = px.violin( dataset, x=x_col, y=y_col, color=color_col, title=f"Violin Plot: {y_col} by {x_col}", template="simple_white" ) st.plotly_chart(fig, use_container_width=True) elif plot_type == "Histogram": # For histogram, we need just one column value_col = st.selectbox("Value column", dataset.columns.tolist()) # Bins option n_bins = st.slider("Number of bins", 5, 100, 20) # Optional color dimension use_color = st.checkbox("Add color dimension") color_col = None if use_color: color_col = st.selectbox("Color by", dataset.columns.tolist()) # Create plot fig = px.histogram( dataset, x=value_col, color=color_col, nbins=n_bins, title=f"Histogram of {value_col}", template="simple_white" ) st.plotly_chart(fig, use_container_width=True) elif plot_type == "Pie": # For pie, we need a categorical column cat_col = st.selectbox("Category column", categorical_cols if categorical_cols else dataset.columns.tolist()) # Optional value column use_values = st.checkbox("Use custom values") value_col = None if use_values and numeric_cols: value_col = st.selectbox("Value column", numeric_cols) # Limit to top N categories if there are too many top_n = st.slider("Limit to top N categories", 0, 20, 10, help="Set to 0 to show all categories. Recommended to limit to top 10-15 categories for readability.") # Process data for pie chart if top_n > 0: if use_values and value_col: pie_data = dataset.groupby(cat_col)[value_col].sum().reset_index() pie_data = pie_data.sort_values(value_col, ascending=False).head(top_n) else: value_counts = dataset[cat_col].value_counts().reset_index() value_counts.columns = [cat_col, 'count'] pie_data = value_counts.head(top_n) value_col = 'count' else: if use_values and value_col: pie_data = dataset.groupby(cat_col)[value_col].sum().reset_index() else: value_counts = dataset[cat_col].value_counts().reset_index() value_counts.columns = [cat_col, 'count'] pie_data = value_counts value_col = 'count' # Create plot fig = px.pie( pie_data, names=cat_col, values=value_col, title=f"Pie Chart of {cat_col}", template="simple_white" ) st.plotly_chart(fig, use_container_width=True)