Vishwas1 commited on
Commit
b60a104
·
verified ·
1 Parent(s): 0be8fa0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -76
app.py CHANGED
@@ -10,15 +10,22 @@ import matplotlib.pyplot as plt
10
 
11
  import torch
12
  from gluonts.dataset.common import ListDataset
 
13
 
14
- # Moirai 2.0 via Uni2TS (per Salesforce's example)
15
- # https://www.salesforce.com/blog/moirai-2-0/
16
- from uni2ts.model.moirai2 import Moirai2Forecast, Moirai2Module # type: ignore
 
 
 
 
 
 
 
17
 
18
  MODEL_ID = "Salesforce/moirai-2.0-R-small"
19
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
- # Load the Moirai 2.0 module once at startup
22
  _MODULE = None
23
  def load_module():
24
  global _MODULE
@@ -26,6 +33,59 @@ def load_module():
26
  _MODULE = Moirai2Module.from_pretrained(MODEL_ID)
27
  return _MODULE
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def fetch_series(ticker: str, years: int) -> pd.Series:
30
  """Fetch daily close price and align to business-day frequency."""
31
  data = yf.download(
@@ -39,17 +99,14 @@ def fetch_series(ticker: str, years: int) -> pd.Series:
39
  if data is None or data.empty:
40
  raise gr.Error(f"No price data found for '{ticker}'.")
41
 
42
- # Choose a price column
43
  col = "Close" if "Close" in data.columns else ("Adj Close" if "Adj Close" in data.columns else None)
44
  if col is None:
45
  raise gr.Error(f"Unexpected columns from yfinance: {list(data.columns)}")
46
 
47
- # yfinance can sometimes return a MultiIndex (e.g., if a list of tickers slips through)
48
  if isinstance(data.columns, pd.MultiIndex):
49
  if ticker in data[col].columns:
50
  s = data[col][ticker]
51
  else:
52
- # fall back to the first column
53
  s = data[col].iloc[:, 0]
54
  else:
55
  s = data[col]
@@ -66,98 +123,160 @@ def fetch_series(ticker: str, years: int) -> pd.Series:
66
  raise gr.Error(f"Only missing values for '{ticker}'.")
67
  return y
68
 
69
-
70
- def forecast_ticker(ticker: str,
71
- horizon: int,
72
- lookback_years: int,
73
- context_hint: int):
74
  ticker = (ticker or "").strip().upper()
75
  if not ticker:
76
  raise gr.Error("Please enter a ticker symbol (e.g., AAPL).")
77
  if horizon < 1:
78
  raise gr.Error("Forecast horizon must be at least 1.")
79
-
80
- # 1) Get history
81
  y = fetch_series(ticker, lookback_years)
82
- if len(y) < 50:
83
- raise gr.Error("Not enough history to forecast (need at least 50 points).")
84
 
85
- # 2) Build dataset for GluonTS-style predictor
86
- # Use business-day freq ('B'); pick a context <= history length.
87
- default_ctx = 1680 # from Moirai 2.0 examples
88
- ctx = int(np.clip(context_hint or default_ctx, 32, len(y)))
89
- target = y.values[-ctx:]
90
- start_idx = y.index[-ctx]
 
 
 
91
 
92
- ds = ListDataset([{"start": start_idx, "target": target}], freq="B")
 
 
93
 
94
- # 3) Create forecast wrapper and predictor
95
- module = load_module()
96
- model = Moirai2Forecast(
97
- module=module,
98
- prediction_length=int(horizon),
99
- context_length=ctx,
100
- target_dim=1,
101
- feat_dynamic_real_dim=0,
102
- past_feat_dynamic_real_dim=0,
103
- )
104
- predictor = model.create_predictor(batch_size=32) # remove device=...
105
 
 
 
 
 
106
 
107
- # 4) Predict
108
- forecast = next(iter(predictor.predict(ds)))
 
109
 
110
- # 5) Extract a reasonable central estimate
111
- if hasattr(forecast, "mean"):
112
- yhat = np.asarray(forecast.mean)
113
- elif hasattr(forecast, "quantile"):
114
- # 50th percentile as point
115
- yhat = np.asarray(forecast.quantile(0.5))
116
- elif hasattr(forecast, "samples"):
117
- yhat = np.asarray(forecast.samples).mean(axis=0)
118
  else:
119
- # very defensive fallback
120
- yhat = np.asarray(forecast)
 
 
 
 
 
121
 
122
- # Guard length (some forecast objects can be slightly longer)
123
- yhat = np.asarray(yhat).ravel()[:horizon]
124
 
125
- # 6) Assemble dates & outputs
126
- # Next business days after the last historical date
127
- future_idx = pd.bdate_range(y.index[-1] + pd.tseries.offsets.BDay(), periods=horizon)
128
- pred = pd.Series(yhat, index=future_idx, name="predicted_close")
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- # 7) Plot
131
- fig = plt.figure(figsize=(10, 5))
132
- plt.plot(y.index, y.values, label="history")
133
- plt.plot(pred.index, pred.values, label="forecast")
134
- plt.title(f"{ticker} close price forecast (Moirai 2.0 R-small)")
135
- plt.xlabel("Date"); plt.ylabel("Price"); plt.legend(); plt.tight_layout()
 
 
 
 
 
 
 
136
 
137
- # 8) Table
138
- out_df = pd.DataFrame({"date": pred.index, "predicted_close": pred.values})
139
- return fig, out_df
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- with gr.Blocks(title="Moirai 2.0 — Stock Price Forecast (Research)") as demo:
 
142
  gr.Markdown(
143
  """
144
- # Moirai 2.0 — Stock Price Forecast (Research)
145
- Enter a ticker to fetch recent daily prices and generate a short-term forecast using **Salesforce/moirai-2.0-R-small**.
146
- > **Important**: For **research/educational** use only. Not investment advice. Model license is **CC-BY-NC-4.0 (non-commercial)**.
 
147
  """
148
  )
149
- with gr.Row():
150
- ticker = gr.Textbox(label="Ticker", value="AAPL", placeholder="e.g., AAPL, MSFT, TSLA")
151
- horizon = gr.Slider(5, 120, value=30, step=1, label="Forecast horizon (business days)")
152
- with gr.Row():
153
- lookback = gr.Slider(1, 10, value=5, step=1, label="Lookback window (years of history)")
154
- ctx = gr.Slider(64, 2000, value=1680, step=16, label="Context length (points)")
155
 
156
- run = gr.Button("Run forecast", variant="primary")
157
- plot = gr.Plot(label="History + Forecast")
158
- table = gr.Dataframe(label="Forecast table", interactive=False)
 
 
 
 
 
 
 
 
159
 
160
- run.click(forecast_ticker, inputs=[ticker, horizon, lookback, ctx], outputs=[plot, table])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  if __name__ == "__main__":
163
  demo.launch()
 
 
10
 
11
  import torch
12
  from gluonts.dataset.common import ListDataset
13
+ from pandas.tseries.frequencies import to_offset
14
 
15
+ # Moirai 2.0 via Uni2TS
16
+ try:
17
+ from uni2ts.model.moirai2 import Moirai2Forecast, Moirai2Module
18
+ except Exception as e:
19
+ raise ImportError(
20
+ "Moirai 2.0 not found in your Uni2TS install. "
21
+ "Make sure requirements.txt installs Uni2TS from GitHub: "
22
+ "git+https://github.com/SalesforceAIResearch/uni2ts.git\n"
23
+ f"Original error: {e}"
24
+ )
25
 
26
  MODEL_ID = "Salesforce/moirai-2.0-R-small"
 
27
 
28
+ # ---- Model loader (one-time) ----
29
  _MODULE = None
30
  def load_module():
31
  global _MODULE
 
33
  _MODULE = Moirai2Module.from_pretrained(MODEL_ID)
34
  return _MODULE
35
 
36
+ # ---- Utilities ----
37
+ def _future_index(last_idx: pd.Timestamp, freq: str, horizon: int) -> pd.DatetimeIndex:
38
+ """Create future timestamps continuing the given freq."""
39
+ off = to_offset(freq)
40
+ start = last_idx + off
41
+ return pd.date_range(start=start, periods=horizon, freq=freq)
42
+
43
+ def _run_forecast_on_series(y: pd.Series, freq: str, horizon: int, context_hint: int, title: str):
44
+ """Core forecasting routine on an indexed univariate series y with pandas freq string."""
45
+ if len(y) < 50:
46
+ raise gr.Error("Need at least 50 points to forecast.")
47
+ ctx = int(np.clip(context_hint or 1680, 32, len(y)))
48
+ target = y.values[-ctx:].astype(np.float32)
49
+ start_idx = y.index[-ctx]
50
+
51
+ ds = ListDataset([{"start": start_idx, "target": target}], freq=freq)
52
+
53
+ module = load_module()
54
+ model = Moirai2Forecast(
55
+ module=module,
56
+ prediction_length=int(horizon),
57
+ context_length=ctx,
58
+ target_dim=1,
59
+ feat_dynamic_real_dim=0,
60
+ past_feat_dynamic_real_dim=0,
61
+ )
62
+ predictor = model.create_predictor(batch_size=32) # device managed internally
63
+
64
+ forecast = next(iter(predictor.predict(ds)))
65
+ if hasattr(forecast, "mean"):
66
+ yhat = np.asarray(forecast.mean)
67
+ elif hasattr(forecast, "quantile"):
68
+ yhat = np.asarray(forecast.quantile(0.5))
69
+ elif hasattr(forecast, "samples"):
70
+ yhat = np.asarray(forecast.samples).mean(axis=0)
71
+ else:
72
+ yhat = np.asarray(forecast)
73
+
74
+ yhat = np.asarray(yhat).ravel()[:horizon]
75
+ future_idx = _future_index(y.index[-1], freq, horizon)
76
+ pred = pd.Series(yhat, index=future_idx, name="predicted")
77
+
78
+ # Plot
79
+ fig = plt.figure(figsize=(10, 5))
80
+ plt.plot(y.index, y.values, label="history")
81
+ plt.plot(pred.index, pred.values, label="forecast")
82
+ plt.title(title)
83
+ plt.xlabel("Time"); plt.ylabel("Value"); plt.legend(); plt.tight_layout()
84
+
85
+ out_df = pd.DataFrame({"date": pred.index, "prediction": pred.values})
86
+ return fig, out_df
87
+
88
+ # ---- Ticker path ----
89
  def fetch_series(ticker: str, years: int) -> pd.Series:
90
  """Fetch daily close price and align to business-day frequency."""
91
  data = yf.download(
 
99
  if data is None or data.empty:
100
  raise gr.Error(f"No price data found for '{ticker}'.")
101
 
 
102
  col = "Close" if "Close" in data.columns else ("Adj Close" if "Adj Close" in data.columns else None)
103
  if col is None:
104
  raise gr.Error(f"Unexpected columns from yfinance: {list(data.columns)}")
105
 
 
106
  if isinstance(data.columns, pd.MultiIndex):
107
  if ticker in data[col].columns:
108
  s = data[col][ticker]
109
  else:
 
110
  s = data[col].iloc[:, 0]
111
  else:
112
  s = data[col]
 
123
  raise gr.Error(f"Only missing values for '{ticker}'.")
124
  return y
125
 
126
+ def forecast_ticker(ticker: str, horizon: int, lookback_years: int, context_hint: int):
 
 
 
 
127
  ticker = (ticker or "").strip().upper()
128
  if not ticker:
129
  raise gr.Error("Please enter a ticker symbol (e.g., AAPL).")
130
  if horizon < 1:
131
  raise gr.Error("Forecast horizon must be at least 1.")
 
 
132
  y = fetch_series(ticker, lookback_years)
133
+ return _run_forecast_on_series(y, "B", horizon, context_hint, f"{ticker} — forecast (Moirai 2.0 R-small)")
 
134
 
135
+ # ---- CSV path ----
136
+ def _read_csv_columns(file_path: str) -> pd.DataFrame:
137
+ # Try very tolerant CSV read
138
+ try:
139
+ df = pd.read_csv(file_path)
140
+ except Exception:
141
+ # if it’s actually TSV or weird delimiter, try python engine
142
+ df = pd.read_csv(file_path, sep=None, engine="python")
143
+ return df
144
 
145
+ def _coerce_numeric_series(s: pd.Series) -> pd.Series:
146
+ s = pd.to_numeric(s, errors="coerce")
147
+ return s.dropna().astype(np.float32)
148
 
149
+ def build_series_from_csv(file, value_col: str, date_col: str, freq_choice: str) -> tuple[pd.Series, str]:
150
+ """
151
+ Returns (series y with DateTimeIndex, freq string).
152
+ - If date_col provided: parse dates and (optionally) infer freq.
153
+ - If no date_col: require freq_choice != 'auto'; build synthetic dates from 2000-01-01.
154
+ """
155
+ if file is None:
156
+ raise gr.Error("Please upload a CSV file.")
 
 
 
157
 
158
+ # Gradio v4/v5 file object compatibility
159
+ path = getattr(file, "name", None) or getattr(file, "path", None) or (file if isinstance(file, str) else None)
160
+ if path is None:
161
+ raise gr.Error("Could not read the uploaded file path.")
162
 
163
+ df = _read_csv_columns(path)
164
+ if df.empty:
165
+ raise gr.Error("Uploaded file is empty.")
166
 
167
+ # Pick value column
168
+ if value_col:
169
+ if value_col not in df.columns:
170
+ raise gr.Error(f"Value column '{value_col}' not found. Available: {list(df.columns)}")
171
+ vals = _coerce_numeric_series(df[value_col])
 
 
 
172
  else:
173
+ # Try the first numeric-looking column
174
+ numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
175
+ if not numeric_cols:
176
+ # Coerce first column
177
+ vals = _coerce_numeric_series(df.iloc[:, 0])
178
+ else:
179
+ vals = _coerce_numeric_series(df[numeric_cols[0]])
180
 
181
+ if vals.empty or len(vals) < 10:
182
+ raise gr.Error("Not enough numeric values after parsing (need at least 10).")
183
 
184
+ # With datetime column
185
+ if date_col:
186
+ if date_col not in df.columns:
187
+ raise gr.Error(f"Date column '{date_col}' not found. Available: {list(df.columns)}")
188
+ dt = pd.to_datetime(df[date_col], errors="coerce")
189
+ mask = dt.notna() & vals.notna()
190
+ dt = pd.DatetimeIndex(dt[mask])
191
+ vals = vals[mask]
192
+ if len(vals) < 10:
193
+ raise gr.Error("Too few valid rows after parsing date/value columns.")
194
+ # sort by date
195
+ order = np.argsort(dt.values)
196
+ dt = dt[order]
197
+ vals = vals.iloc[order].reset_index(drop=True)
198
+ y = pd.Series(vals.values, index=dt, name=value_col or "value").copy()
199
+ y.index = y.index.tz_localize(None)
200
 
201
+ # Determine frequency
202
+ freq = None
203
+ if freq_choice and freq_choice != "auto":
204
+ freq = freq_choice
205
+ y = y.asfreq(freq, method="ffill")
206
+ else:
207
+ # try to infer; if None, fallback to 'D'
208
+ freq = pd.infer_freq(y.index)
209
+ if freq is None:
210
+ # try business day if looks like weekdays only
211
+ weekday_ratio = (y.index.dayofweek < 5).mean()
212
+ freq = "B" if weekday_ratio > 0.95 else "D"
213
+ y = y.asfreq(freq, method="ffill")
214
 
215
+ else:
216
+ # No date column: require explicit freq
217
+ if not freq_choice or freq_choice == "auto":
218
+ raise gr.Error("No date column given. Please choose a frequency (e.g., D, B, H).")
219
+ freq = freq_choice
220
+ idx = pd.date_range(start="2000-01-01", periods=len(vals), freq=freq)
221
+ y = pd.Series(vals.values, index=idx, name=value_col or "value").copy()
222
+
223
+ # Final sanity
224
+ if y.isna().all():
225
+ raise gr.Error("Series is all-NaN after processing.")
226
+ return y, freq
227
+
228
+ def forecast_csv(file, value_col: str, date_col: str, freq_choice: str, horizon: int, context_hint: int):
229
+ y, freq = build_series_from_csv(file, value_col.strip(), date_col.strip(), freq_choice.strip())
230
+ return _run_forecast_on_series(y, freq, horizon, context_hint, f"Uploaded series — forecast (freq={freq})")
231
 
232
+ # ---- UI ----
233
+ with gr.Blocks(title="Moirai 2.0 — Time Series Forecast (Research)") as demo:
234
  gr.Markdown(
235
  """
236
+ # Moirai 2.0 — Time Series Forecast (Research)
237
+ Use **Salesforce/moirai-2.0-R-small** (via Uni2TS) to forecast either a stock ticker *or* a generic CSV time series.
238
+
239
+ > **Important**: Research/educational use only. Not investment advice. Model license: **CC-BY-NC-4.0 (non-commercial)**.
240
  """
241
  )
 
 
 
 
 
 
242
 
243
+ with gr.Tab("By Ticker"):
244
+ with gr.Row():
245
+ ticker = gr.Textbox(label="Ticker", value="AAPL", placeholder="e.g., AAPL, MSFT, TSLA")
246
+ horizon_t = gr.Slider(5, 120, value=30, step=1, label="Forecast horizon (steps)")
247
+ with gr.Row():
248
+ lookback = gr.Slider(1, 10, value=5, step=1, label="Lookback window (years of history)")
249
+ ctx_t = gr.Slider(64, 2000, value=1680, step=16, label="Context length")
250
+ run_t = gr.Button("Run forecast", variant="primary")
251
+ plot_t = gr.Plot(label="History + Forecast")
252
+ table_t = gr.Dataframe(label="Forecast table", interactive=False)
253
+ run_t.click(forecast_ticker, inputs=[ticker, horizon_t, lookback, ctx_t], outputs=[plot_t, table_t])
254
 
255
+ with gr.Tab("Upload CSV"):
256
+ gr.Markdown(
257
+ "Upload a CSV with either (1) a **date/time column** and a **value column**, "
258
+ "or (2) just a numeric value column (then choose a frequency)."
259
+ )
260
+ with gr.Row():
261
+ file = gr.File(label="CSV file", file_types=[".csv"])
262
+ with gr.Row():
263
+ date_col = gr.Textbox(label="Date/time column (optional)", placeholder="e.g., date, timestamp")
264
+ value_col = gr.Textbox(label="Value column (optional — auto-detects first numeric)", placeholder="e.g., value, close")
265
+ with gr.Row():
266
+ freq_choice = gr.Dropdown(
267
+ label="Frequency",
268
+ value="auto",
269
+ choices=["auto", "B", "D", "H", "W", "M", "MS"],
270
+ info="If no date column, pick a freq (e.g., D)."
271
+ )
272
+ with gr.Row():
273
+ horizon_u = gr.Slider(1, 500, value=60, step=1, label="Forecast horizon (steps)")
274
+ ctx_u = gr.Slider(32, 5000, value=512, step=16, label="Context length")
275
+ run_u = gr.Button("Run forecast on CSV", variant="primary")
276
+ plot_u = gr.Plot(label="History + Forecast (CSV)")
277
+ table_u = gr.Dataframe(label="Forecast table (CSV)", interactive=False)
278
+ run_u.click(forecast_csv, inputs=[file, value_col, date_col, freq_choice, horizon_u, ctx_u], outputs=[plot_u, table_u])
279
 
280
  if __name__ == "__main__":
281
  demo.launch()
282
+