from sktime.forecasting.fbprophet import Prophet class ProphetForecaster(): def __init__(self) -> None: print('[Prophet] Init customized prophet model') pass def fit_predict( self, y_train, y_future, fh, fh_test, sp, freq, round_val=False, X=None, seasonality_mode=None, add_country_holidays=None, yearly_seasonality=False, weekly_seasonality=False, daily_seasonality=False ): print('[Prophet] Start forecasting') round_decimal = 0 if round else 0.4 forecaster = Prophet( seasonality_mode=seasonality_mode, n_changepoints=int(len(y_train) / sp), add_country_holidays=add_country_holidays, yearly_seasonality=yearly_seasonality, weekly_seasonality=weekly_seasonality, daily_seasonality=daily_seasonality ) forecaster.fit(y_train) self.predict = forecaster.predict(fh_test) self.predict_interval = forecaster.predict_interval( fh_test, coverage=.9) forecaster.update(y_future, update_params=False) self.forecast = forecaster.predict(fh) self.forecast_interval = forecaster.predict_interval(fh, coverage=.9) self.predict = round(self.predict, round_decimal) self.predict_interval = round(self.predict_interval, round_decimal) self.forecast = round(self.forecast, round_decimal) self.forecast_interval = round(self.forecast_interval, round_decimal) print('[Prophet] Fit-predict completed')