Spaces:
Runtime error
Runtime error
from sktime.forecasting.fbprophet import Prophet | |
class ProphetForecaster(): | |
def __init__(self) -> None: | |
print('[Prophet] Init customized prophet model') | |
pass | |
def fit_predict( | |
self, | |
y_train, | |
y_future, | |
fh, | |
fh_test, | |
sp, | |
freq, | |
round_val=False, | |
X=None, | |
seasonality_mode=None, | |
add_country_holidays=None, | |
yearly_seasonality=False, | |
weekly_seasonality=False, | |
daily_seasonality=False | |
): | |
print('[Prophet] Start forecasting') | |
round_decimal = 0 if round else 0.4 | |
forecaster = Prophet( | |
seasonality_mode=seasonality_mode, | |
n_changepoints=int(len(y_train) / sp), | |
add_country_holidays=add_country_holidays, | |
yearly_seasonality=yearly_seasonality, | |
weekly_seasonality=weekly_seasonality, | |
daily_seasonality=daily_seasonality | |
) | |
forecaster.fit(y_train) | |
self.predict = forecaster.predict(fh_test) | |
self.predict_interval = forecaster.predict_interval( | |
fh_test, coverage=.9) | |
forecaster.update(y_future, update_params=False) | |
self.forecast = forecaster.predict(fh) | |
self.forecast_interval = forecaster.predict_interval(fh, coverage=.9) | |
self.predict = round(self.predict, round_decimal) | |
self.predict_interval = round(self.predict_interval, round_decimal) | |
self.forecast = round(self.forecast, round_decimal) | |
self.forecast_interval = round(self.forecast_interval, round_decimal) | |
print('[Prophet] Fit-predict completed') | |