hopelessDev commited on
Commit
d02c1bf
·
verified ·
1 Parent(s): 17b7ace

Delete sentiment_api.py

Browse files
Files changed (1) hide show
  1. sentiment_api.py +0 -135
sentiment_api.py DELETED
@@ -1,135 +0,0 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from typing import List, Dict
4
- import os
5
- import requests
6
- import yfinance as yf
7
- import numpy as np
8
- from transformers import pipeline
9
- from cachetools import TTLCache, cached
10
-
11
- # -----------------------------
12
- # CONFIG
13
- # -----------------------------
14
- NEWSAPI_KEY = os.environ.get("NEWSAPI_KEY", "").strip() # optional
15
- MAX_HEADLINES = 3
16
- MODEL_A = "yiyanghkust/finbert-tone"
17
- MODEL_B = "mrm8488/distilroberta-finetuned-financial-news-sentiment"
18
-
19
- # -----------------------------
20
- # Load models
21
- # -----------------------------
22
- sentiment_a = pipeline("sentiment-analysis", model=MODEL_A, device=-1)
23
- sentiment_b = pipeline("sentiment-analysis", model=MODEL_B, device=-1)
24
-
25
- LABEL_MAP = {
26
- "positive": "positive", "neutral": "neutral", "negative": "negative",
27
- "Positive": "positive", "Neutral": "neutral", "Negative": "negative",
28
- "LABEL_0": "negative", "LABEL_1": "neutral", "LABEL_2": "positive"
29
- }
30
-
31
- # -----------------------------
32
- # Caching
33
- # -----------------------------
34
- # Cache up to 100 stocks, expires after 10 minutes
35
- stock_cache = TTLCache(maxsize=100, ttl=600)
36
-
37
- # -----------------------------
38
- # News fetchers
39
- # -----------------------------
40
- def fetch_news_newsapi(query: str, limit: int = MAX_HEADLINES) -> List[str]:
41
- if not NEWSAPI_KEY:
42
- return []
43
- url = "https://newsapi.org/v2/everything"
44
- params = {
45
- "q": query,
46
- "language": "en",
47
- "pageSize": limit,
48
- "sortBy": "publishedAt",
49
- "apiKey": NEWSAPI_KEY,
50
- }
51
- try:
52
- r = requests.get(url, params=params, timeout=6)
53
- r.raise_for_status()
54
- data = r.json()
55
- articles = data.get("articles", [])[:limit]
56
- return [a.get("title", "") for a in articles if a.get("title")]
57
- except:
58
- return []
59
-
60
- def fetch_news_yfinance(ticker: str, limit: int = MAX_HEADLINES) -> List[str]:
61
- try:
62
- t = yf.Ticker(ticker)
63
- news_items = getattr(t, "news", None) or []
64
- return [n.get("title") for n in news_items if n.get("title")][:limit]
65
- except:
66
- return []
67
-
68
- def fetch_headlines(stock: str, limit: int = MAX_HEADLINES) -> List[str]:
69
- headlines = fetch_news_newsapi(stock, limit)
70
- if headlines:
71
- return headlines
72
- return fetch_news_yfinance(stock, limit)
73
-
74
- # -----------------------------
75
- # Ensemble utilities
76
- # -----------------------------
77
- def model_to_vector(pred: Dict) -> np.ndarray:
78
- label = pred.get("label", "")
79
- score = float(pred.get("score", 0.0))
80
- mapped = LABEL_MAP.get(label, label.lower())
81
- vec = np.zeros(3)
82
- if mapped == "negative":
83
- vec[0] = score
84
- elif mapped == "neutral":
85
- vec[1] = score
86
- elif mapped == "positive":
87
- vec[2] = score
88
- else:
89
- vec[1] = score
90
- return vec
91
-
92
- def headline_score_ensemble(headline: str) -> np.ndarray:
93
- a = sentiment_a(headline)[0]
94
- b = sentiment_b(headline)[0]
95
- return (model_to_vector(a) + model_to_vector(b)) / 2.0
96
-
97
- def aggregate_headlines_vectors(vectors: List[np.ndarray]) -> np.ndarray:
98
- if not vectors:
99
- return np.array([0.0,1.0,0.0])
100
- mean_vec = np.mean(vectors, axis=0)
101
- mean_vec = np.clip(mean_vec, 0.0, None)
102
- total = mean_vec.sum()
103
- return mean_vec / total if total > 0 else np.array([0.0,1.0,0.0])
104
-
105
- def vector_to_score(vec: np.ndarray) -> float:
106
- neg, neu, pos = vec.tolist()
107
- score = pos + 0.5 * neu
108
- return max(0.0, min(1.0, score))
109
-
110
- # -----------------------------
111
- # FastAPI app
112
- # -----------------------------
113
- app = FastAPI(title="Financial Sentiment API")
114
-
115
- class StocksRequest(BaseModel):
116
- stocks: List[str]
117
-
118
- @cached(stock_cache)
119
- def analyze_single_stock(stock: str) -> float:
120
- headlines = fetch_headlines(stock)
121
- vectors = [headline_score_ensemble(h) for h in headlines if h and len(h.strip())>10]
122
- agg = aggregate_headlines_vectors(vectors)
123
- score = round(vector_to_score(agg), 2)
124
- return score if score else 0.5
125
-
126
- @app.post("/analyze")
127
- def analyze_stocks(req: StocksRequest):
128
- results = {}
129
- for stock in req.stocks:
130
- results[stock] = analyze_single_stock(stock)
131
- return results
132
-
133
- if __name__ == "__main__":
134
- import uvicorn
135
- uvicorn.run(app, host="0.0.0.0", port=7860)