TuanScientist commited on
Commit
a807e9f
1 Parent(s): 1d54feb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -34
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import gradio as gr
2
  import pandas as pd
3
- from neuralprophet import NeuralProphet
4
  import warnings
5
- import torch.optim as optim
6
- from torch.optim.lr_scheduler import OneCycleLR
7
 
 
8
  warnings.filterwarnings("ignore", category=UserWarning)
9
 
10
  url = "VN Index Historical Data.csv"
@@ -14,58 +13,55 @@ df = df.rename(columns={"Date": "ds", "Price": "y"})
14
  df.fillna(method='ffill', inplace=True)
15
  df.dropna(inplace=True)
16
 
17
-
18
- class CustomNeuralProphet(NeuralProphet):
19
- def __init__(self, **kwargs):
20
- super().__init__(**kwargs)
21
- self.optimizer = None
22
-
23
- m = CustomNeuralProphet(
24
  n_forecasts=30,
25
  n_lags=12,
26
  changepoints_range=1,
27
  num_hidden_layers=3,
 
 
28
  yearly_seasonality=True,
29
  n_changepoints=150,
30
- trend_reg_threshold=False,
31
  d_hidden=3,
32
  global_normalization=True,
33
  seasonality_reg=1,
34
  unknown_data_normalization=True,
35
  seasonality_mode="multiplicative",
36
  drop_missing=True,
37
- learning_rate=0.03,
38
  )
39
 
40
- # Set the custom LR scheduler
41
- m.fit(df, freq='D') # Fit the model first before accessing the optimizer
42
- m.optimizer = optim.Adam(m.model.parameters(), lr=0.03) # Example optimizer, adjust as needed
43
-
44
- lr_scheduler = OneCycleLR(
45
- m.optimizer,
46
- max_lr=0.1,
47
- total_steps=100,
48
- pct_start=0.3,
49
- anneal_strategy='cos',
50
- ) # Example LR scheduler, adjust as needed
51
-
52
- m.trainer.lr_schedulers = [lr_scheduler] # Set the LR scheduler to the trainer
53
 
54
  future = m.make_future_dataframe(df, periods=30, n_historic_predictions=True)
55
  forecast = m.predict(future)
56
 
57
-
58
  def predict_vn_index(option=None):
59
- fig = m.plot(forecast)
60
- path = "forecast_plot.png"
61
- fig.savefig(path)
 
 
 
 
 
62
  disclaimer = "Quý khách chỉ xem đây là tham khảo, công ty không chịu bất cứ trách nhiệm nào về tình trạng đầu tư của quý khách."
63
- return path, disclaimer
 
64
 
65
 
66
  if __name__ == "__main__":
67
  dropdown = gr.inputs.Dropdown(["VNIndex"], label="Choose an option", default="VNIndex")
68
- image_output = gr.outputs.Image(type="file", label="Forecast Plot")
69
- disclaimer_output = gr.outputs.Textbox(label="Disclaimer")
70
- interface = gr.Interface(fn=predict_vn_index, inputs=dropdown, outputs=[image_output, disclaimer_output], title="Dự báo VN Index 30 ngày tới")
71
- interface.launch()
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
+ from neuralprophet import NeuralProphet, set_log_level
4
  import warnings
 
 
5
 
6
+ set_log_level("ERROR")
7
  warnings.filterwarnings("ignore", category=UserWarning)
8
 
9
  url = "VN Index Historical Data.csv"
 
13
  df.fillna(method='ffill', inplace=True)
14
  df.dropna(inplace=True)
15
 
16
+ m = NeuralProphet(
 
 
 
 
 
 
17
  n_forecasts=30,
18
  n_lags=12,
19
  changepoints_range=1,
20
  num_hidden_layers=3,
21
+ daily_seasonality=False,
22
+ weekly_seasonality=True,
23
  yearly_seasonality=True,
24
  n_changepoints=150,
25
+ trend_reg_threshold=False, # Disable trend regularization threshold
26
  d_hidden=3,
27
  global_normalization=True,
28
  seasonality_reg=1,
29
  unknown_data_normalization=True,
30
  seasonality_mode="multiplicative",
31
  drop_missing=True,
32
+ learning_rate=0.03
33
  )
34
 
35
+ m.fit(df, freq='D')
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  future = m.make_future_dataframe(df, periods=30, n_historic_predictions=True)
38
  forecast = m.predict(future)
39
 
 
40
  def predict_vn_index(option=None):
41
+ fig1 = m.plot(forecast)
42
+ fig1_path = "forecast_plot1.png"
43
+ fig1.savefig(fig1_path)
44
+
45
+ # Add code to generate the second image (fig2)
46
+ fig2 = m.plot_latest_forecast(forecast) # Replace this line with code to generate the second image
47
+ fig2_path = "forecast_plot2.png"
48
+ fig2.savefig(fig2_path)
49
  disclaimer = "Quý khách chỉ xem đây là tham khảo, công ty không chịu bất cứ trách nhiệm nào về tình trạng đầu tư của quý khách."
50
+
51
+ return fig1_path, fig2_path, disclaimer
52
 
53
 
54
  if __name__ == "__main__":
55
  dropdown = gr.inputs.Dropdown(["VNIndex"], label="Choose an option", default="VNIndex")
56
+ outputs = [
57
+ gr.outputs.Image(type="filepath", label="First Image"),
58
+ gr.outputs.Image(type="filepath", label="Second Image"),
59
+ gr.outputs.Textbox(label="Disclaimer")
60
+ ]
61
+ interface = gr.Interface(fn=predict_vn_index, inputs=dropdown, outputs=outputs, title="Dự báo VN Index 30 ngày tới")
62
+ interface.launch(share=True)
63
+
64
+
65
+
66
+
67
+