abc123 / hack /cw-gen-2.py
vimalk78's picture
feat(crossword): generated crosswords with clues
486eff6
raw
history blame
10.3 kB
#!/usr/bin/env python3
import os
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
import requests
from sklearn.cluster import KMeans
import networkx as nx
import csv
def get_vocab():
# Dynamically fetch a large list of English words from a public GitHub repository
url = "https://raw.githubusercontent.com/dwyl/english-words/master/words.txt"
response = requests.get(url)
if response.status_code == 200:
return [word.strip().lower() for word in response.text.splitlines() if word.strip() and len(word) > 2] # Filter short words
else:
raise Exception("Failed to fetch vocabulary list")
class CrosswordGenerator2:
def get_dict_vocab(self):
"""Read the dictionary CSV file and return list of words."""
dict_path = os.path.join(os.path.dirname(__file__), 'dict-words', 'dict.csv')
words = []
try:
with open(dict_path, 'r', encoding='utf-8') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
word = row['word'].strip().lower()
if word and len(word) > 2: # Filter short words
words.append(word)
except FileNotFoundError:
raise Exception(f"Dictionary file not found: {dict_path}")
except Exception as e:
raise Exception(f"Error reading dictionary file: {e}")
return words
def __init__(self, cache_dir='./model_cache'):
self.vocab = self.get_dict_vocab()
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', cache_folder=cache_dir)
embeddings = self.model.encode(self.vocab, convert_to_numpy=True)
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
faiss.normalize_L2(embeddings)
self.dimension = embeddings.shape[1]
self.faiss_index = faiss.IndexFlatIP(self.dimension)
self.faiss_index.add(embeddings)
self.max_results = 50 # Adjustable
def get_wikipedia_subcats(self, topic):
topic_cap = topic.capitalize().replace(' ', '_')
url = f"https://en.wikipedia.org/w/api.php?action=query&list=categorymembers&cmtitle=Category:{topic_cap}&cmtype=subcat&format=json&cmlimit=50"
try:
response = requests.get(url).json()
members = response.get('query', {}).get('categorymembers', [])
if members:
return [member['title'].replace('Category:', '').lower() for member in members]
else:
# Fallback: Search for main page and get relevant category subcats
search_url = f"https://en.wikipedia.org/w/api.php?action=query&list=search&srsearch={topic}&format=json"
search_response = requests.get(search_url).json()
search_results = search_response.get('query', {}).get('search', [])
if search_results:
main_title = search_results[0]['title']
cat_url = f"https://en.wikipedia.org/w/api.php?action=query&prop=categories&titles={main_title}&format=json&cllimit=50"
cat_response = requests.get(cat_url).json()
pages = cat_response.get('query', {}).get('pages', {})
if pages:
cats = list(pages.values())[0].get('categories', [])
cat_titles = [cat['title'].replace('Category:', '').lower() for cat in cats]
relevant_cats = [c for c in cat_titles if any(t in c for t in topic.lower().split())]
if relevant_cats:
subcat_topic = relevant_cats[0].capitalize().replace(' ', '_')
sub_url = f"https://en.wikipedia.org/w/api.php?action=query&list=categorymembers&cmtitle=Category:{subcat_topic}&cmtype=subcat&format=json&cmlimit=50"
sub_response = requests.get(sub_url).json()
sub_members = sub_response.get('query', {}).get('categorymembers', [])
return [m['title'].replace('Category:', '').lower() for m in sub_members]
return []
except Exception:
return []
def get_category_pages(self, category):
cat_cap = category.capitalize().replace(' ', '_')
url = f"https://en.wikipedia.org/w/api.php?action=query&list=categorymembers&cmtitle=Category:{cat_cap}&cmtype=page&format=json&cmlimit=50"
try:
response = requests.get(url).json()
members = response.get('query', {}).get('categorymembers', [])
# Filter to single words, lower case
return [member['title'].lower() for member in members if ' ' not in member['title'] and len(member['title']) > 3]
except Exception:
return []
def is_subcategory(self, topic, word):
url = f"https://en.wikipedia.org/w/api.php?action=query&prop=categories&format=json&titles={word.capitalize()}"
try:
response = requests.get(url).json()
pages = response.get('query', {}).get('pages', {})
if pages:
cats = list(pages.values())[0].get('categories', [])
return any(topic.lower() in cat['title'].lower() for cat in cats)
return False
except Exception:
return False
def generate_words(self, topic, num_words=20):
variations = [topic.lower()]
if topic.endswith('s'):
variations.append(topic[:-1])
else:
variations.append(topic + 's')
all_results = {}
subcats = self.get_wikipedia_subcats(topic)
print('wiki subcats', subcats)
# Add specific words from subcategory pages
for sub in subcats:
pages = self.get_category_pages(sub)
for p in pages:
# Assign a high score for direct Wikipedia matches
all_results[p] = all_results.get(p, 0) + 0.8 # High base score
for variation in variations:
# Get topic embedding
topic_embedding = self.model.encode([variation], convert_to_numpy=True)
noise_factor = float(os.getenv("SEARCH_RANDOMNESS", "0.02"))
if noise_factor > 0:
noise = np.random.normal(0, noise_factor, topic_embedding.shape)
topic_embedding += noise
topic_embedding = np.ascontiguousarray(topic_embedding, dtype=np.float32)
faiss.normalize_L2(topic_embedding)
search_size = min(self.max_results * 3, len(self.vocab))
scores, indices = self.faiss_index.search(topic_embedding, search_size)
initial_results = []
for i in range(len(indices[0])):
idx = indices[0][i]
score = scores[0][i]
if score > 0.3:
initial_results.append(self.vocab[idx])
# Identify additional subcats from initial results if wiki didn't provide
if not subcats:
additional_subcats = [w for w in initial_results[:30] if self.is_subcategory(topic, w)]
subcats.extend(additional_subcats)
# Fallback clustering if still no subcats
if not subcats and len(initial_results) >= 3:
result_embeddings = self.model.encode(initial_results, convert_to_numpy=True)
result_embeddings = np.ascontiguousarray(result_embeddings, dtype=np.float32)
faiss.normalize_L2(result_embeddings)
kmeans = KMeans(n_clusters=min(3, len(initial_results)), random_state=42).fit(result_embeddings)
cluster_centers = kmeans.cluster_centers_.astype(np.float32)
faiss.normalize_L2(cluster_centers)
_, subcat_indices = self.faiss_index.search(cluster_centers, 1)
subcats = [self.vocab[subcat_indices[j][0]] for j in range(len(subcat_indices))]
# Search subcategories
for level, subs in enumerate([subcats], start=1):
for sub in subs:
sub_embedding = self.model.encode([sub], convert_to_numpy=True)
sub_embedding = np.ascontiguousarray(sub_embedding, dtype=np.float32)
faiss.normalize_L2(sub_embedding)
sub_scores, sub_indices = self.faiss_index.search(sub_embedding, search_size)
for i in range(len(sub_indices[0])):
idx = sub_indices[0][i]
score = sub_scores[0][i]
if score > 0.3:
w = self.vocab[idx]
weighted_score = score * (0.8 ** level)
all_results[w] = all_results.get(w, 0) + weighted_score
# Add initial results
for i in range(len(indices[0])):
idx = indices[0][i]
score = scores[0][i]
if score > 0.3:
w = self.vocab[idx]
all_results[w] = all_results.get(w, 0) + score
# Combine with graph-based weighting
G = nx.Graph()
G.add_node(topic)
for w, score in all_results.items():
G.add_edge(topic, w, weight=score)
pr = nx.pagerank(G, weight='weight')
# Sort and return top, exclude topic
sorted_results = sorted(pr.items(), key=lambda x: x[1], reverse=True)
final_words = [w for w, _ in sorted_results if w != topic][:num_words]
return final_words
if __name__ == "__main__":
# Create a cache directory if it doesn't exist
cache_dir = os.path.join(os.path.dirname(__file__), 'model_cache')
os.makedirs(cache_dir, exist_ok=True)
generator = CrosswordGenerator2(cache_dir=cache_dir)
topics = ["animal", "animal", "science", "technology", "food", "indian food", "chinese food"] # Example topic
for topic in topics:
print(f"------------- {topic} ------------")
generated_words = generator.generate_words(topic)
sorted_generated_words = sorted(generated_words)
print(f"Generated words for topic '{topic}':")
print(sorted_generated_words)