|
|
|
""" |
|
Thematic Word Generator using Sentence Transformers |
|
|
|
Generates thematically related words from a set of input words/sentences. |
|
Uses semantic centroids to understand broader themes and find related vocabulary. |
|
""" |
|
|
|
import os |
|
import csv |
|
import pickle |
|
import numpy as np |
|
import logging |
|
from typing import List, Tuple, Optional |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from sklearn.cluster import KMeans |
|
import nltk |
|
from nltk.corpus import words, brown |
|
from datetime import datetime |
|
import time |
|
from collections import Counter |
|
|
|
|
|
try: |
|
from wordfreq import word_frequency, zipf_frequency, top_n_list |
|
HAS_WORDFREQ = True |
|
except ImportError: |
|
HAS_WORDFREQ = False |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
def get_timestamp(): |
|
return datetime.now().strftime("%H:%M:%S") |
|
def get_datetimestamp(): |
|
return datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
class ThematicWordGeneranor_v0_1: |
|
def __init__(self, cache_dir: Optional[str] = None, model_name: str = 'all-mpnet-base-v2'): |
|
"""Initialize the thematic word generator. |
|
|
|
Args: |
|
cache_dir: Directory to cache the embedding model |
|
model_name: Sentence transformer model to use |
|
""" |
|
if cache_dir is None: |
|
cache_dir = os.path.join(os.path.dirname(__file__), 'model_cache') |
|
|
|
self.cache_dir = cache_dir |
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
logger.info("Loading embedding model...") |
|
self.model = SentenceTransformer( |
|
f'sentence-transformers/{model_name}', |
|
cache_folder=cache_dir |
|
) |
|
logger.info("Model loaded successfully.") |
|
|
|
|
|
self.vocabulary, self.vocab_embeddings = self._load_or_create_vocab_embeddings() |
|
|
|
|
|
self.word_frequencies = self._load_frequency_data() |
|
self.frequency_tiers = self._create_frequency_tiers() |
|
|
|
def _load_or_create_vocab_embeddings(self) -> Tuple[List[str], np.ndarray]: |
|
"""Load vocabulary and embeddings from cache or create them.""" |
|
|
|
vocab_cache_path = os.path.join(self.cache_dir, 'dictionary.pkl') |
|
embeddings_cache_path = os.path.join(self.cache_dir, 'vocab_embeddings.npy') |
|
|
|
|
|
if os.path.exists(vocab_cache_path) and os.path.exists(embeddings_cache_path): |
|
try: |
|
logger.info("Loading vocabulary and embeddings from cache...") |
|
start_time = time.time() |
|
|
|
with open(vocab_cache_path, 'rb') as f: |
|
vocabulary = pickle.load(f) |
|
embeddings = np.load(embeddings_cache_path) |
|
|
|
load_time = time.time() - start_time |
|
logger.info(f"✓ Loaded {len(vocabulary):,} words and embeddings from cache in {load_time:.2f}s") |
|
return vocabulary, embeddings |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading from cache: {e}") |
|
logger.info("Rebuilding vocabulary and embeddings...") |
|
|
|
|
|
logger.info("Creating new vocabulary and embeddings...") |
|
vocabulary = self._load_vocabulary() |
|
|
|
embeddings = self._create_vocab_embeddings(vocabulary) |
|
|
|
|
|
try: |
|
logger.info("Saving vocabulary and embeddings to cache...") |
|
with open(vocab_cache_path, 'wb') as f: |
|
pickle.dump(vocabulary, f) |
|
np.save(embeddings_cache_path, embeddings) |
|
logger.info("✓ Cache saved successfully") |
|
except Exception as e: |
|
logger.warning(f"Could not save cache: {e}") |
|
|
|
return vocabulary, embeddings |
|
|
|
def _load_dictionary(self) -> List[str]: |
|
"""Load words from the dictionary CSV file.""" |
|
dict_path = os.path.join(os.path.dirname(__file__), 'dict-words', 'dict.csv') |
|
words = [] |
|
|
|
try: |
|
with open(dict_path, 'r', encoding='utf-8') as csvfile: |
|
reader = csv.DictReader(csvfile) |
|
for row in reader: |
|
word = row['word'].strip().lower() |
|
if word and len(word) > 1: |
|
words.append(word) |
|
except FileNotFoundError: |
|
raise Exception(f"Dictionary file not found: {dict_path}") |
|
except Exception as e: |
|
raise Exception(f"Error reading dictionary: {e}") |
|
|
|
return words |
|
|
|
def _load_frequency_data(self) -> Counter: |
|
"""Load word frequency data from WordFreq or Brown corpus fallback.""" |
|
|
|
if HAS_WORDFREQ: |
|
wordfreq_cache_path = os.path.join(self.cache_dir, 'wordfreq_frequencies.pkl') |
|
|
|
|
|
if os.path.exists(wordfreq_cache_path): |
|
try: |
|
logger.info("Loading WordFreq data from cache...") |
|
with open(wordfreq_cache_path, 'rb') as f: |
|
word_freq = pickle.load(f) |
|
logger.info(f"✓ Loaded WordFreq data for {len(word_freq):,} words") |
|
return word_freq |
|
except Exception as e: |
|
logger.warning(f"Error loading WordFreq cache: {e}") |
|
|
|
|
|
logger.info("Generating frequency data from WordFreq (comprehensive multi-source)...") |
|
try: |
|
word_freq = self._generate_wordfreq_data() |
|
|
|
|
|
try: |
|
with open(wordfreq_cache_path, 'wb') as f: |
|
pickle.dump(word_freq, f) |
|
logger.info("✓ Cached WordFreq data") |
|
except Exception as e: |
|
logger.warning(f"Could not cache WordFreq data: {e}") |
|
|
|
return word_freq |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading WordFreq data: {e}") |
|
logger.info("Falling back to Brown corpus...") |
|
|
|
|
|
return self._load_brown_frequency_data() |
|
|
|
def _generate_wordfreq_data(self) -> Counter: |
|
"""Generate frequency data from WordFreq's comprehensive vocabulary.""" |
|
logger.info("Fetching comprehensive vocabulary from WordFreq...") |
|
|
|
try: |
|
|
|
top_words = top_n_list('en', 500000, wordlist='large') |
|
logger.info(f"Retrieved {len(top_words):,} words from WordFreq") |
|
|
|
frequency_data = Counter() |
|
processed_count = 0 |
|
|
|
|
|
batch_size = 5000 |
|
total_batches = (len(top_words) + batch_size - 1) // batch_size |
|
|
|
for batch_num in range(total_batches): |
|
start_idx = batch_num * batch_size |
|
end_idx = min(start_idx + batch_size, len(top_words)) |
|
batch_words = top_words[start_idx:end_idx] |
|
|
|
for word in batch_words: |
|
try: |
|
|
|
freq = word_frequency(word, 'en', wordlist='large') |
|
if freq > 0: |
|
|
|
count = int(freq * 1_000_000_000) |
|
if count > 0: |
|
frequency_data[word] = count |
|
processed_count += 1 |
|
else: |
|
|
|
frequency_data[word] = 1 |
|
processed_count += 1 |
|
except Exception: |
|
continue |
|
|
|
|
|
if batch_num % 20 == 0 or batch_num == total_batches - 1: |
|
logger.info(f" Batch {batch_num + 1:3d}/{total_batches} | " |
|
f"Processed {end_idx:6,}/{len(top_words):,} words | " |
|
f"Found {processed_count:,} with frequencies") |
|
|
|
logger.info(f"✓ Generated WordFreq data: {len(frequency_data):,} words with frequencies") |
|
return frequency_data |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating WordFreq data: {e}") |
|
raise |
|
|
|
def _load_brown_frequency_data(self) -> Counter: |
|
"""Load frequency data from Brown corpus (fallback).""" |
|
freq_cache_path = os.path.join(self.cache_dir, 'brown_frequencies.pkl') |
|
|
|
|
|
if os.path.exists(freq_cache_path): |
|
try: |
|
logger.info("Loading Brown corpus frequency data from cache...") |
|
with open(freq_cache_path, 'rb') as f: |
|
word_freq = pickle.load(f) |
|
logger.info(f"✓ Loaded Brown corpus data for {len(word_freq):,} words") |
|
return word_freq |
|
except Exception as e: |
|
logger.warning(f"Error loading Brown corpus cache: {e}") |
|
|
|
|
|
logger.info("Generating frequency data from Brown corpus (1960s academic fallback)...") |
|
try: |
|
nltk.download('brown', quiet=True) |
|
brown_words = [word.lower() for word in brown.words() if word.isalpha()] |
|
word_freq = Counter(brown_words) |
|
logger.info(f"✓ Generated Brown corpus data for {len(word_freq):,} unique words") |
|
|
|
|
|
try: |
|
with open(freq_cache_path, 'wb') as f: |
|
pickle.dump(word_freq, f) |
|
logger.info("✓ Cached Brown corpus data") |
|
except Exception as e: |
|
logger.warning(f"Could not cache Brown corpus data: {e}") |
|
|
|
return word_freq |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading Brown corpus: {e}") |
|
|
|
return Counter() |
|
|
|
def _create_frequency_tiers(self) -> dict: |
|
"""Create detailed frequency tier classifications with 10 tiers.""" |
|
if not self.word_frequencies: |
|
return {} |
|
|
|
tiers = {} |
|
|
|
|
|
all_counts = list(self.word_frequencies.values()) |
|
all_counts.sort(reverse=True) |
|
|
|
|
|
tier_definitions = [ |
|
("tier_1_ultra_common", 0.999, "Ultra Common (Top 0.1%)"), |
|
("tier_2_extremely_common", 0.995, "Extremely Common (Top 0.5%)"), |
|
("tier_3_very_common", 0.99, "Very Common (Top 1%)"), |
|
("tier_4_highly_common", 0.97, "Highly Common (Top 3%)"), |
|
("tier_5_common", 0.92, "Common (Top 8%)"), |
|
("tier_6_moderately_common", 0.85, "Moderately Common (Top 15%)"), |
|
("tier_7_somewhat_uncommon", 0.70, "Somewhat Uncommon (Top 30%)"), |
|
("tier_8_uncommon", 0.50, "Uncommon (Top 50%)"), |
|
("tier_9_rare", 0.25, "Rare (Top 75%)"), |
|
("tier_10_very_rare", 0.0, "Very Rare (Bottom 25%)") |
|
] |
|
|
|
|
|
thresholds = [] |
|
for tier_name, percentile, description in tier_definitions: |
|
if percentile > 0: |
|
idx = int((1 - percentile) * len(all_counts)) |
|
threshold = all_counts[min(idx, len(all_counts) - 1)] |
|
else: |
|
threshold = 0 |
|
thresholds.append((tier_name, threshold, description)) |
|
|
|
|
|
self.tier_descriptions = {name: desc for name, _, desc in thresholds} |
|
|
|
|
|
for word, count in self.word_frequencies.items(): |
|
assigned = False |
|
for tier_name, threshold, description in thresholds: |
|
if count >= threshold: |
|
tiers[word] = tier_name |
|
assigned = True |
|
break |
|
|
|
if not assigned: |
|
tiers[word] = "tier_10_very_rare" |
|
|
|
|
|
for word in self.vocabulary: |
|
if word not in tiers: |
|
tiers[word] = "tier_10_very_rare" |
|
|
|
|
|
tier_counts = Counter(tiers.values()) |
|
logger.info(f"✓ Created 10-tier frequency system for {len(tiers):,} words:") |
|
|
|
tier_order = [f"tier_{i}_{name}" for i, name in enumerate([ |
|
"ultra_common", "extremely_common", "very_common", "highly_common", |
|
"common", "moderately_common", "somewhat_uncommon", "uncommon", |
|
"rare", "very_rare" |
|
], 1)] |
|
|
|
for tier_key in tier_order: |
|
if tier_key in tier_counts: |
|
count = tier_counts[tier_key] |
|
percentage = (count / len(tiers)) * 100 if tiers else 0 |
|
description = self.tier_descriptions.get(tier_key, tier_key) |
|
logger.info(f" - {description}: {count:,} words ({percentage:.1f}%)") |
|
|
|
return tiers |
|
|
|
def get_word_frequency_info(self, word: str) -> Tuple[float, str]: |
|
"""Get relative frequency and tier for a word.""" |
|
count = self.word_frequencies.get(word, 0) |
|
total_words = sum(self.word_frequencies.values()) if self.word_frequencies else 1 |
|
relative_freq = count / total_words if total_words > 0 else 0.0 |
|
tier = self.frequency_tiers.get(word, "tier_10_very_rare") |
|
return relative_freq, tier |
|
|
|
def get_tier_description(self, tier: str) -> str: |
|
"""Get human-readable description for a tier.""" |
|
return getattr(self, 'tier_descriptions', {}).get(tier, tier) |
|
|
|
def get_tier_number(self, tier: str) -> int: |
|
"""Extract tier number from tier string.""" |
|
if tier.startswith("tier_"): |
|
try: |
|
return int(tier.split("_")[1]) |
|
except (IndexError, ValueError): |
|
return 10 |
|
return 10 |
|
|
|
def _load_vocabulary(self) -> List[str]: |
|
"""Load vocabulary from NLTK words corpus with frequency filtering.""" |
|
try: |
|
logger.info("Downloading NLTK data...") |
|
|
|
nltk.download('words', quiet=True) |
|
word_list = list(words.words()) |
|
logger.info(f"✓ Downloaded {len(word_list):,} words from NLTK") |
|
|
|
|
|
logger.info("Filtering vocabulary...") |
|
filtered_words = [] |
|
for word in word_list: |
|
word_clean = word.lower().strip() |
|
|
|
if len(word_clean) >= 3 and word_clean.isalpha(): |
|
filtered_words.append(word_clean) |
|
|
|
|
|
unique_words = list(set(filtered_words)) |
|
logger.info(f"✓ Filtered to {len(unique_words):,} unique words") |
|
|
|
|
|
|
|
vocabulary = sorted(unique_words) |
|
|
|
|
|
max_vocab_size = 50000 |
|
if len(vocabulary) > max_vocab_size: |
|
logger.info(f"Reducing vocabulary from {len(vocabulary):,} to {max_vocab_size:,} words for performance") |
|
vocabulary = vocabulary[:max_vocab_size] |
|
|
|
logger.info(f"✓ Final vocabulary: {len(vocabulary):,} words") |
|
return vocabulary |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading NLTK vocabulary: {e}") |
|
logger.info("Using fallback vocabulary...") |
|
|
|
basic_words = [ |
|
"animal", "science", "technology", "ocean", "forest", "mountain", |
|
"computer", "music", "art", "book", "travel", "food", "nature", |
|
"space", "history", "culture", "sports", "weather", "education", |
|
"health", "family", "friend", "house", "car", "city", "country", |
|
"water", "fire", "earth", "air", "light", "dark", "color", |
|
"sound", "time", "world", "life", "death", "love", "peace", |
|
"war", "power", "money", "work", "play", "game", "sport", |
|
"business", "school", "university", "government", "law", |
|
"medicine", "hospital", "doctor", "nurse", "teacher", "student", |
|
"writer", "artist", "musician", "actor", "director", "producer" |
|
] |
|
logger.info(f"Using fallback vocabulary with {len(basic_words)} words.") |
|
return basic_words |
|
|
|
def _create_vocab_embeddings(self, vocabulary: List[str]) -> np.ndarray: |
|
"""Create embeddings for all vocabulary words with detailed progress.""" |
|
batch_size = 512 |
|
all_embeddings = [] |
|
|
|
total_batches = (len(vocabulary) + batch_size - 1) // batch_size |
|
total_words = len(vocabulary) |
|
|
|
logger.info(f"Creating embeddings for {total_words:,} words in {total_batches} batches...") |
|
start_time = time.time() |
|
|
|
for i in range(0, len(vocabulary), batch_size): |
|
batch_start_time = time.time() |
|
batch_words = vocabulary[i:i + batch_size] |
|
batch_num = i // batch_size + 1 |
|
|
|
batch_embeddings = self.model.encode( |
|
batch_words, |
|
convert_to_tensor=False, |
|
show_progress_bar=False |
|
) |
|
all_embeddings.append(batch_embeddings) |
|
|
|
|
|
batch_time = time.time() - batch_start_time |
|
words_processed = min(i + batch_size, total_words) |
|
progress_pct = (words_processed / total_words) * 100 |
|
|
|
elapsed_total = time.time() - start_time |
|
if words_processed > 0: |
|
words_per_second = words_processed / elapsed_total |
|
remaining_words = total_words - words_processed |
|
eta_seconds = remaining_words / words_per_second if words_per_second > 0 else 0 |
|
eta_str = f"{eta_seconds:.0f}s" if eta_seconds < 60 else f"{eta_seconds/60:.1f}m" |
|
else: |
|
eta_str = "calculating..." |
|
|
|
logger.info(f" Batch {batch_num:3d}/{total_batches} | " |
|
f"{words_processed:6,}/{total_words:,} words ({progress_pct:5.1f}%) | " |
|
f"ETA: {eta_str}") |
|
|
|
total_time = time.time() - start_time |
|
words_per_second = total_words / total_time |
|
logger.info(f"✓ Created embeddings for {total_words:,} words in {total_time:.2f}s " |
|
f"({words_per_second:.0f} words/sec)") |
|
|
|
return np.vstack(all_embeddings) |
|
|
|
def _compute_theme_vector(self, inputs: List[str]) -> np.ndarray: |
|
"""Compute semantic centroid from input words/sentences.""" |
|
logger.info(f"entered _compute_theme_vector") |
|
|
|
input_embeddings = self.model.encode(inputs, convert_to_tensor=False, show_progress_bar=False) |
|
logger.info(f"completed _compute_theme_vector model.encode") |
|
|
|
|
|
theme_vector = np.mean(input_embeddings, axis=0) |
|
|
|
return theme_vector.reshape(1, -1) |
|
|
|
def _detect_multiple_themes(self, inputs: List[str], max_themes: int = 3) -> List[np.ndarray]: |
|
"""Detect multiple themes using clustering.""" |
|
if len(inputs) < 2: |
|
return [self._compute_theme_vector(inputs)] |
|
logger.info(f"entered _detect_multiple_themes") |
|
|
|
|
|
logger.info("starting model.encode") |
|
input_embeddings = self.model.encode(inputs, convert_to_tensor=False, show_progress_bar=False) |
|
logger.info("completed model.encode") |
|
|
|
|
|
n_clusters = min(max_themes, len(inputs), 3) |
|
logger.info(f"num of clusters: {n_clusters:2d}") |
|
|
|
if n_clusters == 1: |
|
return [np.mean(input_embeddings, axis=0).reshape(1, -1)] |
|
|
|
|
|
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) |
|
kmeans.fit(input_embeddings) |
|
|
|
|
|
return [center.reshape(1, -1) for center in kmeans.cluster_centers_] |
|
|
|
def generate_thematic_words(self, |
|
inputs: List[str], |
|
num_words: int = 20, |
|
min_similarity: float = 0.3, |
|
diversity_factor: float = 0.1, |
|
multi_theme: bool = False) -> List[Tuple[str, float]]: |
|
"""Generate thematically related words from input seeds. |
|
|
|
Args: |
|
inputs: List of words or sentences as theme seeds |
|
num_words: Number of words to return |
|
min_similarity: Minimum similarity threshold |
|
diversity_factor: Balance between relevance and diversity (0.0-1.0) |
|
multi_theme: Whether to detect and use multiple themes |
|
|
|
Returns: |
|
List of tuples (word, similarity_score) sorted by relevance |
|
""" |
|
logger.info(f"entered generate_thematic_words") |
|
if not inputs: |
|
return [] |
|
|
|
|
|
clean_inputs = [inp.strip().lower() for inp in inputs if inp.strip()] |
|
if not clean_inputs: |
|
return [] |
|
|
|
|
|
logger.info(f"{multi_theme=},{clean_inputs=}") |
|
if multi_theme and len(clean_inputs) > 2: |
|
theme_vectors = self._detect_multiple_themes(clean_inputs) |
|
else: |
|
theme_vectors = [self._compute_theme_vector(clean_inputs)] |
|
logger.info("done with getting theme_vectors") |
|
|
|
|
|
all_similarities = np.zeros(len(self.vocabulary)) |
|
|
|
for theme_vector in theme_vectors: |
|
|
|
similarities = cosine_similarity(theme_vector, self.vocab_embeddings)[0] |
|
all_similarities += similarities / len(theme_vectors) |
|
|
|
logger.info("done with cosine similarity") |
|
|
|
top_indices = np.argsort(all_similarities)[::-1] |
|
logger.info("done with argsort") |
|
|
|
|
|
results = [] |
|
input_words_set = set(clean_inputs) |
|
seen_words = set() |
|
|
|
for idx in top_indices: |
|
word = self.vocabulary[idx] |
|
similarity_score = all_similarities[idx] |
|
|
|
|
|
if (word not in input_words_set and |
|
word not in seen_words and |
|
similarity_score >= min_similarity): |
|
results.append((word, similarity_score)) |
|
seen_words.add(word) |
|
|
|
if len(results) >= num_words * 3: |
|
break |
|
|
|
logger.info("starting with _apply_diversity_filter") |
|
diversity_factor = 0.0 |
|
|
|
if diversity_factor > 0.0 and len(results) > num_words: |
|
results = self._apply_diversity_filter(results, num_words, diversity_factor) |
|
logger.info("done with _apply_diversity_filter") |
|
|
|
|
|
return results |
|
|
|
def _apply_diversity_filter(self, |
|
candidates: List[Tuple[str, float]], |
|
target_count: int, |
|
diversity_factor: float) -> List[Tuple[str, float]]: |
|
"""Apply diversity filtering to reduce semantic redundancy - optimized version.""" |
|
if len(candidates) <= target_count: |
|
return candidates |
|
|
|
if diversity_factor <= 0.0: |
|
return candidates[:target_count] |
|
|
|
logger.info(f"Applying diversity filter to {len(candidates)} candidates for {target_count} targets") |
|
|
|
|
|
candidate_words = [word for word, _ in candidates] |
|
logger.info("Computing embeddings for all candidates...") |
|
start_time = time.time() |
|
candidate_embeddings = self.model.encode(candidate_words, convert_to_tensor=False, show_progress_bar=False) |
|
embed_time = time.time() - start_time |
|
logger.info(f"✓ Computed {len(candidate_words)} embeddings in {embed_time:.2f}s") |
|
|
|
|
|
selected_indices = [0] |
|
selected_embeddings = [candidate_embeddings[0]] |
|
|
|
|
|
start_time = time.time() |
|
for _ in range(1, min(target_count, len(candidates))): |
|
best_idx = -1 |
|
best_score = -1 |
|
|
|
|
|
for i in range(len(candidates)): |
|
if i in selected_indices: |
|
continue |
|
|
|
|
|
candidate_emb = candidate_embeddings[i].reshape(1, -1) |
|
min_sim_to_selected = float('inf') |
|
|
|
for selected_emb in selected_embeddings: |
|
selected_emb = selected_emb.reshape(1, -1) |
|
sim = cosine_similarity(candidate_emb, selected_emb)[0][0] |
|
min_sim_to_selected = min(min_sim_to_selected, sim) |
|
|
|
|
|
original_score = candidates[i][1] |
|
diversity_bonus = (1.0 - min_sim_to_selected) * diversity_factor |
|
combined_score = original_score + diversity_bonus |
|
|
|
if combined_score > best_score: |
|
best_score = combined_score |
|
best_idx = i |
|
|
|
|
|
if best_idx >= 0: |
|
selected_indices.append(best_idx) |
|
selected_embeddings.append(candidate_embeddings[best_idx]) |
|
else: |
|
break |
|
|
|
selection_time = time.time() - start_time |
|
logger.info(f"✓ Completed diversity selection in {selection_time:.2f}s") |
|
|
|
|
|
return [candidates[i] for i in selected_indices] |
|
|
|
def get_theme_embedding(self, inputs: List[str]) -> np.ndarray: |
|
"""Get the theme embedding vector for debugging/analysis.""" |
|
return self._compute_theme_vector(inputs)[0] |
|
|
|
|
|
def main(): |
|
"""Demo the thematic word generator.""" |
|
logger.info("Initializing Thematic Word Generator...") |
|
generator = ThematicWordGeneranor_v0_1() |
|
|
|
|
|
test_cases = [ |
|
{ |
|
"name": "Ocean Theme", |
|
"inputs": ["ocean", "waves", "sailing"], |
|
"description": "Maritime and ocean-related concepts" |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
print("\n" + "="*70) |
|
print("THEMATIC WORD GENERATOR DEMO") |
|
print("="*70) |
|
|
|
for test_case in test_cases: |
|
print(f"\n{test_case['name']}: {test_case['description']}") |
|
print(f"Input: {test_case['inputs']}") |
|
print("-" * 50) |
|
|
|
|
|
thematic_words = generator.generate_thematic_words( |
|
test_case['inputs'], |
|
num_words=12, |
|
min_similarity=0.2 |
|
) |
|
|
|
if thematic_words: |
|
|
|
def sort_by_tier_and_similarity(item): |
|
word, similarity = item |
|
freq, tier = generator.get_word_frequency_info(word) |
|
tier_num = generator.get_tier_number(tier) |
|
return (tier_num, -similarity) |
|
|
|
thematic_words_sorted = sorted(thematic_words, key=sort_by_tier_and_similarity) |
|
|
|
print("Related words (sorted by frequency tier T1→T10):") |
|
for i, (word, score) in enumerate(thematic_words_sorted): |
|
freq, tier = generator.get_word_frequency_info(word) |
|
tier_desc = generator.get_tier_description(tier) |
|
tier_num = generator.get_tier_number(tier) |
|
print(f" {i+1:2d}. {word:<15} (sim: {score:.3f}, freq: {freq:.8f}) [T{tier_num}: {tier_desc}]") |
|
else: |
|
print(" No related words found.") |
|
|
|
|
|
if len(test_case['inputs']) > 1: |
|
diverse_words = generator.generate_thematic_words( |
|
test_case['inputs'], |
|
num_words=8, |
|
diversity_factor=0.3, |
|
multi_theme=True |
|
) |
|
|
|
|
|
diverse_words_sorted = sorted(diverse_words, key=sort_by_tier_and_similarity) |
|
|
|
print(f"\nWith diversity (showing {len(diverse_words_sorted)} words, sorted by tier):") |
|
for i, (word, score) in enumerate(diverse_words_sorted): |
|
freq, tier = generator.get_word_frequency_info(word) |
|
tier_desc = generator.get_tier_description(tier) |
|
tier_num = generator.get_tier_number(tier) |
|
print(f" {i+1:2d}. {word:<15} (sim: {score:.3f}, freq: {freq:.8f}) [T{tier_num}: {tier_desc}]") |
|
|
|
|
|
print("\n" + "="*70) |
|
print("INTERACTIVE MODE") |
|
print("Enter words/sentences separated by commas (type 'quit' to exit)") |
|
print("="*70) |
|
|
|
while True: |
|
try: |
|
start = get_timestamp() |
|
user_input = input(f"\n[{start}] Enter theme words/sentences: ").strip() |
|
|
|
if user_input.lower() == 'quit': |
|
break |
|
|
|
if not user_input: |
|
continue |
|
|
|
|
|
inputs = [inp.strip() for inp in user_input.split(',') if inp.strip()] |
|
|
|
if not inputs: |
|
print("Please provide at least one word or sentence.") |
|
continue |
|
|
|
start = get_timestamp() |
|
print(f"\n[{start}] Generating thematic words for: {inputs}") |
|
print("-" * 40) |
|
|
|
|
|
thematic_words = generator.generate_thematic_words( |
|
inputs, |
|
num_words=50, |
|
diversity_factor=0.2, |
|
multi_theme=len(inputs) > 2 |
|
) |
|
logger.info("returned from generate_thematic_words") |
|
|
|
|
|
def sort_by_tier_and_similarity(item): |
|
word, similarity = item |
|
freq, tier = generator.get_word_frequency_info(word) |
|
tier_num = generator.get_tier_number(tier) |
|
return (tier_num, -similarity) |
|
|
|
thematic_words = sorted(thematic_words, key=sort_by_tier_and_similarity) |
|
|
|
if thematic_words: |
|
print(f"\nGenerated {len(thematic_words)} thematic words (sorted by frequency tier T1→T10):") |
|
current_tier = None |
|
for i, (word, score) in enumerate(thematic_words): |
|
freq, tier = generator.get_word_frequency_info(word) |
|
tier_desc = generator.get_tier_description(tier) |
|
tier_num = generator.get_tier_number(tier) |
|
|
|
|
|
if tier_num != current_tier: |
|
current_tier = tier_num |
|
print(f"\n === TIER {tier_num}: {tier_desc} ===") |
|
|
|
print(f" {i+1:2d}. {word:<15} (sim: {score:.3f}, freq: {freq:.8f})") |
|
else: |
|
print(" No thematic words found. Try different inputs or lower similarity threshold.") |
|
|
|
except KeyboardInterrupt: |
|
break |
|
except Exception as e: |
|
logger.error(f"Error in main loop: {e}") |
|
print(f"Error: {e}") |
|
|
|
print("\nGoodbye!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|