Spaces:
Sleeping
Sleeping
\ | |
import json, os | |
import numpy as np, pandas as pd | |
import faiss | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
class SloganSearcher: | |
def __init__(self, assets_dir="assets", use_rerank=False, rerank_model="cross-encoder/stsb-roberta-base"): | |
meta_path = os.path.join(assets_dir, "meta.json") | |
if not os.path.exists(meta_path): | |
raise FileNotFoundError(f"Missing {meta_path}. Build assets first.") | |
with open(meta_path, "r") as f: | |
self.meta = json.load(f) | |
self.df = pd.read_parquet(os.path.join(assets_dir, "slogans_clean.parquet")) | |
self.index = faiss.read_index(os.path.join(assets_dir, "faiss.index")) | |
self.encoder = SentenceTransformer(self.meta["model_name"]) | |
self.use_rerank = use_rerank | |
self.reranker = CrossEncoder(rerank_model) if use_rerank else None | |
self.text_col = self.meta.get("text_col", "description") | |
self.fallback_col = self.meta.get("fallback_col", "tagline") | |
self.norm = bool(self.meta.get("normalized", True)) | |
def search(self, query: str, top_k=5, rerank_top_n=20): | |
if not isinstance(query, str) or len(query.strip()) == 0: | |
return pd.DataFrame(columns=["display", "score"] + (["rerank_score"] if self.use_rerank else [])) | |
q = self.encoder.encode([query], convert_to_numpy=True, normalize_embeddings=self.norm) | |
sims, idxs = self.index.search(q, max(int(top_k), int(rerank_top_n) if self.use_rerank else int(top_k))) | |
idxs = idxs[0].tolist() | |
sims = sims[0].tolist() | |
results = self.df.iloc[idxs].copy() | |
results["score"] = sims | |
if self.use_rerank: | |
texts = results[self.text_col].fillna(results[self.fallback_col]).astype(str).tolist() | |
pairs = [[query, t] for t in texts] | |
rr = self.reranker.predict(pairs) | |
results["rerank_score"] = rr | |
results = results.sort_values("rerank_score", ascending=False).head(int(top_k)) | |
else: | |
results = results.head(int(top_k)) | |
results["display"] = results[self.fallback_col] | |
cols = ["display", "score"] + (["rerank_score"] if self.use_rerank else []) | |
return results[cols] | |