hopelessDev commited on
Commit
1cb61c4
·
verified ·
1 Parent(s): 35634ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -68
app.py CHANGED
@@ -7,14 +7,12 @@ import yfinance as yf
7
  import numpy as np
8
  from transformers import pipeline
9
  from cachetools import TTLCache, cached
10
- from datetime import datetime, timezone
11
 
12
  # -----------------------------
13
  # CONFIG
14
  # -----------------------------
15
  NEWSAPI_KEY = os.environ.get("NEWSAPI_KEY", "").strip()
16
  MAX_HEADLINES = 10 # fetch more for robustness
17
- DECAY_HALF_LIFE_HOURS = 12 # half-life for old news
18
 
19
  MODEL_A = "yiyanghkust/finbert-tone"
20
  MODEL_B = "ProsusAI/finbert"
@@ -36,64 +34,42 @@ LABEL_MAP = {
36
  # -----------------------------
37
  stock_cache = TTLCache(maxsize=100, ttl=600)
38
 
39
- # -----------------------------
40
- # Finance keywords filter
41
- # -----------------------------
42
- FINANCE_KEYWORDS = [
43
- "stock", "share", "market", "profit", "loss", "earnings",
44
- "investment", "IPO", "dividend", "trading", "NASDAQ", "NYSE"
45
- ]
46
-
47
- def is_relevant_headline(headline: str) -> bool:
48
- headline_lower = headline.lower()
49
- return any(k.lower() in headline_lower for k in FINANCE_KEYWORDS)
50
-
51
  # -----------------------------
52
  # News fetchers
53
  # -----------------------------
54
- def fetch_news_newsapi(stock: str, limit: int = MAX_HEADLINES) -> List[Dict]:
55
  if not NEWSAPI_KEY:
56
  return []
57
  url = "https://newsapi.org/v2/everything"
58
- query = f'"{stock}" OR ${stock}'
59
  params = {
60
  "q": query,
61
  "language": "en",
62
- "pageSize": limit*2,
63
  "sortBy": "publishedAt",
64
  "apiKey": NEWSAPI_KEY,
65
  }
66
  try:
67
  r = requests.get(url, params=params, timeout=6)
68
  r.raise_for_status()
69
- articles = r.json().get("articles", [])
70
- filtered = [
71
- {"title": a.get("title"), "publishedAt": a.get("publishedAt")}
72
- for a in articles if a.get("title") and is_relevant_headline(a.get("title"))
73
- ]
74
- return filtered[:limit]
75
  except Exception as e:
76
  print(f"[NewsAPI error] {e}")
77
  return []
78
 
79
- def fetch_news_yfinance(stock: str, limit: int = MAX_HEADLINES) -> List[Dict]:
80
  try:
81
- t = yf.Ticker(stock)
82
  news_items = getattr(t, "news", None) or []
83
- filtered = [
84
- {"title": n.get("title"), "publishedAt": n.get("providerPublishTime")}
85
- for n in news_items if n.get("title") and is_relevant_headline(n.get("title"))
86
- ]
87
- return filtered[:limit]
88
  except Exception as e:
89
  print(f"[Yahoo Finance error] {e}")
90
  return []
91
 
92
- def fetch_headlines(stock: str, limit: int = MAX_HEADLINES) -> List[Dict]:
93
  headlines = fetch_news_newsapi(stock, limit)
94
- if len(headlines) < 2:
95
- headlines_yf = fetch_news_yfinance(stock, limit)
96
- headlines = list({h['title']: h for h in (headlines + headlines_yf)}.values())[:limit]
97
  return headlines
98
 
99
  # -----------------------------
@@ -119,63 +95,56 @@ def headline_score_ensemble(headline: str) -> np.ndarray:
119
  b = sentiment_b(headline)[0]
120
  return (model_to_vector(a) + model_to_vector(b)) / 2.0
121
 
122
- def aggregate_headlines_vectors(vectors: List[np.ndarray], timestamps: List[float]) -> np.ndarray:
123
  if not vectors:
124
  return np.array([0.0,1.0,0.0])
125
-
126
- # Apply decay weights based on timestamps
127
- now = datetime.now(timezone.utc).timestamp()
128
- weights = np.array([0.5 ** ((now - ts)/(DECAY_HALF_LIFE_HOURS*3600)) for ts in timestamps])
129
- weighted_vecs = np.array(vectors) * weights[:, None]
130
- mean_vec = weighted_vecs.sum(axis=0) / weights.sum()
131
-
132
- mean_vec = np.clip(mean_vec, 0.0, None)
133
  total = mean_vec.sum()
134
  return mean_vec / total if total > 0 else np.array([0.0,1.0,0.0])
135
 
136
  def vector_to_score(vec: np.ndarray) -> float:
137
  neg, neu, pos = vec.tolist()
138
- return round(max(0.0, min(1.0, pos + 0.5 * neu)), 2)
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  # -----------------------------
141
  # FastAPI app
142
  # -----------------------------
143
- app = FastAPI(title="Financial Sentiment API with Decay")
144
 
145
  class StocksRequest(BaseModel):
146
  stocks: List[str]
147
 
148
  @cached(stock_cache)
149
  def analyze_single_stock(stock: str) -> float | str:
150
- headlines_data = fetch_headlines(stock)
151
- headlines_data = [h for h in headlines_data if h.get("title") and len(h["title"].strip()) > 10]
152
 
153
- if not headlines_data:
154
  return "NO_DATA"
155
 
156
- vectors = []
157
- timestamps = []
158
- for h in headlines_data:
159
- vectors.append(headline_score_ensemble(h["title"]))
160
- # convert publishedAt to timestamp
161
- try:
162
- ts = h.get("publishedAt")
163
- if isinstance(ts, str):
164
- ts = datetime.fromisoformat(ts.replace("Z","+00:00")).timestamp()
165
- elif isinstance(ts, (int, float)):
166
- ts = float(ts)
167
- else:
168
- ts = datetime.now(timezone.utc).timestamp()
169
- except:
170
- ts = datetime.now(timezone.utc).timestamp()
171
- timestamps.append(ts)
172
-
173
- agg = aggregate_headlines_vectors(vectors, timestamps)
174
- return vector_to_score(agg)
175
 
176
  @app.get("/")
177
  def root():
178
- return {"message": "Fin-senti API with Decay is running! Use POST /analyze"}
179
 
180
  @app.post("/analyze")
181
  def analyze(req: StocksRequest):
 
7
  import numpy as np
8
  from transformers import pipeline
9
  from cachetools import TTLCache, cached
 
10
 
11
  # -----------------------------
12
  # CONFIG
13
  # -----------------------------
14
  NEWSAPI_KEY = os.environ.get("NEWSAPI_KEY", "").strip()
15
  MAX_HEADLINES = 10 # fetch more for robustness
 
16
 
17
  MODEL_A = "yiyanghkust/finbert-tone"
18
  MODEL_B = "ProsusAI/finbert"
 
34
  # -----------------------------
35
  stock_cache = TTLCache(maxsize=100, ttl=600)
36
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # -----------------------------
38
  # News fetchers
39
  # -----------------------------
40
+ def fetch_news_newsapi(query: str, limit: int = MAX_HEADLINES) -> List[str]:
41
  if not NEWSAPI_KEY:
42
  return []
43
  url = "https://newsapi.org/v2/everything"
 
44
  params = {
45
  "q": query,
46
  "language": "en",
47
+ "pageSize": limit,
48
  "sortBy": "publishedAt",
49
  "apiKey": NEWSAPI_KEY,
50
  }
51
  try:
52
  r = requests.get(url, params=params, timeout=6)
53
  r.raise_for_status()
54
+ articles = r.json().get("articles", [])[:limit]
55
+ return [a.get("title", "") for a in articles if a.get("title")]
 
 
 
 
56
  except Exception as e:
57
  print(f"[NewsAPI error] {e}")
58
  return []
59
 
60
+ def fetch_news_yfinance(ticker: str, limit: int = MAX_HEADLINES) -> List[str]:
61
  try:
62
+ t = yf.Ticker(ticker)
63
  news_items = getattr(t, "news", None) or []
64
+ return [n.get("title") for n in news_items if n.get("title")][:limit]
 
 
 
 
65
  except Exception as e:
66
  print(f"[Yahoo Finance error] {e}")
67
  return []
68
 
69
+ def fetch_headlines(stock: str, limit: int = MAX_HEADLINES) -> List[str]:
70
  headlines = fetch_news_newsapi(stock, limit)
71
+ if not headlines:
72
+ headlines = fetch_news_yfinance(stock, limit)
 
73
  return headlines
74
 
75
  # -----------------------------
 
95
  b = sentiment_b(headline)[0]
96
  return (model_to_vector(a) + model_to_vector(b)) / 2.0
97
 
98
+ def aggregate_headlines_vectors(vectors: List[np.ndarray]) -> np.ndarray:
99
  if not vectors:
100
  return np.array([0.0,1.0,0.0])
101
+ mean_vec = np.mean(vectors, axis=0)
 
 
 
 
 
 
 
102
  total = mean_vec.sum()
103
  return mean_vec / total if total > 0 else np.array([0.0,1.0,0.0])
104
 
105
  def vector_to_score(vec: np.ndarray) -> float:
106
  neg, neu, pos = vec.tolist()
107
+ return max(0.0, min(1.0, pos + 0.5 * neu))
108
+
109
+ # -----------------------------
110
+ # Decay utilities
111
+ # -----------------------------
112
+ def get_decay_factor(num_headlines: int, max_headlines: int = MAX_HEADLINES,
113
+ min_decay: float = 0.6, max_decay: float = 0.95) -> float:
114
+ """
115
+ Dynamic decay: more headlines → higher decay → score can approach extremes.
116
+ """
117
+ ratio = min(num_headlines / max_headlines, 1.0)
118
+ return min_decay + ratio * (max_decay - min_decay)
119
 
120
  # -----------------------------
121
  # FastAPI app
122
  # -----------------------------
123
+ app = FastAPI(title="Financial Sentiment API")
124
 
125
  class StocksRequest(BaseModel):
126
  stocks: List[str]
127
 
128
  @cached(stock_cache)
129
  def analyze_single_stock(stock: str) -> float | str:
130
+ headlines = fetch_headlines(stock)
131
+ headlines = [h for h in headlines if h and len(h.strip()) > 10]
132
 
133
+ if not headlines or len(headlines) < 2:
134
  return "NO_DATA"
135
 
136
+ vectors = [headline_score_ensemble(h) for h in headlines]
137
+ agg = aggregate_headlines_vectors(vectors)
138
+ raw_score = vector_to_score(agg)
139
+
140
+ # Apply dynamic decay
141
+ decay = get_decay_factor(len(headlines))
142
+ adjusted_score = 0.5 + decay * (raw_score - 0.5)
143
+ return round(adjusted_score, 2)
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  @app.get("/")
146
  def root():
147
+ return {"message": "Fin-senti API is running! Use POST /analyze"}
148
 
149
  @app.post("/analyze")
150
  def analyze(req: StocksRequest):