Spaces:
Sleeping
Sleeping
\ | |
import os, json, numpy as np, pandas as pd | |
import gradio as gr | |
import faiss | |
import re | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from logic.cleaning import clean_dataframe | |
from logic.search import SloganSearcher | |
# -------------------- Config -------------------- | |
ASSETS_DIR = "assets" | |
DATA_PATH = "data/slogan.csv" | |
PROMPT_PATH = "data/prompt.txt" | |
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
NORMALIZE = True | |
GEN_MODEL = "google/flan-t5-base" | |
NUM_GEN_CANDIDATES = 12 | |
MAX_NEW_TOKENS = 18 | |
TEMPERATURE = 0.7 | |
TOP_P = 0.9 | |
REPETITION_PENALTY = 1.15 | |
# choose the most relevant yet non-duplicate candidate | |
RELEVANCE_WEIGHT = 0.7 | |
NOVELTY_WEIGHT = 0.3 | |
DUPLICATE_MAX_SIM = 0.92 | |
NOVELTY_SIM_THRESHOLD = 0.80 # keep some distance from retrieved | |
META_PATH = os.path.join(ASSETS_DIR, "meta.json") | |
PARQUET_PATH = os.path.join(ASSETS_DIR, "slogans_clean.parquet") | |
INDEX_PATH = os.path.join(ASSETS_DIR, "faiss.index") | |
EMB_PATH = os.path.join(ASSETS_DIR, "embeddings.npy") | |
def _log(m): print(f"[SLOGAN-SPACE] {m}", flush=True) | |
# -------------------- Asset build -------------------- | |
def _build_assets(): | |
if not os.path.exists(DATA_PATH): | |
raise FileNotFoundError(f"Dataset not found at {DATA_PATH} (CSV with columns: 'tagline', 'description').") | |
os.makedirs(ASSETS_DIR, exist_ok=True) | |
_log(f"Loading dataset: {DATA_PATH}") | |
df = pd.read_csv(DATA_PATH) | |
_log(f"Rows before cleaning: {len(df)}") | |
df = clean_dataframe(df) | |
_log(f"Rows after cleaning: {len(df)}") | |
if "description" in df.columns and df["description"].notna().any(): | |
texts = df["description"].fillna(df["tagline"]).astype(str).tolist() | |
text_col, fallback_col = "description", "tagline" | |
else: | |
texts = df["tagline"].astype(str).tolist() | |
text_col, fallback_col = "tagline", "tagline" | |
_log(f"Encoding with {MODEL_NAME} (normalize={NORMALIZE}) β¦") | |
encoder = SentenceTransformer(MODEL_NAME) | |
emb = encoder.encode(texts, batch_size=64, convert_to_numpy=True, normalize_embeddings=NORMALIZE) | |
dim = emb.shape[1] | |
index = faiss.IndexFlatIP(dim) if NORMALIZE else faiss.IndexFlatL2(dim) | |
index.add(emb) | |
_log("Persisting assets β¦") | |
df.to_parquet(PARQUET_PATH, index=False) | |
faiss.write_index(index, INDEX_PATH) | |
np.save(EMB_PATH, emb) | |
meta = { | |
"model_name": MODEL_NAME, | |
"dim": int(dim), | |
"normalized": NORMALIZE, | |
"metric": "ip" if NORMALIZE else "l2", | |
"row_count": int(len(df)), | |
"text_col": text_col, | |
"fallback_col": fallback_col, | |
} | |
with open(META_PATH, "w") as f: | |
json.dump(meta, f, indent=2) | |
_log("Assets built successfully.") | |
def _ensure_assets(): | |
need = False | |
for p in (META_PATH, PARQUET_PATH, INDEX_PATH): | |
if not os.path.exists(p): | |
_log(f"Missing asset: {p}") | |
need = True | |
if need: | |
_log("Building assets from scratch β¦") | |
_build_assets() | |
return | |
try: | |
pd.read_parquet(PARQUET_PATH) | |
except Exception as e: | |
_log(f"Parquet read failed ({e}); rebuilding assets.") | |
_build_assets() | |
# Build before UI | |
_ensure_assets() | |
# -------------------- Retrieval -------------------- | |
searcher = SloganSearcher(assets_dir=ASSETS_DIR, use_rerank=False) | |
meta = json.load(open(META_PATH)) | |
_encoder = SentenceTransformer(meta["model_name"]) | |
# -------------------- Generator -------------------- | |
_gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL) | |
_gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL) | |
# keep this list small so we don't nuke relevant outputs | |
_BANNED_TERMS = {"portal", "e-commerce", "ecommerce", "shopping", "shop"} | |
_PUNCT_CHARS = ":;ββ-,.!?ββ\"'`" | |
_PUNCT_RE = re.compile(f"[{re.escape(_PUNCT_CHARS)}]") | |
_MIN_WORDS, _MAX_WORDS = 2, 8 | |
def _load_prompt(): | |
if os.path.exists(PROMPT_PATH): | |
with open(PROMPT_PATH, "r", encoding="utf-8") as f: | |
return f.read() | |
return ( | |
"You are a professional slogan writer.\n" | |
"Write ONE original startup slogan under 8 words, Title Case, no punctuation.\n" | |
"Do not copy examples.\n" | |
"Description:\n{description}\nSlogan:" | |
) | |
def _render_prompt(description: str, retrieved=None) -> str: | |
tmpl = _load_prompt() | |
if "{description}" in tmpl: | |
prompt = tmpl.replace("{description}", description) | |
else: | |
prompt = f"{tmpl}\n\nDescription:\n{description}\nSlogan:" | |
if retrieved: | |
prompt += "\n\nDo NOT copy these existing slogans:\n" | |
for s in retrieved[:3]: | |
prompt += f"- {s}\n" | |
return prompt | |
def _title_case(s: str) -> str: | |
small = {"and","or","for","of","the","to","in","on","with","a","an"} | |
words = [w for w in s.split() if w] | |
out = [] | |
for i,w in enumerate(words): | |
lw = w.lower() | |
if i>0 and lw in small: out.append(lw) | |
else: out.append(lw.capitalize()) | |
return " ".join(out) | |
def _strip_punct(s: str) -> str: | |
return _PUNCT_RE.sub("", s) | |
def _strict_ok(s: str) -> bool: | |
if not s: return False | |
wc = len(s.split()) | |
if wc < _MIN_WORDS or wc > _MAX_WORDS: return False | |
lo = s.lower() | |
if any(term in lo for term in _BANNED_TERMS): return False | |
if lo in {"the","a","an"}: return False | |
return True | |
def _postprocess_strict(texts): | |
cleaned, seen = [], set() | |
for t in texts: | |
s = t.replace("Slogan:", "").strip().strip('"').strip("'") | |
s = " ".join(s.split()) | |
s = _strip_punct(s) # remove punctuation instead of rejecting | |
s = _title_case(s) | |
if _strict_ok(s): | |
k = s.lower() | |
if k not in seen: | |
seen.add(k); cleaned.append(s) | |
return cleaned | |
def _postprocess_relaxed(texts): | |
# fallback if strict returns nothing: keep 2β8 words, strip punctuation, Title Case | |
cleaned, seen = [], set() | |
for t in texts: | |
s = t.strip().strip('"').strip("'") | |
s = _strip_punct(s) | |
s = " ".join(s.split()) | |
wc = len(s.split()) | |
if _MIN_WORDS <= wc <= _MAX_WORDS: | |
s = _title_case(s) | |
k = s.lower() | |
if k not in seen: | |
seen.add(k); cleaned.append(s) | |
return cleaned | |
def _generate_candidates(description: str, retrieved_texts, n: int = NUM_GEN_CANDIDATES): | |
prompt = _render_prompt(description, retrieved_texts) | |
# only block very generic junk at decode time | |
bad_ids = _gen_tokenizer(list(_BANNED_TERMS), add_special_tokens=False).input_ids | |
inputs = _gen_tokenizer([prompt], return_tensors="pt", padding=True, truncation=True) | |
outputs = _gen_model.generate( | |
**inputs, | |
do_sample=True, | |
temperature=TEMPERATURE, | |
top_p=TOP_P, | |
num_return_sequences=n, | |
max_new_tokens=MAX_NEW_TOKENS, | |
no_repeat_ngram_size=3, | |
repetition_penalty=REPETITION_PENALTY, | |
bad_words_ids=bad_ids if bad_ids else None, | |
eos_token_id=_gen_tokenizer.eos_token_id, | |
) | |
texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
cands = _postprocess_strict(texts) | |
if not cands: | |
cands = _postprocess_relaxed(texts) # <- graceful fallback | |
return cands | |
def _pick_best(candidates, retrieved_texts, description): | |
"""Weighted relevance to description minus duplication vs retrieved.""" | |
if not candidates: | |
return None | |
c_emb = _encoder.encode(candidates, convert_to_numpy=True, normalize_embeddings=True) | |
d_emb = _encoder.encode([description], convert_to_numpy=True, normalize_embeddings=True)[0] | |
rel = c_emb @ d_emb # cosine sim to description | |
if retrieved_texts: | |
R = _encoder.encode(retrieved_texts, convert_to_numpy=True, normalize_embeddings=True) | |
dup = np.max(R @ c_emb.T, axis=0) # max sim to any retrieved | |
else: | |
dup = np.zeros(len(candidates), dtype=np.float32) | |
# penalize near-duplicates outright | |
mask = dup < DUPLICATE_MAX_SIM | |
if mask.any(): | |
scores = RELEVANCE_WEIGHT * rel[mask] - NOVELTY_WEIGHT * dup[mask] | |
best_idx = np.argmax(scores) | |
return [c for i, c in enumerate(candidates) if mask[i]][best_idx] | |
# else: pick most relevant that still clears a basic novelty bar, else top score | |
scores = RELEVANCE_WEIGHT * rel - NOVELTY_WEIGHT * dup | |
order = np.argsort(-scores) | |
for i in order: | |
if dup[i] < NOVELTY_SIM_THRESHOLD: | |
return candidates[i] | |
return candidates[order[0]] | |
# -------------------- Inference pipeline -------------------- | |
def run_pipeline(user_description: str): | |
if not user_description or not user_description.strip(): | |
return "Please enter a description." | |
retrieved_df = searcher.search(user_description, top_k=3, rerank_top_n=10) | |
retrieved_texts = retrieved_df["display"].tolist() if not retrieved_df.empty else [] | |
gens = _generate_candidates(user_description, retrieved_texts, NUM_GEN_CANDIDATES) | |
chosen = _pick_best(gens, retrieved_texts, user_description) or (gens[0] if gens else "β") | |
lines = [] | |
lines.append("### π Top 3 similar slogans") | |
if retrieved_texts: | |
for i, s in enumerate(retrieved_texts, 1): | |
lines.append(f"{i}. {s}") | |
else: | |
lines.append("_No similar slogans found._") | |
lines.append("\n### β¨ AI-generated suggestion") | |
lines.append(chosen) | |
return "\n".join(lines) | |
# -------------------- UI -------------------- | |
with gr.Blocks(title="Slogan Finder") as demo: | |
gr.Markdown("# π Slogan Finder\nDescribe your product/company; get 3 similar slogans + 1 AI-generated suggestion.") | |
query = gr.Textbox(label="Describe your product/company", placeholder="AI-powered patient financial navigation platform...") | |
btn = gr.Button("Get slogans", variant="primary") | |
out = gr.Markdown() | |
btn.click(run_pipeline, inputs=[query], outputs=out) | |
demo.queue(max_size=64).launch() | |