Vishwas1 commited on
Commit
41714f8
·
verified ·
1 Parent(s): b60a104

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -57
app.py CHANGED
@@ -1,31 +1,34 @@
1
- import os
2
  import warnings
3
  warnings.filterwarnings("ignore")
4
 
5
  import gradio as gr
6
- import pandas as pd
7
  import numpy as np
 
8
  import yfinance as yf
9
  import matplotlib.pyplot as plt
10
 
11
- import torch
12
- from gluonts.dataset.common import ListDataset
13
  from pandas.tseries.frequencies import to_offset
 
14
 
15
- # Moirai 2.0 via Uni2TS
 
 
16
  try:
17
  from uni2ts.model.moirai2 import Moirai2Forecast, Moirai2Module
18
  except Exception as e:
19
  raise ImportError(
20
- "Moirai 2.0 not found in your Uni2TS install. "
21
- "Make sure requirements.txt installs Uni2TS from GitHub: "
22
- "git+https://github.com/SalesforceAIResearch/uni2ts.git\n"
23
  f"Original error: {e}"
24
  )
25
 
26
  MODEL_ID = "Salesforce/moirai-2.0-R-small"
 
27
 
28
- # ---- Model loader (one-time) ----
 
 
29
  _MODULE = None
30
  def load_module():
31
  global _MODULE
@@ -33,18 +36,25 @@ def load_module():
33
  _MODULE = Moirai2Module.from_pretrained(MODEL_ID)
34
  return _MODULE
35
 
36
- # ---- Utilities ----
 
 
37
  def _future_index(last_idx: pd.Timestamp, freq: str, horizon: int) -> pd.DatetimeIndex:
38
- """Create future timestamps continuing the given freq."""
39
  off = to_offset(freq)
40
  start = last_idx + off
41
  return pd.date_range(start=start, periods=horizon, freq=freq)
42
 
43
- def _run_forecast_on_series(y: pd.Series, freq: str, horizon: int, context_hint: int, title: str):
44
- """Core forecasting routine on an indexed univariate series y with pandas freq string."""
 
 
 
 
 
45
  if len(y) < 50:
46
  raise gr.Error("Need at least 50 points to forecast.")
47
- ctx = int(np.clip(context_hint or 1680, 32, len(y)))
 
48
  target = y.values[-ctx:].astype(np.float32)
49
  start_idx = y.index[-ctx]
50
 
@@ -59,7 +69,7 @@ def _run_forecast_on_series(y: pd.Series, freq: str, horizon: int, context_hint:
59
  feat_dynamic_real_dim=0,
60
  past_feat_dynamic_real_dim=0,
61
  )
62
- predictor = model.create_predictor(batch_size=32) # device managed internally
63
 
64
  forecast = next(iter(predictor.predict(ds)))
65
  if hasattr(forecast, "mean"):
