alidenewade's picture
Update app.py
547f02d verified
raw
history blame
15.3 kB
import gradio as gr
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import warnings
import joblib
from statsmodels.tsa.statespace.sarimax import SARIMAX
# --- Setup and Configuration ---
warnings.filterwarnings('ignore')
# --- File Loading ---
# NOTE: When deploying to Hugging Face Spaces, upload these files to your space.
# You can use the "Files" tab in your Hugging Face Space to upload them.
# Make sure the paths here match where you upload the files.
try:
TRACTS_PATH = Path("nyc_tracts.gpkg")
PANEL_PATH = Path("nyc_cesium_features.parquet")
MODEL_PATH = Path("lgbm_crime_classifier.joblib") # We'll need to create and save this model
tracts_gdf = gpd.read_file(TRACTS_PATH)
panel_df = pd.read_parquet(PANEL_PATH)
# Convert month to datetime for filtering
panel_df['month'] = pd.to_datetime(panel_df['month'])
except FileNotFoundError as e:
print(f"Error loading data files: {e}")
print("Please make sure 'nyc_tracts.gpkg' and 'nyc_cesium_features.parquet' are in the same directory as app.py")
# Create dummy dataframes to allow the app to launch for structure review
tracts_gdf = gpd.GeoDataFrame({'GEOID': ['DUMMY'], 'geometry': [None]})
panel_df = pd.DataFrame({
'GEOID': ['DUMMY'],
'month': [pd.to_datetime('2023-01-01')],
'crime_total': [0],
'sr311_total': [0],
'dob_permits_total': [0],
'crime_felony': [0],
'crime_misd': [0],
'crime_viol': [0]
})
# This will be handled more gracefully in the app's functions
# --- Pre-computation and Model Training (for demonstration) ---
# In a real scenario, you would train and save the model separately.
# For this script, we'll simulate a simple model if one isn't loaded.
if not MODEL_PATH.exists():
print(f"Model file not found at {MODEL_PATH}. A placeholder model will be used.")
# In a real application, you would have a proper training script.
# This is just a placeholder.
model = None
else:
model = joblib.load(MODEL_PATH)
# --- Tab 1: EDA Dashboard Functions ---
def create_choropleth_map(metric, start_date, end_date):
"""Creates a choropleth map for a given metric and date range."""
print(f"DEBUG: create_choropleth_map called with metric={metric}, start_date={start_date}, end_date={end_date}")
if panel_df is None or 'DUMMY' in panel_df['GEOID'].tolist():
fig, ax = plt.subplots()
ax.text(0.5, 0.5, "Data not loaded", ha='center', va='center')
return fig
# Parse dates - handle both string and datetime inputs
try:
if isinstance(start_date, str):
start_date = pd.to_datetime(start_date)
if isinstance(end_date, str):
end_date = pd.to_datetime(end_date)
print(f"DEBUG: Parsed dates - start: {start_date}, end: {end_date}")
except Exception as e:
print(f"DEBUG: Date parsing error: {e}")
start_date = panel_df['month'].min()
end_date = panel_df['month'].max()
filtered_df = panel_df[(panel_df['month'] >= start_date) & (panel_df['month'] <= end_date)]
print(f"DEBUG: Filtered dataframe length: {len(filtered_df)}")
if len(filtered_df) == 0:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.text(0.5, 0.5, f"No data found for date range", ha='center', va='center')
ax.set_title('No Data Available', fontsize=15)
ax.set_axis_off()
return fig
geoid_totals = filtered_df.groupby('GEOID')[metric].sum().reset_index()
print(f"DEBUG: GEOID totals shape: {geoid_totals.shape}")
merged_gdf = tracts_gdf.merge(geoid_totals, on='GEOID', how='left').fillna(0)
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
merged_gdf.plot(column=metric,
ax=ax,
legend=True,
cmap='viridis',
legend_kwds={'label': f"Total {metric.replace('_', ' ').title()}",
'orientation': "horizontal"})
ax.set_title(f'Spatial Distribution of {metric.replace("_", " ").title()}', fontsize=15)
ax.set_axis_off()
plt.tight_layout()
return fig
def create_time_series_plot(metric, start_date, end_date):
"""Creates a time series plot for a given metric and date range."""
print(f"DEBUG: create_time_series_plot called with metric={metric}, start_date={start_date}, end_date={end_date}")
if panel_df is None or 'DUMMY' in panel_df['GEOID'].tolist():
fig, ax = plt.subplots()
ax.text(0.5, 0.5, "Data not loaded", ha='center', va='center')
return fig
# Parse dates - handle both string and datetime inputs
try:
if isinstance(start_date, str):
start_date = pd.to_datetime(start_date)
if isinstance(end_date, str):
end_date = pd.to_datetime(end_date)
print(f"DEBUG: Parsed dates - start: {start_date}, end: {end_date}")
except Exception as e:
print(f"DEBUG: Date parsing error: {e}")
start_date = panel_df['month'].min()
end_date = panel_df['month'].max()
filtered_df = panel_df[(panel_df['month'] >= start_date) & (panel_df['month'] <= end_date)]
print(f"DEBUG: Filtered dataframe length: {len(filtered_df)}")
if len(filtered_df) == 0:
fig, ax = plt.subplots(figsize=(12, 6))
ax.text(0.5, 0.5, f"No data found for date range", ha='center', va='center')
ax.set_title('No Data Available', fontsize=15)
return fig
monthly_totals = filtered_df.groupby('month')[metric].sum()
print(f"DEBUG: Monthly totals shape: {monthly_totals.shape}")
fig, ax = plt.subplots(figsize=(12, 6))
monthly_totals.plot(ax=ax)
ax.set_title(f'Monthly Total of {metric.replace("_", " ").title()}', fontsize=15)
ax.set_xlabel('Month')
ax.set_ylabel('Total Count')
ax.grid(True)
plt.tight_layout()
return fig
# --- Tab 2: Predictive ML & TS Functions ---
def predict_crime_level(crime_felony, crime_misd, crime_viol, sr311_total, dob_permits_total):
"""Predicts crime level based on input features."""
print(f"DEBUG: predict_crime_level called with inputs: {crime_felony}, {crime_misd}, {crime_viol}, {sr311_total}, {dob_permits_total}")
if model is None:
# Create a dummy prediction based on simple logic when model is not available
total_crime = crime_felony + crime_misd + crime_viol
# Simple rule-based classification for demonstration
if total_crime <= 20:
prediction = "Low"
confidence = {"Low": 0.7, "Medium": 0.2, "High": 0.1}
elif total_crime <= 50:
prediction = "Medium"
confidence = {"Low": 0.2, "Medium": 0.6, "High": 0.2}
else:
prediction = "High"
confidence = {"Low": 0.1, "Medium": 0.3, "High": 0.6}
# Factor in 311 requests and permits
if sr311_total > 500:
# High service requests might indicate more issues
if prediction == "Low":
prediction = "Medium"
confidence = {"Low": 0.4, "Medium": 0.5, "High": 0.1}
if dob_permits_total > 25:
# High construction activity might indicate development/change
confidence["Medium"] = min(0.8, confidence.get("Medium", 0) + 0.2)
print(f"DEBUG: Dummy prediction result: {prediction}, confidence: {confidence}")
return f"Predicted Crime Level: {prediction} (using fallback model)", confidence
try:
# Create a DataFrame for the model
input_data = pd.DataFrame({
'crime_felony': [crime_felony],
'crime_misd': [crime_misd],
'crime_viol': [crime_viol],
'sr311_total': [sr311_total],
'dob_permits_total': [dob_permits_total]
})
# Predict probabilities
probabilities = model.predict_proba(input_data)[0]
labels = model.classes_
# Get the prediction
prediction = labels[np.argmax(probabilities)]
# Create a confidence dictionary
confidence = {label: prob for label, prob in zip(labels, probabilities)}
print(f"DEBUG: Real model prediction result: {prediction}, confidence: {confidence}")
return f"Predicted Crime Level: {prediction}", confidence
except Exception as e:
print(f"DEBUG: Error in model prediction: {e}")
return f"Error in prediction: {str(e)}", {}
def forecast_time_series(geoid):
"""Forecasts crime for a specific GEOID."""
if panel_df is None or 'DUMMY' in panel_df['GEOID'].tolist():
fig, ax = plt.subplots()
ax.text(0.5, 0.5, "Data not loaded", ha='center', va='center')
return fig, "Data not loaded."
if geoid not in panel_df['GEOID'].unique():
return None, f"GEOID {geoid} not found in the dataset."
tract_data = panel_df[panel_df['GEOID'] == geoid].set_index('month')['crime_total'].asfreq('MS')
if len(tract_data) < 24: # Need enough data to forecast
return None, f"Not enough historical data for GEOID {geoid} to create a forecast."
# Simple SARIMAX model for demonstration
model_ts = SARIMAX(tract_data, order=(1, 1, 1), seasonal_order=(1, 1, 1, 12))
results = model_ts.fit(disp=False)
forecast = results.get_forecast(steps=12)
forecast_mean = forecast.predicted_mean
forecast_ci = forecast.conf_int()
fig, ax = plt.subplots(figsize=(12, 6))
tract_data.plot(ax=ax, label='Historical')
forecast_mean.plot(ax=ax, label='Forecast')
ax.fill_between(forecast_ci.index,
forecast_ci.iloc[:, 0],
forecast_ci.iloc[:, 1], color='k', alpha=.25)
ax.set_title(f'Crime Forecast for Census Tract {geoid}')
ax.set_xlabel('Date')
ax.set_ylabel('Crime Total')
ax.legend()
ax.grid(True)
plt.tight_layout()
metrics_text = f"Forecast for GEOID: {geoid}\n"
metrics_text += "Mean Absolute Error (on test set) would be calculated here in a full implementation."
return fig, metrics_text
# --- Gradio App Layout ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# NYC Urban Indicators Dashboard & Prediction")
with gr.Tab("Dashboard"):
gr.Markdown("## Exploratory Data Analysis of NYC Urban Data")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Controls")
metric_selector = gr.Dropdown(
label="Select Metric",
choices=['crime_total', 'sr311_total', 'dob_permits_total'],
value='crime_total'
)
# Get date range from data
min_date = panel_df['month'].min().strftime('%Y-%m-%d')
max_date = panel_df['month'].max().strftime('%Y-%m-%d')
start_date_picker = gr.Textbox(
label="Start Date (YYYY-MM-DD)",
value=min_date,
placeholder="2023-01-01"
)
end_date_picker = gr.Textbox(
label="End Date (YYYY-MM-DD)",
value=max_date,
placeholder="2023-12-31"
)
update_button = gr.Button("Update Dashboard")
with gr.Column(scale=3):
gr.Markdown("### Visualizations")
map_plot = gr.Plot()
ts_plot = gr.Plot()
# Function to update both plots at once
def update_dashboard(metric, start_date, end_date):
print(f"DEBUG: update_dashboard called with {metric}, {start_date}, {end_date}")
map_fig = create_choropleth_map(metric, start_date, end_date)
ts_fig = create_time_series_plot(metric, start_date, end_date)
return map_fig, ts_fig
# Update on button click
update_button.click(
fn=update_dashboard,
inputs=[metric_selector, start_date_picker, end_date_picker],
outputs=[map_plot, ts_plot]
)
# Also trigger updates when inputs change
metric_selector.change(
fn=update_dashboard,
inputs=[metric_selector, start_date_picker, end_date_picker],
outputs=[map_plot, ts_plot]
)
start_date_picker.change(
fn=update_dashboard,
inputs=[metric_selector, start_date_picker, end_date_picker],
outputs=[map_plot, ts_plot]
)
end_date_picker.change(
fn=update_dashboard,
inputs=[metric_selector, start_date_picker, end_date_picker],
outputs=[map_plot, ts_plot]
)
with gr.Tab("Predictive Analytics"):
with gr.Tabs():
with gr.TabItem("Machine Learning Prediction"):
gr.Markdown("## Predict Next Month's Crime Level")
gr.Markdown("Adjust the sliders to reflect the current month's data for a census tract.")
with gr.Row():
with gr.Column():
felony_slider = gr.Slider(0, 100, label="Felony Count", step=1, value=5)
misd_slider = gr.Slider(0, 200, label="Misdemeanor Count", step=1, value=15)
viol_slider = gr.Slider(0, 200, label="Violation Count", step=1, value=10)
sr311_slider = gr.Slider(0, 1000, label="311 Service Requests", step=10, value=100)
dob_slider = gr.Slider(0, 50, label="DOB Permits Issued", step=1, value=3)
predict_button = gr.Button("Predict")
with gr.Column():
prediction_output = gr.Label(label="Prediction Result")
confidence_output = gr.Label(label="Prediction Confidence")
predict_button.click(
fn=predict_crime_level,
inputs=[felony_slider, misd_slider, viol_slider, sr311_slider, dob_slider],
outputs=[prediction_output, confidence_output]
)
with gr.TabItem("Time Series Forecasting"):
gr.Markdown("## Forecast Future Crime Counts")
gr.Markdown("Enter a Census Tract GEOID to forecast the total crime count for the next 12 months.")
with gr.Row():
with gr.Column():
geoid_input = gr.Textbox(label="Enter GEOID", placeholder="e.g., 36005000100")
forecast_button = gr.Button("Generate Forecast")
with gr.Column():
forecast_metrics = gr.Textbox(label="Forecast Metrics", interactive=False)
forecast_plot = gr.Plot()
forecast_button.click(
fn=forecast_time_series,
inputs=[geoid_input],
outputs=[forecast_plot, forecast_metrics]
)
if __name__ == "__main__":
demo.launch()