File size: 1,698 Bytes
8cf4695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from sktime.forecasting.fbprophet import Prophet


class ProphetForecaster():
    def __init__(self) -> None:
        print('[Prophet] Init customized prophet model')
        pass

    def fit_predict(
            self,
            y_train,
            y_future,
            fh,
            fh_test,
            sp,
            freq,
            round_val=False,
            X=None,
            seasonality_mode=None,
            add_country_holidays=None,
            yearly_seasonality=False,
            weekly_seasonality=False,
            daily_seasonality=False
    ):

        print('[Prophet] Start forecasting')

        round_decimal = 0 if round else 0.4

        forecaster = Prophet(
            seasonality_mode=seasonality_mode,
            n_changepoints=int(len(y_train) / sp),
            add_country_holidays=add_country_holidays,
            yearly_seasonality=yearly_seasonality,
            weekly_seasonality=weekly_seasonality,
            daily_seasonality=daily_seasonality
        )

        forecaster.fit(y_train)

        self.predict = forecaster.predict(fh_test)
        self.predict_interval = forecaster.predict_interval(
            fh_test, coverage=.9)

        forecaster.update(y_future, update_params=False)

        self.forecast = forecaster.predict(fh)
        self.forecast_interval = forecaster.predict_interval(fh, coverage=.9)

        self.predict = round(self.predict, round_decimal)
        self.predict_interval = round(self.predict_interval, round_decimal)
        self.forecast = round(self.forecast, round_decimal)
        self.forecast_interval = round(self.forecast_interval, round_decimal)

        print('[Prophet] Fit-predict completed')