import json import io import os import tempfile import datetime import gradio as gr import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import numpy as np from sktime.utils.plotting import plot_series from statsmodels.tsa.seasonal import seasonal_decompose from statsmodels.graphics.tsaplots import plot_acf, plot_pacf from src.forecaster import Forecaster from src.forecaster.models import XGBoost from src.analyser import Analyser from src.idsc import IDSC from src.forecaster.models import ProphetForecaster class GradioApp(): def __init__( self ) -> None: self.forecaster = Forecaster() self.analyser = Analyser() self.idsc = IDSC() self.historical_demo_data = 'data/multivariate/demo_historical.csv' self.future_demo_data = 'data/multivariate/demo_future.csv' self.data: pd.DataFrame = None self.n_predict = 3 self.window_length = 7 self.target_column = 'y' self.exog_columns = [] # Define if the model's result is going to be rounded self.round_results = True # Delete old temp files oder than n minutes self.delete_file_old_than_n_minutes = 10 self.plot_figsize_full_screen = (20, 4) # -------------------- # # Model Related Params # # -------------------- # # XGBoost # self.xgboost = XGBoost() self.xgboost_cv = False self.xgboost_params = self.xgboost.cv_params self.xgboost_strategy = 'recursive' self.xgboost_forecast = None self.xgboost_test = None print('Init Gradio app') # Prophet # self.prophet = ProphetForecaster() self.prophet__seasonality_mode = 'multiplicative' self.prophet__add_country_holidays = {'country_name': 'Singapore'} self.prophet__yearly_seasonality = True self.prophet__weekly_seasonality = False self.prophet__daily_seasonality = False def checkbox__round_results__change(self, val): self.round_results = val def textbox__target_column__change(self, val): print('Updating textbox__target_column:', val) self.target_column = val def btn__profiling__click(self): self.analyser.fit(self.data) self.analyser.profiling() return ( self.update__md__profiling(), self.update__plot__changepoints()) def btn__plot_correlation__click(self): return (self.update__plot__correlation()) def file__historical__upload( self, file ): self.data = pd.read_csv( file.name, index_col='datetime', parse_dates=['datetime']) print('[file__historical__upload]') return ( self.update__df__table_view(), self.update__dropdown__chart_view_filter(), self.update__dropdown__seasonality_decompose(), self.update__plot__chart_view()) def file__future__upload( self, file ): self.__handle_future_data_upload(file.name) return ( self.update__df__table_view(), self.update__dropdown__chart_view_filter(), self.update__dropdown__seasonality_decompose(), self.update__plot__chart_view(), self.update__number__n_predict()) def btn__load_future_demo__click( self ): self.__handle_future_data_upload(self.future_demo_data) # [df__table_view, number__n_predict] return ( self.update__df__table_view(), self.update__dropdown__chart_view_filter(), self.update__dropdown__seasonality_decompose(), self.update__plot__chart_view(), self.update__number__n_predict()) def __handle_future_data_upload( self, path ): data = pd.read_csv( path, index_col='datetime', parse_dates=['datetime']) self.exog_columns = data.columns.tolist() self.n_predict = len(data) print( f"[file__future__upload] with {self.exog_columns} columns") self.data = pd.concat( [self.data, data], axis=0) def number__n_predict__change( self, val ): print(f'[number__n_predict__change], {val}') self.n_predict = val def number__window_length__change( self, val): print(f'[number__window_length__change], {val}') self.window_length = val def btn__fit_data__click( self): data = self.data.drop(columns=self.exog_columns).dropna(how='any') self.forecaster.fit( data, target_col=self.target_column, n_predict=self.n_predict, window_length=self.window_length, exog=None if len( self.exog_columns) == 0 else self.data[self.exog_columns]) return ( gr.Number(interactive=False), # number__n_predict gr.Number(interactive=False), # number__window_length gr.File(interactive=False), # file__historical gr.File(interactive=False), # file__future gr.Button(visible=False), # btn__fit_data gr.Column(visible=True), # column__models gr.Button(visible=False), # btn__load_historical_demo gr.Button(visible=False), # btn__load_future_demo self.update__md__forecast_data_info() ) def btn__load_historical_demo__click( self ): self.data = pd.read_csv( self.historical_demo_data, index_col='datetime', parse_dates=['datetime']) return ( self.update__df__table_view(), self.update__dropdown__chart_view_filter(), self.update__dropdown__seasonality_decompose(), self.update__plot__chart_view() ) def dropdown__chart_view_filter__change(self, options): return (self.update__plot__chart_view(options)) def dropdown__seasonality_decompose__change(self, col): return ( self.update__plot__seasonality_decompose(col), self.update__plot_acg_pacf(col)) # ------------------------ # # XGboost Model Operations # # ------------------------ # def btn__train_xgboost__click(self): (test, forecast, best_params) = self.xgboost.fit_predict( y=self.forecaster.y, y_train=self.forecaster.y_train, window_length=self.forecaster.window_length, fh=self.forecaster.fh, fh_test=self.forecaster.fh_test, params=self.xgboost_params, X=self.forecaster.X, X_train=self.forecaster.X_train, X_test=self.forecaster.X_test, X_future=self.forecaster.X_future ) print(test, forecast, best_params) self.xgboost_forecast = forecast self.xgboost_test = test return ( self.update__plot__xgboost_result(test, forecast), self.update__file__xgboost_result(), self.update__df__xgboost_result()) def btn__set_xgboost_params__click(self, text): params = json.loads(text.replace("'", '"')) self.xgboost_params = params return ( self.update__json_xgboost_params() ) def checkbox__xgboost_round__change(self, val): self.xgboost.round_result = val # ----------------------------------- # # Prophet Model Operations & Updaters # # ----------------------------------- # def btn__forecast_with_prophet__click(self): self.prophet.fit_predict( self.forecaster.y_train, self.forecaster.y, self.forecaster.fh, self.forecaster.fh_test, self.forecaster.period, self.forecaster.freq, X=self.forecaster.exog, seasonality_mode=self.prophet__seasonality_mode, add_country_holidays=self.prophet__add_country_holidays, yearly_seasonality=self.prophet__yearly_seasonality, weekly_seasonality=self.prophet__weekly_seasonality, daily_seasonality=self.prophet__daily_seasonality, round_val=self.round_results) return ( self.update__plot__prophet_result(), self.update__file__prophet_result(), self.update__df__prophet_result()) def update__plot__prophet_result(self): fig, ax = plt.subplots(figsize=self.plot_figsize_full_screen) plot_series( self.forecaster.y_train[-2 * self.forecaster.period:], self.forecaster.y_test, self.prophet.predict, self.prophet.forecast, pred_interval=self.prophet.forecast_interval, labels=['Train', 'Test', 'Predicted - Test', 'Forecast'], ax=ax) ax.set_title('Prophet Forecast Result') ax.legend(loc='upper left') fig.tight_layout() return gr.Plot(fig) def update__file__prophet_result(self): prophet_forecast_df = pd.DataFrame(self.prophet.forecast) path = self.__create_temp_csv_file(prophet_forecast_df) return gr.File(path) def update__df__prophet_result(self): prophet_forecast_df = self.prophet.forecast.reset_index() return gr.Dataframe(value=prophet_forecast_df) # =============================== # # || Gradio Component Updaters || # # =============================== # def update__plot__changepoints(self): fig, axs = plt.subplots(2, 1, figsize=(20, 8)) axs[0].plot(self.data[['y']]) axs[0].text(self.data.index[0], axs[0].get_ylim()[1]*0.9, self.analyser.quantity_predictability[0], fontsize=20) for i, p in enumerate(self.analyser.quantity_change_points): axs[0].axvline(x=p) axs[0].text(p, axs[0].get_ylim()[1]*0.9, self.analyser.quantity_predictability[i+1], fontsize=20) axs[1].plot(self.data[['y']]) axs[1].text(self.data.index[0], axs[1].get_ylim()[1]*0.9, self.analyser.intermittent_predictability[0], fontsize=20) for i, p in enumerate(self.analyser.intermittent_change_points): axs[1].axvline(x=p) axs[1].text(p, axs[1].get_ylim()[1]*0.9, self.analyser.intermittent_predictability[i+1], fontsize=20) axs[0].set_title('Quantity Change Points & Predictability') axs[1].set_title('Intermittent Change Points & Predictability') fig.tight_layout() return gr.Plot(fig) def update__md__profiling(self): return (f""" \n### Data Characteristic: \n # {self.analyser.characteristic} \n --- \n### Quantity Change Points: {self.analyser.quantity_change_points.astype(str).tolist()} \n### Quantity Predictability: {self.analyser.quantity_predictability} \n### Intermittent Change Points: {self.analyser.intermittent_change_points.astype(str).tolist()} \n### Intermittent Predictability: {self.analyser.intermittent_predictability} """) def update__md__forecast_data_info(self): return gr.Markdown(value=f' \ **Forecasting for these timestamps**: \ {self.forecaster.fh.to_pandas().astype(str).tolist()} \ \n **Data Period**: {self.forecaster.period} \ \n **Data Frequency**: {self.forecaster.freq} \ ') def update__plot__correlation(self): fig, ax = plt.subplots(figsize=(20, 8)) corr = self.data.corr(numeric_only=True) mask = np.triu(np.ones_like(corr, dtype=bool)) sns.heatmap( corr, mask=mask, square=True, annot=True, cmap='coolwarm', linewidths=.5, cbar_kws={"shrink": .5}, ax=ax) fig.tight_layout() return gr.Plot(fig) def update__df__table_view( self ): data = self.data.reset_index() return gr.Dataframe(value=data) def update__number__n_predict( self ): return gr.Number(self.n_predict, interactive=False) def update__dropdown__chart_view_filter(self): options = self.data.columns.tolist() return gr.Dropdown(options, value=options) def update__dropdown__seasonality_decompose(self): options = self.data.columns.tolist() return gr.Dropdown(options) def update__plot__seasonality_decompose(self, col): seasonal = seasonal_decompose(self.data[col].dropna()) fig = seasonal.plot() return gr.Plot(fig) def update__plot_acg_pacf(self, col): fig, axs = plt.subplots(2, 1, sharex=True, sharey=True) plot_acf(self.data[col].dropna(), ax=axs[0], zero=False) plot_pacf(self.data[col].dropna(), ax=axs[1], zero=False) axs[0].set_title('Auto Correlation') axs[1].set_title('Partial Auto Correlation') return gr.Plot(fig) # ---------------------- # # Update XGboost Results # # ---------------------- # def update__json_xgboost_params(self): return gr.JSON(value=self.xgboost_params) def update__plot__xgboost_result(self, test, predict): fig, ax = plt.subplots(figsize=self.plot_figsize_full_screen) plot_series( self.forecaster.y_train[-2*self.forecaster.period:], self.forecaster.y_test, test, predict, labels=["y_train (part)", "y_test", "y_pred", 'y_forecast'], x_label='Date', ax=ax) ax.set_xticklabels(ax.get_xticklabels(), rotation=45) fig.tight_layout() return gr.Plot(fig) def update__plot__chart_view(self, cols=None): fig, ax = plt.subplots(figsize=self.plot_figsize_full_screen) _cols = cols if _cols is None: _cols = self.data.columns print('[update__plot__chart_view]') for col in _cols: ax.plot(self.data[[col]], label=col) fig.legend() fig.tight_layout() return gr.Plot(fig) def update__file__xgboost_result(self): path = self.__create_temp_csv_file(self.xgboost_forecast) return gr.File(path) def update__df__xgboost_result(self): # xgboost_forecast is actually a Series instead of proper DataFrame # Re constructing a proper dataframe for gradio to take data = pd.DataFrame( {"datetime": self.xgboost_forecast.index, "y": self.xgboost_forecast.values}) return gr.Dataframe(value=data) # ------------- # # Util Function # # ------------- # def __create_temp_csv_file(self, df) -> str: time_format = "%Y%m%d%H%M%S" directory = 'temp' now = datetime.datetime.now() # Check if there are old files, remove them # for filename in os.listdir(directory): file_path = os.path.join(directory, filename) file_time = datetime.datetime.strptime( filename.split('.')[0], time_format) # If the file is older than 3 minutes, delete the file if now > datetime.timedelta( minutes=self.delete_file_old_than_n_minutes) + file_time: print('deleting olde file: ', filename) os.remove(file_path) new_file_name = now.strftime(format=time_format) + '.csv' new_file_path = os.path.join(directory, new_file_name) df.to_csv(new_file_path) return new_file_path