MANOJSEQ commited on
Commit
f75d93a
·
verified ·
1 Parent(s): e1dfc68

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +152 -52
main.py CHANGED
@@ -1,6 +1,5 @@
1
  # ----------------- Imports (Stdlib + Typing) -----------------
2
  from fastapi import FastAPI, Query, HTTPException, Body
3
- from fastapi import FastAPI, Query, HTTPException, Body
4
  from typing import Optional, List, Dict, Any, Tuple, Set
5
  import os
6
  import time
@@ -451,18 +450,59 @@ SECTION_HINTS = {
451
 
452
  KEYWORDS = {
453
  "sports": r"\b(NBA|NFL|MLB|NHL|Olympic|goal|match|tournament|coach|transfer)\b",
454
- "business": r"\b(stocks?|earnings|IPO|merger|acquisition|revenue|inflation|market)\b",
455
  "technology": r"\b(AI|software|chip|semiconductor|app|startup|cyber|hack|quantum|robot)\b",
456
  "science": r"\b(researchers?|study|physics|astronomy|genome|spacecraft|telescope)\b",
457
- "health": r"\b(virus|vaccine|disease|hospital|doctor|public health|covid)\b",
458
- "entertainment": r"\b(movie|film|box office|celebrity|series|show|album|music)\b",
459
  "crime": r"\b(arrested|charged|police|homicide|fraud|theft|court|lawsuit)\b",
460
  "weather": r"\b(hurricane|storm|flood|heatwave|blizzard|tornado|forecast)\b",
461
  "environment": r"\b(climate|emissions|wildfire|deforestation|biodiversity)\b",
462
  "travel": r"\b(flight|airline|airport|tourism|visa|cruise|hotel)\b",
463
- "politics": r"\b(president|parliament|congress|minister|policy|campaign|election)\b",
 
 
 
 
 
 
 
464
  }
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  def get_news_clf():
467
  # Lazy-init topic classifier
468
  global _news_clf
@@ -496,29 +536,25 @@ def _infer_category_from_text(text: str) -> Optional[str]:
496
 
497
  def infer_category(article_url, title, description, provided):
498
  if provided:
499
- p = provided.strip().lower()
500
- if p:
501
- return p
502
  try:
503
  p = urlparse(article_url).path or ""
504
  cat = _infer_category_from_url_path(p)
505
  if cat:
506
- return cat
507
  except Exception:
508
  pass
509
  text = f"{title or ''} {description or ''}".strip()
510
  cat = _infer_category_from_text(text)
511
  if cat:
512
- return cat
513
  try:
514
  preds = get_news_clf()(text[:512])
515
- if isinstance(preds[0], list):
516
- label = preds[0][0]["label"]
517
- else:
518
- label = preds[0]["label"]
519
- return label.lower()
520
- except Exception as e:
521
- log.warning(f"ML category failed: {e}")
522
  return "general"
523
 
524
  # ----------------- Language Detection / Embeddings -----------------
@@ -710,25 +746,48 @@ def opus_model_for(src2: str, tgt2: str) -> Optional[str]:
710
  SUPPORTED = {"en", "fr", "de", "es", "it", "hi", "ar", "ru", "ja", "ko", "pt", "zh"}
711
  LIBRETRANSLATE_URL = os.getenv("LIBRETRANSLATE_URL")
712
 
 
 
 
 
 
 
 
 
 
713
  def _translate_via_libre(text: str, src: str, tgt: str) -> Optional[str]:
714
  url = LIBRETRANSLATE_URL
715
  if not url or not text or src == tgt:
716
  return None
717
- try:
718
- r = SESSION.post(
719
- f"{url.rstrip('/')}/translate",
720
- json={"q": text, "source": src, "target": tgt, "format": "text"},
721
- timeout=6
722
- )
723
- if r.status_code == 200:
724
- j = r.json()
725
- out = j.get("translatedText")
726
- return out if isinstance(out, str) and out else None
727
- else:
728
- log.warning("LibreTranslate HTTP %s: %s", r.status_code, r.text[:200])
729
- except Exception as e:
730
- log.warning("LibreTranslate failed: %s", e)
731
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
 
733
  def _hf_call(model_id: str, payload: dict) -> Optional[str]:
734
  if not (HUGGINGFACE_API_TOKEN and ALLOW_HF_REMOTE):
@@ -886,10 +945,16 @@ def _warm_once():
886
 
887
  @app.on_event("startup")
888
  def warm():
 
 
 
 
 
889
  get_sbert()
890
  get_news_clf()
891
  threading.Thread(target=_warm_once, daemon=True).start()
892
 
 
893
  # ----------------- GDELT Query Helpers -----------------
894
  _GDELT_LANG = {
895
  "en": "english",
@@ -1026,8 +1091,8 @@ def fetch_gdelt_multi(limit=120, query=None, language=None, timespan="48h", cate
1026
 
1027
 
1028
  # ----------------- Provider Flags / Keys / Logging -----------------
1029
- USE_GNEWS_API = True
1030
- USE_NEWSDATA_API = True
1031
  USE_GDELT_API = True
1032
  USE_NEWSAPI = False
1033
 
@@ -1157,7 +1222,11 @@ def enrich_article(a, language=None, translate=False, target_lang=None):
1157
  sentiment = classify_sentiment(f"{orig_title} {orig_description}")
1158
  seed = f"{source_name}|{article_url}|{title}"
1159
  uid = hashlib.md5(seed.encode("utf-8")).hexdigest()[:12]
1160
- cat = infer_category(article_url, orig_title, orig_description, None)
 
 
 
 
1161
  return {
1162
  "id": uid,
1163
  "title": title,
@@ -1232,7 +1301,7 @@ def event_payload_from_cluster(cluster, enriched_articles):
1232
  "sample_urls": [a["url"] for a in arts[:3] if a.get("url")],
1233
  }
1234
 
1235
- def aggregate_event_by_country(cluster, enriched_articles):
1236
  idxs = cluster["indices"]
1237
  arts = [enriched_articles[i] for i in idxs]
1238
  by_country: Dict[str, Dict[str, Any]] = {}
@@ -1251,6 +1320,7 @@ def aggregate_event_by_country(cluster, enriched_articles):
1251
  avg_sent = "positive" if avg > 0.15 else "negative" if avg < -0.15 else "neutral"
1252
  top_sources = [s for s, _ in Counter([a["source"] for a in arr]).most_common(3)]
1253
  summary = " • ".join([a["title"] for a in arr[:2]])
 
1254
  results.append(
1255
  {
1256
  "country": c,
@@ -1270,7 +1340,8 @@ def aggregate_event_by_country(cluster, enriched_articles):
1270
  "sentiment": a["sentiment"],
1271
  "detected_lang": a.get("detected_lang"),
1272
  }
1273
- for a in arr[:5]
 
1274
  ],
1275
  }
1276
  )
@@ -1500,13 +1571,10 @@ def combine_raw_articles(category=None, query=None, language=None, limit_each=30
1500
  a3 = fetch_gnews_articles(limit=limit_each, query=query, language=language) if USE_GNEWS_API else []
1501
  gdelt_limit = limit_each
1502
  a4 = fetch_gdelt_multi(
1503
- limit=gdelt_limit,
1504
- query=query,
1505
- language=language,
1506
- timespan=timespan,
1507
- category=category,
1508
- speed=speed,
1509
- )
1510
  seen, merged = set(), []
1511
  for a in a1 + a3 + a2 + a4:
1512
  if a.get("url"):
@@ -1566,6 +1634,7 @@ def get_event_details(
1566
  translate: Optional[bool] = Query(False),
1567
  target_lang: Optional[str] = Query(None),
1568
  limit_each: int = Query(150, ge=5, le=250),
 
1569
  ):
1570
  if cache_key:
1571
  parts = cache_key.split("|")
@@ -1600,7 +1669,7 @@ def get_event_details(
1600
  if not cluster:
1601
  raise HTTPException(status_code=404, detail="Event not found with current filters")
1602
  payload = event_payload_from_cluster(cluster, eview)
1603
- countries = aggregate_event_by_country(cluster, eview)
1604
  payload["articles_in_event"] = sum(c["count"] for c in countries)
1605
  return {"event": payload, "countries": countries}
1606
 
@@ -1746,14 +1815,45 @@ def client_metric(payload: Dict[str, Any] = Body(...)):
1746
 
1747
  # ----------------- Diagnostics: Translation Health -----------------
1748
  @app.get("/diag/translate")
1749
- def diag_translate():
1750
- remote = _hf_call("Helsinki-NLP/opus-mt-es-en", {"inputs":"Hola mundo"})
1751
- local = _translate_local("Hola mundo", "es", "en")
1752
- libre = _translate_via_libre("Hola mundo", "es", "en")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1753
  return {
1754
- "token": bool(HUGGINGFACE_API_TOKEN),
 
 
 
1755
  "remote_ok": bool(remote),
1756
  "local_ok": bool(local),
1757
- "libre_ok": bool(libre),
1758
- "sample": libre or remote or local
 
 
1759
  }
 
1
  # ----------------- Imports (Stdlib + Typing) -----------------
2
  from fastapi import FastAPI, Query, HTTPException, Body
 
3
  from typing import Optional, List, Dict, Any, Tuple, Set
4
  import os
5
  import time
 
450
 
451
  KEYWORDS = {
452
  "sports": r"\b(NBA|NFL|MLB|NHL|Olympic|goal|match|tournament|coach|transfer)\b",
453
+ "business": r"\b(stocks?|earnings|IPO|merger|acquisition|revenue|inflation|market|tax|budget|inflation|revenue|deficit)\b",
454
  "technology": r"\b(AI|software|chip|semiconductor|app|startup|cyber|hack|quantum|robot)\b",
455
  "science": r"\b(researchers?|study|physics|astronomy|genome|spacecraft|telescope)\b",
456
+ "health": r"\b(virus|vaccine|disease|hospital|doctor|public health|covid|recall|FDA|contamination|disease outbreak)\b",
457
+ "entertainment": r"\b(movie|film|box office|celebrity|series|show|album|music|)\b",
458
  "crime": r"\b(arrested|charged|police|homicide|fraud|theft|court|lawsuit)\b",
459
  "weather": r"\b(hurricane|storm|flood|heatwave|blizzard|tornado|forecast)\b",
460
  "environment": r"\b(climate|emissions|wildfire|deforestation|biodiversity)\b",
461
  "travel": r"\b(flight|airline|airport|tourism|visa|cruise|hotel)\b",
462
+ "politics": r"\b(president|parliament|congress|minister|policy|campaign|election|rally|protest|demonstration)\b",
463
+ }
464
+
465
+ # ----------------- Category normalization to frontend set -----------------
466
+ FRONTEND_CATS = {
467
+ "politics","technology","sports","business","entertainment",
468
+ "science","health","crime","weather","environment","travel",
469
+ "viral","general"
470
  }
471
 
472
+ ML_TO_FRONTEND = {
473
+ "arts_&_culture": "entertainment",
474
+ "business": "business",
475
+ "business_&_entrepreneurs": "business",
476
+ "celebrity_&_pop_culture": "entertainment",
477
+ "crime": "crime",
478
+ "diaries_&_daily_life": "viral",
479
+ "entertainment": "entertainment",
480
+ "environment": "environment",
481
+ "fashion_&_style": "entertainment",
482
+ "film_tv_&_video": "entertainment",
483
+ "fitness_&_health": "health",
484
+ "food_&_dining": "entertainment",
485
+ "general": "general",
486
+ "learning_&_educational": "science",
487
+ "news_&_social_concern": "politics",
488
+ "politics": "politics",
489
+ "science_&_technology": "science",
490
+ "sports": "sports",
491
+ "technology": "technology",
492
+ "travel_&_adventure": "travel",
493
+ "other_hobbies": "viral"
494
+ }
495
+
496
+ def normalize_category(c: Optional[str]) -> str:
497
+ s = (c or "").strip().lower()
498
+ if not s:
499
+ return "general"
500
+ if s in FRONTEND_CATS:
501
+ return s
502
+ return ML_TO_FRONTEND.get(s, "general")
503
+
504
+
505
+
506
  def get_news_clf():
507
  # Lazy-init topic classifier
508
  global _news_clf
 
536
 
537
  def infer_category(article_url, title, description, provided):
538
  if provided:
539
+ got = normalize_category(provided)
540
+ if got:
541
+ return got
542
  try:
543
  p = urlparse(article_url).path or ""
544
  cat = _infer_category_from_url_path(p)
545
  if cat:
546
+ return normalize_category(cat)
547
  except Exception:
548
  pass
549
  text = f"{title or ''} {description or ''}".strip()
550
  cat = _infer_category_from_text(text)
551
  if cat:
552
+ return normalize_category(cat)
553
  try:
554
  preds = get_news_clf()(text[:512])
555
+ label = preds[0][0]["label"] if isinstance(preds[0], list) else preds[0]["label"]
556
+ return normalize_category(label)
557
+ except Exception:
 
 
 
 
558
  return "general"
559
 
560
  # ----------------- Language Detection / Embeddings -----------------
 
746
  SUPPORTED = {"en", "fr", "de", "es", "it", "hi", "ar", "ru", "ja", "ko", "pt", "zh"}
747
  LIBRETRANSLATE_URL = os.getenv("LIBRETRANSLATE_URL")
748
 
749
+ def _lt_lang(code: str) -> str:
750
+ if not code:
751
+ return code
752
+ c = code.lower()
753
+ # LibreTranslate uses zh-Hans; normalize zh* to zh-Hans
754
+ if c.startswith("zh"):
755
+ return "zh-Hans"
756
+ return c
757
+
758
  def _translate_via_libre(text: str, src: str, tgt: str) -> Optional[str]:
759
  url = LIBRETRANSLATE_URL
760
  if not url or not text or src == tgt:
761
  return None
762
+
763
+ payload = {
764
+ "q": text,
765
+ "source": _lt_lang(src),
766
+ "target": _lt_lang(tgt),
767
+ "format": "text",
768
+ }
769
+
770
+ # First call can be slow while LT warms models; retry once.
771
+ for attempt in (1, 2):
772
+ try:
773
+ r = SESSION.post(
774
+ f"{url.rstrip('/')}/translate",
775
+ json=payload,
776
+ timeout=15 # was 6
777
+ )
778
+ if r.status_code == 200:
779
+ j = r.json()
780
+ out = j.get("translatedText")
781
+ return out if isinstance(out, str) and out else None
782
+ else:
783
+ log.warning("LibreTranslate HTTP %s: %s", r.status_code, r.text[:200])
784
+ return None
785
+ except Exception as e:
786
+ if attempt == 2:
787
+ log.warning("LibreTranslate failed: %s", e)
788
+ return None
789
+ time.sleep(0.5)
790
+
791
 
792
  def _hf_call(model_id: str, payload: dict) -> Optional[str]:
793
  if not (HUGGINGFACE_API_TOKEN and ALLOW_HF_REMOTE):
 
945
 
946
  @app.on_event("startup")
947
  def warm():
948
+ try:
949
+ _translate_cached.cache_clear()
950
+ except Exception:
951
+ pass
952
+
953
  get_sbert()
954
  get_news_clf()
955
  threading.Thread(target=_warm_once, daemon=True).start()
956
 
957
+
958
  # ----------------- GDELT Query Helpers -----------------
959
  _GDELT_LANG = {
960
  "en": "english",
 
1091
 
1092
 
1093
  # ----------------- Provider Flags / Keys / Logging -----------------
1094
+ USE_GNEWS_API = False
1095
+ USE_NEWSDATA_API = False
1096
  USE_GDELT_API = True
1097
  USE_NEWSAPI = False
1098
 
 
1222
  sentiment = classify_sentiment(f"{orig_title} {orig_description}")
1223
  seed = f"{source_name}|{article_url}|{title}"
1224
  uid = hashlib.md5(seed.encode("utf-8")).hexdigest()[:12]
1225
+ provided = a.get("category")
1226
+ if provided and normalize_category(provided) != "general":
1227
+ cat = normalize_category(provided)
1228
+ else:
1229
+ cat = infer_category(article_url, orig_title, orig_description, provided)
1230
  return {
1231
  "id": uid,
1232
  "title": title,
 
1301
  "sample_urls": [a["url"] for a in arts[:3] if a.get("url")],
1302
  }
1303
 
1304
+ def aggregate_event_by_country(cluster, enriched_articles, max_samples: int | None = 5):
1305
  idxs = cluster["indices"]
1306
  arts = [enriched_articles[i] for i in idxs]
1307
  by_country: Dict[str, Dict[str, Any]] = {}
 
1320
  avg_sent = "positive" if avg > 0.15 else "negative" if avg < -0.15 else "neutral"
1321
  top_sources = [s for s, _ in Counter([a["source"] for a in arr]).most_common(3)]
1322
  summary = " • ".join([a["title"] for a in arr[:2]])
1323
+ use = arr if (max_samples in (None, 0) or max_samples < 0) else arr[:max_samples]
1324
  results.append(
1325
  {
1326
  "country": c,
 
1340
  "sentiment": a["sentiment"],
1341
  "detected_lang": a.get("detected_lang"),
1342
  }
1343
+ # for a in arr[:5]
1344
+ for a in use
1345
  ],
1346
  }
1347
  )
 
1571
  a3 = fetch_gnews_articles(limit=limit_each, query=query, language=language) if USE_GNEWS_API else []
1572
  gdelt_limit = limit_each
1573
  a4 = fetch_gdelt_multi(
1574
+ limit=limit_each, query=query, language=language,
1575
+ timespan=timespan, category=category, speed=speed
1576
+ ) if USE_GDELT_API else []
1577
+
 
 
 
1578
  seen, merged = set(), []
1579
  for a in a1 + a3 + a2 + a4:
1580
  if a.get("url"):
 
1634
  translate: Optional[bool] = Query(False),
1635
  target_lang: Optional[str] = Query(None),
1636
  limit_each: int = Query(150, ge=5, le=250),
1637
+ max_samples: int = Query(5, ge=0, le=1000),
1638
  ):
1639
  if cache_key:
1640
  parts = cache_key.split("|")
 
1669
  if not cluster:
1670
  raise HTTPException(status_code=404, detail="Event not found with current filters")
1671
  payload = event_payload_from_cluster(cluster, eview)
1672
+ countries = aggregate_event_by_country(cluster, eview, max_samples=max_samples)
1673
  payload["articles_in_event"] = sum(c["count"] for c in countries)
1674
  return {"event": payload, "countries": countries}
1675
 
 
1815
 
1816
  # ----------------- Diagnostics: Translation Health -----------------
1817
  @app.get("/diag/translate")
1818
+ def diag_translate(
1819
+ src: str = Query("pt"),
1820
+ tgt: str = Query("en"),
1821
+ text: str = Query("Olá mundo")
1822
+ ):
1823
+ # Try each path explicitly (same order your runtime uses)
1824
+ libre = _translate_via_libre(text, src, tgt)
1825
+ remote = None
1826
+ local = None
1827
+
1828
+ opus_id = opus_model_for(src, tgt)
1829
+ if opus_id:
1830
+ remote = _hf_call(opus_id, {"inputs": text})
1831
+ local = _translate_local(text, src, tgt)
1832
+
1833
+ # Optional: try primary NLLB if configured
1834
+ nllb = None
1835
+ if HF_MODEL_PRIMARY and (src in NLLB_CODES) and (tgt in NLLB_CODES):
1836
+ nllb = _hf_call(
1837
+ HF_MODEL_PRIMARY,
1838
+ {
1839
+ "inputs": text,
1840
+ "parameters": {"src_lang": NLLB_CODES[src], "tgt_lang": NLLB_CODES[tgt]},
1841
+ "options": {"wait_for_model": True},
1842
+ },
1843
+ )
1844
+
1845
+ sample_out = libre or remote or local or nllb
1846
+ out_lang = detect_lang(sample_out or "") or None
1847
+
1848
  return {
1849
+ "src": src, "tgt": tgt, "text": text,
1850
+ "libre_url": LIBRETRANSLATE_URL,
1851
+ "token_present": bool(HUGGINGFACE_API_TOKEN),
1852
+ "libre_ok": bool(libre),
1853
  "remote_ok": bool(remote),
1854
  "local_ok": bool(local),
1855
+ "nllb_ok": bool(nllb),
1856
+ "sample_out": sample_out,
1857
+ "sample_out_lang_detected": out_lang,
1858
+ "lang_match": (out_lang == tgt)
1859
  }