|
|
|
import os |
|
import numpy as np |
|
import faiss |
|
from sentence_transformers import SentenceTransformer |
|
import requests |
|
from sklearn.cluster import KMeans |
|
import networkx as nx |
|
|
|
def get_vocab(): |
|
|
|
url = "https://raw.githubusercontent.com/first20hours/google-10000-english/master/google-10000-english-no-swears.txt" |
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
return [word.strip().lower() for word in response.text.splitlines() if word.strip()] |
|
else: |
|
raise Exception("Failed to fetch vocabulary list") |
|
|
|
class CrosswordGenerator: |
|
def __init__(self): |
|
self.vocab = get_vocab() |
|
self.model = SentenceTransformer('all-MiniLM-L6-v2') |
|
embeddings = self.model.encode(self.vocab, convert_to_numpy=True) |
|
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32) |
|
faiss.normalize_L2(embeddings) |
|
self.dimension = embeddings.shape[1] |
|
|
|
self.faiss_index = faiss.IndexFlatIP(self.dimension) |
|
self.faiss_index.add(embeddings) |
|
self.max_results = 50 |
|
|
|
def is_subcategory(self, topic, word): |
|
|
|
url = f"https://en.wikipedia.org/w/api.php?action=query&prop=categories&format=json&titles={word.capitalize()}" |
|
try: |
|
response = requests.get(url).json() |
|
pages = response.get('query', {}).get('pages', {}) |
|
if pages: |
|
cats = list(pages.values())[0].get('categories', []) |
|
return any(topic.lower() in cat['title'].lower() for cat in cats) |
|
return False |
|
except Exception: |
|
return False |
|
|
|
def generate_words(self, topic, num_words=20): |
|
variations = [topic.lower()] |
|
|
|
|
|
|
|
|
|
|
|
all_results = {} |
|
|
|
for variation in variations: |
|
|
|
topic_embedding = self.model.encode([variation], convert_to_numpy=True) |
|
|
|
noise_factor = float(os.getenv("SEARCH_RANDOMNESS", "0.02")) |
|
if noise_factor > 0: |
|
noise = np.random.normal(0, noise_factor, topic_embedding.shape) |
|
topic_embedding += noise |
|
topic_embedding = np.ascontiguousarray(topic_embedding, dtype=np.float32) |
|
faiss.normalize_L2(topic_embedding) |
|
|
|
search_size = min(self.max_results * 3, len(self.vocab)) |
|
scores, indices = self.faiss_index.search(topic_embedding, search_size) |
|
|
|
|
|
initial_results = [] |
|
for i in range(len(indices[0])): |
|
idx = indices[0][i] |
|
score = scores[0][i] |
|
if score > 0.3: |
|
initial_results.append(self.vocab[idx]) |
|
|
|
|
|
subcats = [w for w in initial_results[:30] if self.is_subcategory(topic, w)] |
|
print(f"subcats {subcats}") |
|
|
|
|
|
if not subcats and len(initial_results) >= 3: |
|
result_embeddings = self.model.encode(initial_results, convert_to_numpy=True) |
|
result_embeddings = np.ascontiguousarray(result_embeddings, dtype=np.float32) |
|
faiss.normalize_L2(result_embeddings) |
|
kmeans = KMeans(n_clusters=min(3, len(initial_results)), random_state=42).fit(result_embeddings) |
|
cluster_centers = kmeans.cluster_centers_.astype(np.float32) |
|
faiss.normalize_L2(cluster_centers) |
|
_, subcat_indices = self.faiss_index.search(cluster_centers, 1) |
|
subcats = [self.vocab[subcat_indices[j][0]] for j in range(len(subcat_indices))] |
|
|
|
|
|
for level, subs in enumerate([subcats], start=1): |
|
for sub in subs: |
|
sub_embedding = self.model.encode([sub], convert_to_numpy=True) |
|
sub_embedding = np.ascontiguousarray(sub_embedding, dtype=np.float32) |
|
faiss.normalize_L2(sub_embedding) |
|
sub_scores, sub_indices = self.faiss_index.search(sub_embedding, search_size) |
|
for i in range(len(sub_indices[0])): |
|
idx = sub_indices[0][i] |
|
score = sub_scores[0][i] |
|
if score > 0.3: |
|
w = self.vocab[idx] |
|
|
|
weighted_score = score * (0.8 ** level) |
|
all_results[w] = all_results.get(w, 0) + weighted_score |
|
|
|
|
|
for i in range(len(indices[0])): |
|
idx = indices[0][i] |
|
score = scores[0][i] |
|
if score > 0.3: |
|
w = self.vocab[idx] |
|
all_results[w] = all_results.get(w, 0) + score |
|
|
|
|
|
G = nx.Graph() |
|
G.add_node(topic) |
|
for w, score in all_results.items(): |
|
G.add_edge(topic, w, weight=score) |
|
pr = nx.pagerank(G, weight='weight') |
|
|
|
|
|
sorted_results = sorted(pr.items(), key=lambda x: x[1], reverse=True) |
|
final_words = [w for w, _ in sorted_results if w != topic][:num_words] |
|
|
|
return final_words |
|
|
|
if __name__ == "__main__": |
|
generator = CrosswordGenerator() |
|
topics = ["animal", "animal", "science", "technology", "food", "indian food", "chinese food"] |
|
for topic in topics: |
|
print(f"------------- {topic} ------------") |
|
generated_words = generator.generate_words(topic) |
|
sorted_generated_words = sorted(generated_words) |
|
print(f"Generated words for topic '{topic}':") |
|
print(sorted_generated_words) |
|
|