Spaces:
Runtime error
Runtime error
File size: 1,698 Bytes
8cf4695 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from sktime.forecasting.fbprophet import Prophet
class ProphetForecaster():
def __init__(self) -> None:
print('[Prophet] Init customized prophet model')
pass
def fit_predict(
self,
y_train,
y_future,
fh,
fh_test,
sp,
freq,
round_val=False,
X=None,
seasonality_mode=None,
add_country_holidays=None,
yearly_seasonality=False,
weekly_seasonality=False,
daily_seasonality=False
):
print('[Prophet] Start forecasting')
round_decimal = 0 if round else 0.4
forecaster = Prophet(
seasonality_mode=seasonality_mode,
n_changepoints=int(len(y_train) / sp),
add_country_holidays=add_country_holidays,
yearly_seasonality=yearly_seasonality,
weekly_seasonality=weekly_seasonality,
daily_seasonality=daily_seasonality
)
forecaster.fit(y_train)
self.predict = forecaster.predict(fh_test)
self.predict_interval = forecaster.predict_interval(
fh_test, coverage=.9)
forecaster.update(y_future, update_params=False)
self.forecast = forecaster.predict(fh)
self.forecast_interval = forecaster.predict_interval(fh, coverage=.9)
self.predict = round(self.predict, round_decimal)
self.predict_interval = round(self.predict_interval, round_decimal)
self.forecast = round(self.forecast, round_decimal)
self.forecast_interval = round(self.forecast_interval, round_decimal)
print('[Prophet] Fit-predict completed')
|