@@ -73,7 +83,7 @@ def _run_forecast_on_series(y: pd.Series, freq: str, horizon: int, context_hint:
73
 
74
  yhat = np.asarray(yhat).ravel()[:horizon]
75
  future_idx = _future_index(y.index[-1], freq, horizon)
76
- pred = pd.Series(yhat, index=future_idx, name="predicted")
77
 
78
  # Plot
79
  fig = plt.figure(figsize=(10, 5))
@@ -85,9 +95,11 @@ def _run_forecast_on_series(y: pd.Series, freq: str, horizon: int, context_hint:
85
  out_df = pd.DataFrame({"date": pred.index, "prediction": pred.values})
86
  return fig, out_df
87
 
88
- # ---- Ticker path ----
 
 
89
  def fetch_series(ticker: str, years: int) -> pd.Series:
90
- """Fetch daily close price and align to business-day frequency."""
91
  data = yf.download(
92
  ticker,
93
  period=f"{years}y",
@@ -115,7 +127,7 @@ def fetch_series(ticker: str, years: int) -> pd.Series:
115
  y.name = ticker
116
  y.index = pd.DatetimeIndex(y.index).tz_localize(None)
117
 
118
- # Business-day index, forward-fill market holidays
119
  bidx = pd.bdate_range(y.index.min(), y.index.max())
120
  y = y.reindex(bidx).ffill()
121
 
@@ -132,13 +144,13 @@ def forecast_ticker(ticker: str, horizon: int, lookback_years: int, context_hint
132
  y = fetch_series(ticker, lookback_years)
133
  return _run_forecast_on_series(y, "B", horizon, context_hint, f"{ticker} — forecast (Moirai 2.0 R-small)")
134
 
135
- # ---- CSV path ----
 
 
136
  def _read_csv_columns(file_path: str) -> pd.DataFrame:
137
- # Try very tolerant CSV read
138
  try:
139
  df = pd.read_csv(file_path)
140
  except Exception:
141
- # if it’s actually TSV or weird delimiter, try python engine
142
  df = pd.read_csv(file_path, sep=None, engine="python")
143
  return df
144
 
@@ -146,16 +158,16 @@ def _coerce_numeric_series(s: pd.Series) -> pd.Series:
146
  s = pd.to_numeric(s, errors="coerce")
147
  return s.dropna().astype(np.float32)
148
 
149
- def build_series_from_csv(file, value_col: str, date_col: str, freq_choice: str) -> tuple[pd.Series, str]:
150
  """
151
  Returns (series y with DateTimeIndex, freq string).
152
- - If date_col provided: parse dates and (optionally) infer freq.
153
- - If no date_col: require freq_choice != 'auto'; build synthetic dates from 2000-01-01.
154
  """
155
  if file is None:
156
  raise gr.Error("Please upload a CSV file.")
157
 
158
- # Gradio v4/v5 file object compatibility
159
  path = getattr(file, "name", None) or getattr(file, "path", None) or (file if isinstance(file, str) else None)
160
  if path is None:
161
  raise gr.Error("Could not read the uploaded file path.")
@@ -164,72 +176,75 @@ def build_series_from_csv(file, value_col: str, date_col: str, freq_choice: str)
164
  if df.empty:
165
  raise gr.Error("Uploaded file is empty.")
166
 
167
- # Pick value column
 
168
  if value_col:
169
  if value_col not in df.columns:
170
  raise gr.Error(f"Value column '{value_col}' not found. Available: {list(df.columns)}")
171
  vals = _coerce_numeric_series(df[value_col])
172
  else:
173
- # Try the first numeric-looking column
174
  numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
175
- if not numeric_cols:
176
- # Coerce first column
177
- vals = _coerce_numeric_series(df.iloc[:, 0])
178
- else:
179
  vals = _coerce_numeric_series(df[numeric_cols[0]])
 
 
180
 
181
  if vals.empty or len(vals) < 10:
182
  raise gr.Error("Not enough numeric values after parsing (need at least 10).")
183
 
184
- # With datetime column
 
 
185
  if date_col:
186
  if date_col not in df.columns:
187
  raise gr.Error(f"Date column '{date_col}' not found. Available: {list(df.columns)}")
188
  dt = pd.to_datetime(df[date_col], errors="coerce")
189
  mask = dt.notna() & vals.notna()
190
- dt = pd.DatetimeIndex(dt[mask])
191
  vals = vals[mask]
 
192
  if len(vals) < 10:
193
  raise gr.Error("Too few valid rows after parsing date/value columns.")
194
- # sort by date
 
195
  order = np.argsort(dt.values)
196
  dt = dt[order]
197
  vals = vals.iloc[order].reset_index(drop=True)
 
198
  y = pd.Series(vals.values, index=dt, name=value_col or "value").copy()
199
- y.index = y.index.tz_localize(None)
200
 
201
- # Determine frequency
202
- freq = None
203
- if freq_choice and freq_choice != "auto":
204
- freq = freq_choice
205
- y = y.asfreq(freq, method="ffill")
206
  else:
207
- # try to infer; if None, fallback to 'D'
208
- freq = pd.infer_freq(y.index)
209
- if freq is None:
210
- # try business day if looks like weekdays only
211
  weekday_ratio = (y.index.dayofweek < 5).mean()
212
  freq = "B" if weekday_ratio > 0.95 else "D"
213
- y = y.asfreq(freq, method="ffill")
 
 
214
 
215
  else:
216
- # No date column: require explicit freq
217
- if not freq_choice or freq_choice == "auto":
218
- raise gr.Error("No date column given. Please choose a frequency (e.g., D, B, H).")
219
- freq = freq_choice
220
  idx = pd.date_range(start="2000-01-01", periods=len(vals), freq=freq)
221
  y = pd.Series(vals.values, index=idx, name=value_col or "value").copy()
222
 
223
- # Final sanity
224
  if y.isna().all():
225
  raise gr.Error("Series is all-NaN after processing.")
226
  return y, freq
227
 
228
  def forecast_csv(file, value_col: str, date_col: str, freq_choice: str, horizon: int, context_hint: int):
229
- y, freq = build_series_from_csv(file, value_col.strip(), date_col.strip(), freq_choice.strip())
230
  return _run_forecast_on_series(y, freq, horizon, context_hint, f"Uploaded series — forecast (freq={freq})")
231
 
232
- # ---- UI ----
 
 
233
  with gr.Blocks(title="Moirai 2.0 — Time Series Forecast (Research)") as demo:
234
  gr.Markdown(
235
  """
@@ -246,7 +261,7 @@ Use **Salesforce/moirai-2.0-R-small** (via Uni2TS) to forecast either a stock ti
246
  horizon_t = gr.Slider(5, 120, value=30, step=1, label="Forecast horizon (steps)")
247
  with gr.Row():
248
  lookback = gr.Slider(1, 10, value=5, step=1, label="Lookback window (years of history)")
249
- ctx_t = gr.Slider(64, 2000, value=1680, step=16, label="Context length")
250
  run_t = gr.Button("Run forecast", variant="primary")
251
  plot_t = gr.Plot(label="History + Forecast")
252
  table_t = gr.Dataframe(label="Forecast table", interactive=False)
@@ -255,7 +270,7 @@ Use **Salesforce/moirai-2.0-R-small** (via Uni2TS) to forecast either a stock ti
255
  with gr.Tab("Upload CSV"):
256
  gr.Markdown(
257
  "Upload a CSV with either (1) a **date/time column** and a **value column**, "
258
- "or (2) just a numeric value column (then choose a frequency)."
259
  )
260
  with gr.Row():
261
  file = gr.File(label="CSV file", file_types=[".csv"])
@@ -267,7 +282,7 @@ Use **Salesforce/moirai-2.0-R-small** (via Uni2TS) to forecast either a stock ti
267
  label="Frequency",
268
  value="auto",
269
  choices=["auto", "B", "D", "H", "W", "M", "MS"],
270
- info="If no date column, pick a freq (e.g., D)."
271
  )
272
  with gr.Row():
273
  horizon_u = gr.Slider(1, 500, value=60, step=1, label="Forecast horizon (steps)")
@@ -275,7 +290,11 @@ Use **Salesforce/moirai-2.0-R-small** (via Uni2TS) to forecast either a stock ti
275
  run_u = gr.Button("Run forecast on CSV", variant="primary")
276
  plot_u = gr.Plot(label="History + Forecast (CSV)")
277
  table_u = gr.Dataframe(label="Forecast table (CSV)", interactive=False)
278
- run_u.click(forecast_csv, inputs=[file, value_col, date_col, freq_choice, horizon_u, ctx_u], outputs=[plot_u, table_u])
 
 
 
 
279
 
280
  if __name__ == "__main__":
281
  demo.launch()
 
 
1
  import warnings
2
  warnings.filterwarnings("ignore")
3
 
4
  import gradio as gr
 
5
  import numpy as np
6
+ import pandas as pd
7
  import yfinance as yf
8
  import matplotlib.pyplot as plt
9
 
 
 
10
  from pandas.tseries.frequencies import to_offset
11
+ from gluonts.dataset.common import ListDataset
12
 
13
+ # --- Moirai 2.0 via Uni2TS ---
14
+ # Make sure your requirements install Uni2TS from GitHub:
15
+ # git+https://github.com/SalesforceAIResearch/uni2ts.git
16
  try:
17
  from uni2ts.model.moirai2 import Moirai2Forecast, Moirai2Module
18
  except Exception as e:
19
  raise ImportError(
20
+ "Moirai 2.0 not found in your Uni2TS install.\n"
21
+ "Ensure requirements.txt includes:\n"
22
+ " git+https://github.com/SalesforceAIResearch/uni2ts.git\n"
23
  f"Original error: {e}"
24
  )
25
 
26
  MODEL_ID = "Salesforce/moirai-2.0-R-small"
27
+ DEFAULT_CONTEXT = 1680 # from Moirai examples, but we clamp to series length
28
 
29
+ # ----------------------------
30
+ # Model loader (single instance)
31
+ # ----------------------------
32
  _MODULE = None
33
  def load_module():
34
  global _MODULE
 
36
  _MODULE = Moirai2Module.from_pretrained(MODEL_ID)
37
  return _MODULE
38
 
39
+ # ----------------------------
40
+ # Shared forecasting core
41
+ # ----------------------------
42
  def _future_index(last_idx: pd.Timestamp, freq: str, horizon: int) -> pd.DatetimeIndex:
 
43
  off = to_offset(freq)
44
  start = last_idx + off
45
  return pd.date_range(start=start, periods=horizon, freq=freq)
46
 
47
+ def _run_forecast_on_series(
48
+ y: pd.Series,
49
+ freq: str,
50
+ horizon: int,
51
+ context_hint: int,
52
+ title: str,
53
+ ):
54
  if len(y) < 50:
55
  raise gr.Error("Need at least 50 points to forecast.")
56
+
57
+ ctx = int(np.clip(context_hint or DEFAULT_CONTEXT, 32, len(y)))
58
  target = y.values[-ctx:].astype(np.float32)
59
  start_idx = y.index[-ctx]
60
 
 
69
  feat_dynamic_real_dim=0,
70
  past_feat_dynamic_real_dim=0,
71
  )
72
+ predictor = model.create_predictor(batch_size=32) # device handled internally
73
 
74
  forecast = next(iter(predictor.predict(ds)))
75
  if hasattr(forecast, "mean"):
 
83
 
84
  yhat = np.asarray(yhat).ravel()[:horizon]
85
  future_idx = _future_index(y.index[-1], freq, horizon)
86
+ pred = pd.Series(yhat, index=future_idx, name="prediction")
87
 
88
  # Plot
89
  fig = plt.figure(figsize=(10, 5))
 
95
  out_df = pd.DataFrame({"date": pred.index, "prediction": pred.values})
96
  return fig, out_df
97
 
98
+ # ----------------------------
99
+ # Ticker helpers
100
+ # ----------------------------
101
  def fetch_series(ticker: str, years: int) -> pd.Series:
102
+ """Fetch daily close prices and align to business-day frequency."""
103
  data = yf.download(
104
  ticker,
105
  period=f"{years}y",
 
127
  y.name = ticker
128
  y.index = pd.DatetimeIndex(y.index).tz_localize(None)
129
 
130
+ # Business-day index; forward-fill holidays
131
  bidx = pd.bdate_range(y.index.min(), y.index.max())
132
  y = y.reindex(bidx).ffill()
133
 
 
144
  y = fetch_series(ticker, lookback_years)
145
  return _run_forecast_on_series(y, "B", horizon, context_hint, f"{ticker} — forecast (Moirai 2.0 R-small)")
146
 
147
+ # ----------------------------
148
+ # CSV helpers
149
+ # ----------------------------
150
  def _read_csv_columns(file_path: str) -> pd.DataFrame:
 
151
  try:
152
  df = pd.read_csv(file_path)
153
  except Exception:
 
154
  df = pd.read_csv(file_path, sep=None, engine="python")
155
  return df
156
 
 
158
  s = pd.to_numeric(s, errors="coerce")
159
  return s.dropna().astype(np.float32)
160
 
161
+ def build_series_from_csv(file, value_col: str, date_col: str, freq_choice: str):
162
  """
163
  Returns (series y with DateTimeIndex, freq string).
164
+ - If date_col is provided: parse dates and infer/align frequency.
165
+ - If NO date_col: create a synthetic date index using freq_choice (default to 'D' if auto/blank).
166
  """
167
  if file is None:
168
  raise gr.Error("Please upload a CSV file.")
169
 
170
+ # Gradio file object handling (v4/v5)
171
  path = getattr(file, "name", None) or getattr(file, "path", None) or (file if isinstance(file, str) else None)
172
  if path is None:
173
  raise gr.Error("Could not read the uploaded file path.")
 
176
  if df.empty:
177
  raise gr.Error("Uploaded file is empty.")
178
 
179
+ # Value column selection
180
+ value_col = (value_col or "").strip()
181
  if value_col:
182
  if value_col not in df.columns:
183
  raise gr.Error(f"Value column '{value_col}' not found. Available: {list(df.columns)}")
184
  vals = _coerce_numeric_series(df[value_col])
185
  else:
 
186
  numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
187
+ if numeric_cols:
 
 
 
188
  vals = _coerce_numeric_series(df[numeric_cols[0]])
189
+ else:
190
+ vals = _coerce_numeric_series(df.iloc[:, 0])
191
 
192
  if vals.empty or len(vals) < 10:
193
  raise gr.Error("Not enough numeric values after parsing (need at least 10).")
194
 
195
+ date_col = (date_col or "").strip()
196
+ freq_choice_norm = (freq_choice or "").strip().upper()
197
+
198
  if date_col:
199
  if date_col not in df.columns:
200
  raise gr.Error(f"Date column '{date_col}' not found. Available: {list(df.columns)}")
201
  dt = pd.to_datetime(df[date_col], errors="coerce")
202
  mask = dt.notna() & vals.notna()
203
+ dt = pd.DatetimeIndex(dt[mask]).tz_localize(None)
204
  vals = vals[mask]
205
+
206
  if len(vals) < 10:
207
  raise gr.Error("Too few valid rows after parsing date/value columns.")
208
+
209
+ # Sort & dedupe index BEFORE inferring/aligning freq
210
  order = np.argsort(dt.values)
211
  dt = dt[order]
212
  vals = vals.iloc[order].reset_index(drop=True)
213
+
214
  y = pd.Series(vals.values, index=dt, name=value_col or "value").copy()
215
+ y = y[~y.index.duplicated(keep="last")].sort_index()
216
 
217
+ # Choose frequency
218
+ if freq_choice_norm and freq_choice_norm != "AUTO":
219
+ freq = freq_choice_norm
 
 
220
  else:
221
+ inferred = pd.infer_freq(y.index)
222
+ if inferred:
223
+ freq = inferred
224
+ else:
225
  weekday_ratio = (y.index.dayofweek < 5).mean()
226
  freq = "B" if weekday_ratio > 0.95 else "D"
227
+
228
+ # Align to chosen frequency
229
+ y = y.asfreq(freq, method="ffill")
230
 
231
  else:
232
+ # No date column: build synthetic index
233
+ freq = "D" if (not freq_choice_norm or freq_choice_norm == "AUTO") else freq_choice_norm
 
 
234
  idx = pd.date_range(start="2000-01-01", periods=len(vals), freq=freq)
235
  y = pd.Series(vals.values, index=idx, name=value_col or "value").copy()
236
 
 
237
  if y.isna().all():
238
  raise gr.Error("Series is all-NaN after processing.")
239
  return y, freq
240
 
241
  def forecast_csv(file, value_col: str, date_col: str, freq_choice: str, horizon: int, context_hint: int):
242
+ y, freq = build_series_from_csv(file, value_col, date_col, freq_choice)
243
  return _run_forecast_on_series(y, freq, horizon, context_hint, f"Uploaded series — forecast (freq={freq})")
244
 
245
+ # ----------------------------
246
+ # UI
247
+ # ----------------------------
248
  with gr.Blocks(title="Moirai 2.0 — Time Series Forecast (Research)") as demo:
249
  gr.Markdown(
250
  """
 
261
  horizon_t = gr.Slider(5, 120, value=30, step=1, label="Forecast horizon (steps)")
262
  with gr.Row():
263
  lookback = gr.Slider(1, 10, value=5, step=1, label="Lookback window (years of history)")
264
+ ctx_t = gr.Slider(64, 5000, value=1680, step=16, label="Context length")
265
  run_t = gr.Button("Run forecast", variant="primary")
266
  plot_t = gr.Plot(label="History + Forecast")
267
  table_t = gr.Dataframe(label="Forecast table", interactive=False)
 
270
  with gr.Tab("Upload CSV"):
271
  gr.Markdown(
272
  "Upload a CSV with either (1) a **date/time column** and a **value column**, "
273
+ "or (2) just a numeric value column (then choose a frequency, or leave **auto** to default to **D**)."
274
  )
275
  with gr.Row():
276
  file = gr.File(label="CSV file", file_types=[".csv"])
 
282
  label="Frequency",
283
  value="auto",
284
  choices=["auto", "B", "D", "H", "W", "M", "MS"],
285
+ info="If no date column, 'auto' defaults to D (daily)."
286
  )
287
  with gr.Row():
288
  horizon_u = gr.Slider(1, 500, value=60, step=1, label="Forecast horizon (steps)")
 
290
  run_u = gr.Button("Run forecast on CSV", variant="primary")
291
  plot_u = gr.Plot(label="History + Forecast (CSV)")
292
  table_u = gr.Dataframe(label="Forecast table (CSV)", interactive=False)
293
+ run_u.click(
294
+ forecast_csv,
295
+ inputs=[file, value_col, date_col, freq_choice, horizon_u, ctx_u],
296
+ outputs=[plot_u, table_u],
297
+ )
298
 
299
  if __name__ == "__main__":
300
  demo.launch()