|
|
|
|
|
import os |
|
import numpy as np |
|
import faiss |
|
from sentence_transformers import SentenceTransformer |
|
import requests |
|
from sklearn.cluster import KMeans |
|
import networkx as nx |
|
import csv |
|
|
|
def get_vocab(): |
|
|
|
url = "https://raw.githubusercontent.com/dwyl/english-words/master/words.txt" |
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
return [word.strip().lower() for word in response.text.splitlines() if word.strip() and len(word) > 2] |
|
else: |
|
raise Exception("Failed to fetch vocabulary list") |
|
|
|
class CrosswordGenerator2: |
|
def get_dict_vocab(self): |
|
"""Read the dictionary CSV file and return list of words.""" |
|
dict_path = os.path.join(os.path.dirname(__file__), 'dict-words', 'dict.csv') |
|
words = [] |
|
|
|
try: |
|
with open(dict_path, 'r', encoding='utf-8') as csvfile: |
|
reader = csv.DictReader(csvfile) |
|
for row in reader: |
|
word = row['word'].strip().lower() |
|
if word and len(word) > 2: |
|
words.append(word) |
|
except FileNotFoundError: |
|
raise Exception(f"Dictionary file not found: {dict_path}") |
|
except Exception as e: |
|
raise Exception(f"Error reading dictionary file: {e}") |
|
|
|
return words |
|
|
|
def __init__(self, cache_dir='./model_cache'): |
|
self.vocab = self.get_dict_vocab() |
|
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', cache_folder=cache_dir) |
|
embeddings = self.model.encode(self.vocab, convert_to_numpy=True) |
|
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32) |
|
faiss.normalize_L2(embeddings) |
|
self.dimension = embeddings.shape[1] |
|
self.faiss_index = faiss.IndexFlatIP(self.dimension) |
|
self.faiss_index.add(embeddings) |
|
self.max_results = 50 |
|
|
|
def get_wikipedia_subcats(self, topic): |
|
topic_cap = topic.capitalize().replace(' ', '_') |
|
url = f"https://en.wikipedia.org/w/api.php?action=query&list=categorymembers&cmtitle=Category:{topic_cap}&cmtype=subcat&format=json&cmlimit=50" |
|
try: |
|
response = requests.get(url).json() |
|
members = response.get('query', {}).get('categorymembers', []) |
|
if members: |
|
return [member['title'].replace('Category:', '').lower() for member in members] |
|
else: |
|
|
|
search_url = f"https://en.wikipedia.org/w/api.php?action=query&list=search&srsearch={topic}&format=json" |
|
search_response = requests.get(search_url).json() |
|
search_results = search_response.get('query', {}).get('search', []) |
|
if search_results: |
|
main_title = search_results[0]['title'] |
|
cat_url = f"https://en.wikipedia.org/w/api.php?action=query&prop=categories&titles={main_title}&format=json&cllimit=50" |
|
cat_response = requests.get(cat_url).json() |
|
pages = cat_response.get('query', {}).get('pages', {}) |
|
if pages: |
|
cats = list(pages.values())[0].get('categories', []) |
|
cat_titles = [cat['title'].replace('Category:', '').lower() for cat in cats] |
|
relevant_cats = [c for c in cat_titles if any(t in c for t in topic.lower().split())] |
|
if relevant_cats: |
|
subcat_topic = relevant_cats[0].capitalize().replace(' ', '_') |
|
sub_url = f"https://en.wikipedia.org/w/api.php?action=query&list=categorymembers&cmtitle=Category:{subcat_topic}&cmtype=subcat&format=json&cmlimit=50" |
|
sub_response = requests.get(sub_url).json() |
|
sub_members = sub_response.get('query', {}).get('categorymembers', []) |
|
return [m['title'].replace('Category:', '').lower() for m in sub_members] |
|
return [] |
|
except Exception: |
|
return [] |
|
|
|
def get_category_pages(self, category): |
|
cat_cap = category.capitalize().replace(' ', '_') |
|
url = f"https://en.wikipedia.org/w/api.php?action=query&list=categorymembers&cmtitle=Category:{cat_cap}&cmtype=page&format=json&cmlimit=50" |
|
try: |
|
response = requests.get(url).json() |
|
members = response.get('query', {}).get('categorymembers', []) |
|
|
|
return [member['title'].lower() for member in members if ' ' not in member['title'] and len(member['title']) > 3] |
|
except Exception: |
|
return [] |
|
|
|
def is_subcategory(self, topic, word): |
|
url = f"https://en.wikipedia.org/w/api.php?action=query&prop=categories&format=json&titles={word.capitalize()}" |
|
try: |
|
response = requests.get(url).json() |
|
pages = response.get('query', {}).get('pages', {}) |
|
if pages: |
|
cats = list(pages.values())[0].get('categories', []) |
|
return any(topic.lower() in cat['title'].lower() for cat in cats) |
|
return False |
|
except Exception: |
|
return False |
|
|
|
def generate_words(self, topic, num_words=20): |
|
variations = [topic.lower()] |
|
if topic.endswith('s'): |
|
variations.append(topic[:-1]) |
|
else: |
|
variations.append(topic + 's') |
|
|
|
all_results = {} |
|
|
|
subcats = self.get_wikipedia_subcats(topic) |
|
print('wiki subcats', subcats) |
|
|
|
|
|
for sub in subcats: |
|
pages = self.get_category_pages(sub) |
|
for p in pages: |
|
|
|
all_results[p] = all_results.get(p, 0) + 0.8 |
|
|
|
for variation in variations: |
|
|
|
topic_embedding = self.model.encode([variation], convert_to_numpy=True) |
|
noise_factor = float(os.getenv("SEARCH_RANDOMNESS", "0.02")) |
|
if noise_factor > 0: |
|
noise = np.random.normal(0, noise_factor, topic_embedding.shape) |
|
topic_embedding += noise |
|
topic_embedding = np.ascontiguousarray(topic_embedding, dtype=np.float32) |
|
faiss.normalize_L2(topic_embedding) |
|
|
|
search_size = min(self.max_results * 3, len(self.vocab)) |
|
scores, indices = self.faiss_index.search(topic_embedding, search_size) |
|
|
|
initial_results = [] |
|
for i in range(len(indices[0])): |
|
idx = indices[0][i] |
|
score = scores[0][i] |
|
if score > 0.3: |
|
initial_results.append(self.vocab[idx]) |
|
|
|
|
|
if not subcats: |
|
additional_subcats = [w for w in initial_results[:30] if self.is_subcategory(topic, w)] |
|
subcats.extend(additional_subcats) |
|
|
|
|
|
if not subcats and len(initial_results) >= 3: |
|
result_embeddings = self.model.encode(initial_results, convert_to_numpy=True) |
|
result_embeddings = np.ascontiguousarray(result_embeddings, dtype=np.float32) |
|
faiss.normalize_L2(result_embeddings) |
|
kmeans = KMeans(n_clusters=min(3, len(initial_results)), random_state=42).fit(result_embeddings) |
|
cluster_centers = kmeans.cluster_centers_.astype(np.float32) |
|
faiss.normalize_L2(cluster_centers) |
|
_, subcat_indices = self.faiss_index.search(cluster_centers, 1) |
|
subcats = [self.vocab[subcat_indices[j][0]] for j in range(len(subcat_indices))] |
|
|
|
|
|
for level, subs in enumerate([subcats], start=1): |
|
for sub in subs: |
|
sub_embedding = self.model.encode([sub], convert_to_numpy=True) |
|
sub_embedding = np.ascontiguousarray(sub_embedding, dtype=np.float32) |
|
faiss.normalize_L2(sub_embedding) |
|
sub_scores, sub_indices = self.faiss_index.search(sub_embedding, search_size) |
|
for i in range(len(sub_indices[0])): |
|
idx = sub_indices[0][i] |
|
score = sub_scores[0][i] |
|
if score > 0.3: |
|
w = self.vocab[idx] |
|
weighted_score = score * (0.8 ** level) |
|
all_results[w] = all_results.get(w, 0) + weighted_score |
|
|
|
|
|
for i in range(len(indices[0])): |
|
idx = indices[0][i] |
|
score = scores[0][i] |
|
if score > 0.3: |
|
w = self.vocab[idx] |
|
all_results[w] = all_results.get(w, 0) + score |
|
|
|
|
|
G = nx.Graph() |
|
G.add_node(topic) |
|
for w, score in all_results.items(): |
|
G.add_edge(topic, w, weight=score) |
|
pr = nx.pagerank(G, weight='weight') |
|
|
|
|
|
sorted_results = sorted(pr.items(), key=lambda x: x[1], reverse=True) |
|
final_words = [w for w, _ in sorted_results if w != topic][:num_words] |
|
|
|
return final_words |
|
|
|
if __name__ == "__main__": |
|
|
|
cache_dir = os.path.join(os.path.dirname(__file__), 'model_cache') |
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
generator = CrosswordGenerator2(cache_dir=cache_dir) |
|
topics = ["animal", "animal", "science", "technology", "food", "indian food", "chinese food"] |
|
for topic in topics: |
|
print(f"------------- {topic} ------------") |
|
generated_words = generator.generate_words(topic) |
|
sorted_generated_words = sorted(generated_words) |
|
print(f"Generated words for topic '{topic}':") |
|
print(sorted_generated_words) |
|
|