import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def render_dataset_visualization(dataset, dataset_type):
"""
Renders visualizations for the dataset.
Args:
dataset: The dataset to visualize (pandas DataFrame)
dataset_type: The type of dataset (csv, json, etc.)
"""
if dataset is None:
st.warning("No dataset to visualize.")
return
st.markdown("
Dataset Visualization
", unsafe_allow_html=True)
# Get column types
numeric_cols = dataset.select_dtypes(include=[np.number]).columns.tolist()
categorical_cols = dataset.select_dtypes(include=['object', 'category']).columns.tolist()
date_cols = [col for col in dataset.columns if dataset[col].dtype == 'datetime64[ns]']
# Add visualization options based on column types
viz_type = st.selectbox(
"Select visualization type",
["Distribution", "Correlation", "Categories", "Time Series", "Custom"],
help="Choose the type of visualization to create"
)
if viz_type == "Distribution":
if numeric_cols:
# Select columns for distribution visualization
selected_cols = st.multiselect(
"Select columns to visualize",
numeric_cols,
default=numeric_cols[:min(3, len(numeric_cols))]
)
if not selected_cols:
st.warning("Please select at least one column to visualize.")
return
# Distribution plots
if len(selected_cols) == 1:
# Single column histogram with density curve
col = selected_cols[0]
fig = px.histogram(
dataset,
x=col,
histnorm='probability density',
title=f"Distribution of {col}",
color_discrete_sequence=["#FFD21E"],
template="simple_white"
)
fig.add_traces(
go.Scatter(
x=dataset[col].sort_values(),
y=dataset[col].sort_values().reset_index(drop=True).rolling(
window=int(len(dataset[col])/10) if len(dataset[col]) > 10 else len(dataset[col]),
min_periods=1,
center=True
).mean(),
mode='lines',
line=dict(color="#2563EB", width=3),
name='Smoothed'
)
)
st.plotly_chart(fig, use_container_width=True)
else:
# Multiple histograms in a grid
num_cols = min(len(selected_cols), 2)
num_rows = (len(selected_cols) + num_cols - 1) // num_cols
fig = make_subplots(
rows=num_rows,
cols=num_cols,
subplot_titles=[f"Distribution of {col}" for col in selected_cols]
)
for i, col in enumerate(selected_cols):
row = i // num_cols + 1
col_pos = i % num_cols + 1
# Add histogram
fig.add_trace(
go.Histogram(
x=dataset[col],
name=col,
marker_color="#FFD21E"
),
row=row, col=col_pos
)
fig.update_layout(
title="Distribution of Selected Features",
showlegend=False,
template="simple_white",
height=300 * num_rows
)
st.plotly_chart(fig, use_container_width=True)
# Show distribution statistics
st.markdown("### Distribution Statistics")
stats_df = dataset[selected_cols].describe().T
st.dataframe(stats_df, use_container_width=True)
else:
st.warning("No numeric columns found for distribution visualization.")
elif viz_type == "Correlation":
if len(numeric_cols) >= 2:
# Correlation matrix
st.markdown("### Correlation Matrix")
# Select columns for correlation
selected_cols = st.multiselect(
"Select columns for correlation analysis",
numeric_cols,
default=numeric_cols[:min(5, len(numeric_cols))]
)
if len(selected_cols) < 2:
st.warning("Please select at least two columns for correlation analysis.")
return
# Compute correlation
corr = dataset[selected_cols].corr()
# Heatmap
fig = px.imshow(
corr,
color_continuous_scale="RdBu_r",
title="Correlation Matrix",
template="simple_white",
text_auto=True
)
st.plotly_chart(fig, use_container_width=True)
# Scatter plot matrix for selected columns
if len(selected_cols) > 2 and len(selected_cols) <= 5: # Limit to 5 columns for readability
st.markdown("### Scatter Plot Matrix")
fig = px.scatter_matrix(
dataset,
dimensions=selected_cols,
color_discrete_sequence=["#2563EB"],
title="Scatter Plot Matrix",
template="simple_white"
)
fig.update_traces(diagonal_visible=False)
st.plotly_chart(fig, use_container_width=True)
# Correlation pairs as bar chart
st.markdown("### Top Correlation Pairs")
# Get correlation pairs
corr_pairs = []
for i in range(len(corr.columns)):
for j in range(i+1, len(corr.columns)):
corr_pairs.append({
'Feature 1': corr.columns[i],
'Feature 2': corr.columns[j],
'Correlation': corr.iloc[i, j]
})
# Sort by absolute correlation
corr_pairs = sorted(corr_pairs, key=lambda x: abs(x['Correlation']), reverse=True)
# Create bar chart
if corr_pairs:
# Convert to DataFrame
corr_df = pd.DataFrame(corr_pairs)
pair_labels = [f"{row['Feature 1']} & {row['Feature 2']}" for _, row in corr_df.iterrows()]
# Bar chart
fig = px.bar(
x=pair_labels,
y=[abs(c) for c in corr_df['Correlation']],
color=corr_df['Correlation'],
color_continuous_scale="RdBu_r",
labels={'x': 'Feature Pairs', 'y': 'Absolute Correlation'},
title="Top Feature Correlations"
)
st.plotly_chart(fig, use_container_width=True)
else:
st.warning("Need at least two numeric columns for correlation analysis.")
elif viz_type == "Categories":
if categorical_cols:
# Select categorical column
selected_cat = st.selectbox("Select categorical column", categorical_cols)
# Category counts
value_counts = dataset[selected_cat].value_counts()
# Limit to top N categories if there are too many
if len(value_counts) > 20:
st.info(f"Showing top 20 categories out of {len(value_counts)}")
value_counts = value_counts.head(20)
# Bar chart
fig = px.bar(
x=value_counts.index,
y=value_counts.values,
title=f"Category Counts for {selected_cat}",
labels={'x': selected_cat, 'y': 'Count'},
color_discrete_sequence=["#FFD21E"]
)
st.plotly_chart(fig, use_container_width=True)
# If there are numeric columns, show relationship with categorical
if numeric_cols:
st.markdown(f"### {selected_cat} vs Numeric Features")
selected_num = st.selectbox("Select numeric column", numeric_cols)
# Box plot
fig = px.box(
dataset,
x=selected_cat,
y=selected_num,
title=f"{selected_cat} vs {selected_num}",
color_discrete_sequence=["#2563EB"],
template="simple_white"
)
st.plotly_chart(fig, use_container_width=True)
# Statistics by category
st.markdown(f"### Statistics of {selected_num} by {selected_cat}")
stats_by_cat = dataset.groupby(selected_cat)[selected_num].describe()
st.dataframe(stats_by_cat, use_container_width=True)
else:
st.warning("No categorical columns found for category visualization.")
elif viz_type == "Time Series":
# Check if there are potential date columns
potential_date_cols = date_cols.copy()
# Also check for object columns that might be dates
for col in categorical_cols:
# Sample the column to check if it contains date-like strings
sample = dataset[col].dropna().head(5).tolist()
if sample and all('/' in str(x) or '-' in str(x) for x in sample):
potential_date_cols.append(col)
if potential_date_cols:
date_col = st.selectbox("Select date column", potential_date_cols)
# Convert to datetime if it's not already
if dataset[date_col].dtype != 'datetime64[ns]':
try:
temp_df = dataset.copy()
temp_df[date_col] = pd.to_datetime(temp_df[date_col])
except:
st.error(f"Could not convert {date_col} to datetime.")
return
else:
temp_df = dataset.copy()
# Select numeric column for time series
if numeric_cols:
value_col = st.selectbox("Select value column", numeric_cols)
# Aggregate by time period
time_period = st.selectbox(
"Aggregate by",
["Day", "Week", "Month", "Quarter", "Year"]
)
# Set up time grouping
if time_period == "Day":
temp_df['period'] = temp_df[date_col].dt.date
elif time_period == "Week":
temp_df['period'] = temp_df[date_col].dt.to_period('W').dt.start_time
elif time_period == "Month":
temp_df['period'] = temp_df[date_col].dt.to_period('M').dt.start_time
elif time_period == "Quarter":
temp_df['period'] = temp_df[date_col].dt.to_period('Q').dt.start_time
else: # Year
temp_df['period'] = temp_df[date_col].dt.year
# Aggregate data
agg_method = st.selectbox("Aggregation method", ["Mean", "Sum", "Min", "Max", "Count"])
agg_map = {
"Mean": "mean",
"Sum": "sum",
"Min": "min",
"Max": "max",
"Count": "count"
}
time_series = temp_df.groupby('period')[value_col].agg(agg_map[agg_method]).reset_index()
# Line chart
fig = px.line(
time_series,
x='period',
y=value_col,
title=f"{agg_method} of {value_col} by {time_period}",
markers=True,
color_discrete_sequence=["#2563EB"],
template="simple_white"
)
fig.update_layout(
xaxis_title=time_period,
yaxis_title=f"{agg_method} of {value_col}"
)
st.plotly_chart(fig, use_container_width=True)
# Show trendline option
if st.checkbox("Show trendline"):
fig = px.scatter(
time_series,
x='period',
y=value_col,
trendline="ols",
title=f"{agg_method} of {value_col} by {time_period} with Trendline",
color_discrete_sequence=["#2563EB"],
template="simple_white"
)
fig.update_layout(
xaxis_title=time_period,
yaxis_title=f"{agg_method} of {value_col}"
)
st.plotly_chart(fig, use_container_width=True)
# Table view of time series data
st.dataframe(time_series, use_container_width=True)
else:
st.warning("No numeric columns found for time series values.")
else:
st.warning("No date columns found for time series visualization.")
elif viz_type == "Custom":
st.markdown("### Custom Visualization")
st.info("Create a custom plot by selecting axes and plot type")
# Select plot type
plot_type = st.selectbox(
"Select plot type",
["Scatter", "Line", "Bar", "Box", "Violin", "Histogram", "Pie", "3D Scatter"]
)
# Depending on the plot type, get required axes
if plot_type in ["Scatter", "Line", "Bar", "3D Scatter"]:
# For scatter/line/bar, we need x and y
x_col = st.selectbox("X-axis", dataset.columns.tolist())
y_col = st.selectbox("Y-axis", numeric_cols if numeric_cols else dataset.columns.tolist())
# For 3D scatter, we need a z-axis
if plot_type == "3D Scatter":
z_col = st.selectbox("Z-axis", numeric_cols if numeric_cols else dataset.columns.tolist())
# Optional color dimension
use_color = st.checkbox("Add color dimension")
color_col = None
if use_color:
color_col = st.selectbox("Color by", dataset.columns.tolist())
# Create plot
if plot_type == "Scatter":
fig = px.scatter(
dataset,
x=x_col,
y=y_col,
color=color_col,
title=f"{y_col} vs {x_col}",
template="simple_white"
)
elif plot_type == "Line":
fig = px.line(
dataset.sort_values(x_col),
x=x_col,
y=y_col,
color=color_col,
title=f"{y_col} vs {x_col}",
template="simple_white"
)
elif plot_type == "Bar":
fig = px.bar(
dataset,
x=x_col,
y=y_col,
color=color_col,
title=f"{y_col} by {x_col}",
template="simple_white"
)
elif plot_type == "3D Scatter":
fig = px.scatter_3d(
dataset,
x=x_col,
y=y_col,
z=z_col,
color=color_col,
title=f"3D Scatter: {x_col}, {y_col}, {z_col}",
template="simple_white"
)
st.plotly_chart(fig, use_container_width=True)
elif plot_type in ["Box", "Violin"]:
# For box/violin, we need x (categorical) and y (numeric)
x_col = st.selectbox("X-axis (categories)", categorical_cols if categorical_cols else dataset.columns.tolist())
y_col = st.selectbox("Y-axis (values)", numeric_cols if numeric_cols else dataset.columns.tolist())
# Optional color dimension
use_color = st.checkbox("Add color dimension")
color_col = None
if use_color:
color_col = st.selectbox("Color by", dataset.columns.tolist())
# Create plot
if plot_type == "Box":
fig = px.box(
dataset,
x=x_col,
y=y_col,
color=color_col,
title=f"Box Plot: {y_col} by {x_col}",
template="simple_white"
)
else: # Violin
fig = px.violin(
dataset,
x=x_col,
y=y_col,
color=color_col,
title=f"Violin Plot: {y_col} by {x_col}",
template="simple_white"
)
st.plotly_chart(fig, use_container_width=True)
elif plot_type == "Histogram":
# For histogram, we need just one column
value_col = st.selectbox("Value column", dataset.columns.tolist())
# Bins option
n_bins = st.slider("Number of bins", 5, 100, 20)
# Optional color dimension
use_color = st.checkbox("Add color dimension")
color_col = None
if use_color:
color_col = st.selectbox("Color by", dataset.columns.tolist())
# Create plot
fig = px.histogram(
dataset,
x=value_col,
color=color_col,
nbins=n_bins,
title=f"Histogram of {value_col}",
template="simple_white"
)
st.plotly_chart(fig, use_container_width=True)
elif plot_type == "Pie":
# For pie, we need a categorical column
cat_col = st.selectbox("Category column", categorical_cols if categorical_cols else dataset.columns.tolist())
# Optional value column
use_values = st.checkbox("Use custom values")
value_col = None
if use_values and numeric_cols:
value_col = st.selectbox("Value column", numeric_cols)
# Limit to top N categories if there are too many
top_n = st.slider("Limit to top N categories", 0, 20, 10,
help="Set to 0 to show all categories. Recommended to limit to top 10-15 categories for readability.")
# Process data for pie chart
if top_n > 0:
if use_values and value_col:
pie_data = dataset.groupby(cat_col)[value_col].sum().reset_index()
pie_data = pie_data.sort_values(value_col, ascending=False).head(top_n)
else:
value_counts = dataset[cat_col].value_counts().reset_index()
value_counts.columns = [cat_col, 'count']
pie_data = value_counts.head(top_n)
value_col = 'count'
else:
if use_values and value_col:
pie_data = dataset.groupby(cat_col)[value_col].sum().reset_index()
else:
value_counts = dataset[cat_col].value_counts().reset_index()
value_counts.columns = [cat_col, 'count']
pie_data = value_counts
value_col = 'count'
# Create plot
fig = px.pie(
pie_data,
names=cat_col,
values=value_col,
title=f"Pie Chart of {cat_col}",
template="simple_white"
)
st.plotly_chart(fig, use_container_width=True)