|
|
|
""" |
|
Local LLM Clue Generator for Crossword Puzzles |
|
|
|
Uses google/flan-t5-small for generating contextual crossword clues. |
|
Designed to work within Hugging Face Spaces constraints. |
|
""" |
|
|
|
import os |
|
import time |
|
import logging |
|
from typing import List, Dict, Optional, Tuple, Any |
|
from pathlib import Path |
|
|
|
|
|
try: |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
TRANSFORMERS_AVAILABLE = True |
|
except ImportError: |
|
TRANSFORMERS_AVAILABLE = False |
|
logging.warning("Transformers not available - LLM clue generation disabled") |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class LLMClueGenerator: |
|
""" |
|
Local LLM-based clue generator using google/flan-t5-small. |
|
Optimized for Hugging Face Spaces deployment. |
|
""" |
|
|
|
def __init__(self, cache_dir: Optional[str] = None): |
|
"""Initialize the LLM clue generator. |
|
|
|
Args: |
|
cache_dir: Directory to cache the model files |
|
""" |
|
if not TRANSFORMERS_AVAILABLE: |
|
raise ImportError("transformers library is required for LLM clue generation") |
|
|
|
if cache_dir is None: |
|
cache_dir = os.path.join(os.path.dirname(__file__), 'model_cache') |
|
|
|
self.cache_dir = Path(cache_dir) |
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.model_name = "google/flan-t5-base" |
|
self.max_length = 64 |
|
self.num_return_sequences = 3 |
|
|
|
|
|
self.tokenizer = None |
|
self.model = None |
|
self.generator = None |
|
self.is_initialized = False |
|
|
|
|
|
self.clue_templates = { |
|
"definition": """Write a crossword clue for the word '{word}' (topic: {topic}). |
|
|
|
Examples: |
|
- CAT (animals) β "Feline pet" |
|
- GUITAR (music) β "Stringed instrument" |
|
- AIRPORT (transportation) β "Flight terminal" |
|
|
|
Now write a clue for '{word}' (topic: {topic}) in 2-5 words:""", |
|
|
|
"description": """Create a crossword clue by describing '{word}' from the {topic} category. |
|
|
|
Examples: |
|
- DOG (animals) β "Loyal canine companion" |
|
- PIZZA (food) β "Italian bread dish" |
|
- DATABASE (technology) β "Information storage system" |
|
|
|
Describe '{word}' (topic: {topic}) in 3-6 words:""", |
|
|
|
"simple": """Complete this crossword clue pattern. |
|
|
|
Examples: |
|
VIOLIN (music) = "Bowed string instrument" |
|
SCIENTIST (science) = "Research professional" |
|
SWIMMING (sports) = "Aquatic athletic activity" |
|
|
|
{word} ({topic}) =""", |
|
} |
|
|
|
def initialize(self): |
|
"""Initialize the LLM model and tokenizer.""" |
|
if self.is_initialized: |
|
return |
|
|
|
start_time = time.time() |
|
logger.info(f"π€ Initializing LLM clue generator with {self.model_name}") |
|
logger.info(f"π Cache directory: {self.cache_dir}") |
|
|
|
try: |
|
|
|
logger.info("π¦ Loading tokenizer...") |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.model_name, |
|
cache_dir=str(self.cache_dir) |
|
) |
|
|
|
logger.info("π¦ Loading model...") |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained( |
|
self.model_name, |
|
cache_dir=str(self.cache_dir) |
|
) |
|
|
|
|
|
logger.info("π Creating generation pipeline...") |
|
self.generator = pipeline( |
|
"text2text-generation", |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
max_length=self.max_length, |
|
num_return_sequences=1, |
|
do_sample=True, |
|
temperature=0.7, |
|
device=-1 |
|
) |
|
|
|
self.is_initialized = True |
|
init_time = time.time() - start_time |
|
logger.info(f"β
LLM clue generator initialized in {init_time:.2f}s") |
|
|
|
except Exception as e: |
|
logger.error(f"β Failed to initialize LLM clue generator: {e}") |
|
raise |
|
|
|
def generate_clue(self, |
|
word: str, |
|
topic: str, |
|
clue_style: str = "category", |
|
difficulty: str = "medium") -> str: |
|
"""Generate a single clue for the given word and topic. |
|
|
|
Args: |
|
word: The word to generate a clue for |
|
topic: The theme/topic context |
|
clue_style: Style of clue ('definition', 'trivia', 'description', 'category', 'simple') |
|
difficulty: Difficulty level ('easy', 'medium', 'hard') |
|
|
|
Returns: |
|
Generated clue string |
|
""" |
|
if not self.is_initialized: |
|
self.initialize() |
|
candidates = self.generate_clue_candidates(word, topic, clue_style, difficulty) |
|
return self._select_best_clue(candidates, word) if candidates else self._fallback_clue(word, topic) |
|
|
|
def generate_clue_candidates(self, |
|
word: str, |
|
topic: str, |
|
clue_style: str = "category", |
|
difficulty: str = "medium", |
|
num_candidates: int = 5) -> List[str]: |
|
"""Generate multiple clue candidates using different strategies. |
|
|
|
Args: |
|
word: The word to generate clues for |
|
topic: The theme/topic context |
|
clue_style: Style of clue to generate |
|
difficulty: Difficulty level |
|
num_candidates: Number of candidates to generate |
|
|
|
Returns: |
|
List of generated clue candidates |
|
""" |
|
if not self.is_initialized: |
|
self.initialize() |
|
|
|
logger.info(f"π― Generating {num_candidates} clues for '{word}' (topic: {topic}, style: {clue_style})") |
|
|
|
candidates = [] |
|
|
|
|
|
candidates.extend(self._try_clue_generation(word, topic, clue_style, difficulty, num_candidates // 2)) |
|
|
|
|
|
if len(candidates) < 2: |
|
backup_styles = ["definition", "description", "simple"] |
|
for backup_style in backup_styles: |
|
if backup_style != clue_style: |
|
backup_candidates = self._try_clue_generation(word, topic, backup_style, difficulty, 2) |
|
candidates.extend(backup_candidates) |
|
if len(candidates) >= 3: |
|
break |
|
|
|
|
|
if len(candidates) < 2: |
|
logger.debug(f"β οΈ Low quality candidates, trying with different temperature") |
|
candidates.extend(self._try_clue_generation(word, topic, "simple", difficulty, 3, temperature=0.5)) |
|
|
|
logger.debug(f"β
Generated {len(candidates)} valid candidates total") |
|
return candidates[:num_candidates] |
|
|
|
def _try_clue_generation(self, word: str, topic: str, clue_style: str, difficulty: str, |
|
attempts: int, temperature: float = 0.8) -> List[str]: |
|
"""Try generating clues with specific parameters.""" |
|
template = self.clue_templates.get(clue_style, self.clue_templates["definition"]) |
|
prompt = self._create_prompt(word, topic, template, difficulty) |
|
|
|
candidates = [] |
|
|
|
try: |
|
for i in range(attempts): |
|
result = self.generator( |
|
prompt, |
|
max_length=self.max_length, |
|
do_sample=True, |
|
temperature=temperature, |
|
num_return_sequences=1, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
if result and len(result) > 0: |
|
generated_text = result[0]['generated_text'].strip() |
|
|
|
|
|
clean_clue = self._clean_generated_clue(generated_text, word) |
|
|
|
if clean_clue and clean_clue not in candidates: |
|
candidates.append(clean_clue) |
|
logger.debug(f"β
Valid clue #{len(candidates)}: {clean_clue}") |
|
else: |
|
logger.debug(f"β Rejected clue: {generated_text[:100]}...") |
|
|
|
except Exception as e: |
|
logger.error(f"β Error in clue generation attempt: {e}") |
|
|
|
return candidates |
|
|
|
def generate_clues_batch(self, |
|
words_and_topics: List[Tuple[str, str]], |
|
clue_style: str = "category", |
|
difficulty: str = "medium") -> Dict[str, str]: |
|
"""Generate clues for multiple words in batch. |
|
|
|
Args: |
|
words_and_topics: List of (word, topic) tuples |
|
clue_style: Style of clue to generate |
|
difficulty: Difficulty level |
|
|
|
Returns: |
|
Dictionary mapping words to their generated clues |
|
""" |
|
if not self.is_initialized: |
|
self.initialize() |
|
|
|
logger.info(f"π― Generating {len(words_and_topics)} clues in batch") |
|
|
|
results = {} |
|
start_time = time.time() |
|
|
|
for i, (word, topic) in enumerate(words_and_topics): |
|
try: |
|
clue = self.generate_clue(word, topic, clue_style, difficulty) |
|
results[word] = clue |
|
|
|
if (i + 1) % 5 == 0: |
|
elapsed = time.time() - start_time |
|
avg_time = elapsed / (i + 1) |
|
logger.info(f"π Progress: {i+1}/{len(words_and_topics)} ({avg_time:.2f}s per clue)") |
|
|
|
except Exception as e: |
|
logger.error(f"β Failed to generate clue for '{word}': {e}") |
|
results[word] = self._fallback_clue(word, topic) |
|
|
|
total_time = time.time() - start_time |
|
logger.info(f"β
Batch generation complete in {total_time:.2f}s (avg: {total_time/len(words_and_topics):.2f}s per clue)") |
|
|
|
return results |
|
|
|
def _create_prompt(self, word: str, topic: str, template: str, difficulty: str) -> str: |
|
"""Create a difficulty-aware prompt for the LLM.""" |
|
|
|
difficulty_hints = { |
|
"easy": "Keep it simple and clear.", |
|
"medium": "Make it moderately challenging.", |
|
"hard": "Make it clever and challenging." |
|
} |
|
|
|
base_prompt = template.format(word=word, topic=topic) |
|
hint = difficulty_hints.get(difficulty, "") |
|
|
|
return f"{base_prompt} {hint}".strip() |
|
|
|
def _clean_generated_clue(self, generated_text: str, word: str) -> str: |
|
"""Clean and validate the generated clue text with improved filtering.""" |
|
if not generated_text: |
|
return "" |
|
|
|
|
|
clue = generated_text.strip() |
|
|
|
|
|
artifacts_to_remove = [ |
|
"Your clue:", "Your answer:", "Clue:", "Answer:", "Format:", "Example:", |
|
"Rules:", "Here's", "The clue is", "A good clue would be", "This is", |
|
"I would suggest", "One option could be", "Consider this", |
|
] |
|
|
|
clue_lower = clue.lower() |
|
for artifact in artifacts_to_remove: |
|
if artifact.lower() in clue_lower: |
|
|
|
artifact_pos = clue_lower.find(artifact.lower()) |
|
if artifact_pos >= 0: |
|
clue = clue[artifact_pos + len(artifact):].strip() |
|
|
|
|
|
clue = clue.strip('"\'[](){}<>') |
|
|
|
|
|
word_lower = word.lower() |
|
clue_words = set(clue.lower().split()) |
|
|
|
if word_lower in clue_words: |
|
logger.debug(f"β οΈ Rejecting clue containing target word '{word}': {clue}") |
|
return "" |
|
|
|
|
|
if any(word_lower in clue_word or clue_word in word_lower for clue_word in clue_words): |
|
logger.debug(f"β οΈ Rejecting clue with partial word match for '{word}': {clue}") |
|
return "" |
|
|
|
|
|
if len(clue) < 5 or len(clue) > 80: |
|
logger.debug(f"β οΈ Rejecting clue with bad length ({len(clue)}): {clue}") |
|
return "" |
|
|
|
|
|
word_count = len(clue.split()) |
|
if word_count > 15: |
|
logger.debug(f"β οΈ Rejecting wordy clue ({word_count} words): {clue}") |
|
return "" |
|
|
|
|
|
if self._is_nonsensical(clue): |
|
logger.debug(f"β οΈ Rejecting nonsensical clue: {clue}") |
|
return "" |
|
|
|
|
|
clue = clue.capitalize() |
|
if not clue.endswith('.'): |
|
clue = clue.rstrip('.,!?') + '.' |
|
|
|
return clue |
|
|
|
def _is_nonsensical(self, clue: str) -> bool: |
|
"""Check if clue appears nonsensical or inappropriate.""" |
|
clue_lower = clue.lower() |
|
|
|
|
|
nonsense_indicators = [ |
|
"shit", "crap", "damn", "fuck", |
|
"nicolas", "fender", "omelets are sometimes", |
|
"for the most part", "go to a party", |
|
"help for the kids", "new ways to get", |
|
] |
|
|
|
for indicator in nonsense_indicators: |
|
if indicator in clue_lower: |
|
return True |
|
|
|
|
|
if clue_lower.startswith(("for the", "help for", "go to", "new ways")): |
|
return True |
|
|
|
|
|
words = clue_lower.split() |
|
if len(set(words)) < len(words) * 0.5: |
|
return True |
|
|
|
return False |
|
|
|
def _select_best_clue(self, candidates: List[str], word: str) -> str: |
|
"""Select the best clue from candidates.""" |
|
if not candidates: |
|
return "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
scored_candidates = [] |
|
|
|
for clue in candidates: |
|
score = 0 |
|
|
|
|
|
length = len(clue) |
|
if 20 <= length <= 60: |
|
score += 10 |
|
elif length < 20: |
|
score += 5 |
|
else: |
|
score -= (length - 60) // 10 |
|
|
|
|
|
if word.lower() not in clue.lower(): |
|
score += 15 |
|
else: |
|
score -= 20 |
|
|
|
|
|
if any(p in clue for p in '.!?,:;'): |
|
score += 3 |
|
|
|
scored_candidates.append((score, clue)) |
|
|
|
|
|
scored_candidates.sort(key=lambda x: x[0], reverse=True) |
|
|
|
best_clue = scored_candidates[0][1] |
|
logger.debug(f"π Selected best clue: '{best_clue}' (score: {scored_candidates[0][0]})") |
|
|
|
return best_clue |
|
|
|
def _fallback_clue(self, word: str, topic: str) -> str: |
|
"""Generate a simple fallback clue when LLM fails.""" |
|
word_lower = word.lower() |
|
topic_lower = topic.lower() |
|
|
|
|
|
if any(keyword in topic_lower for keyword in ["animal", "pet", "wildlife"]): |
|
return f"Animal: {word_lower}" |
|
elif any(keyword in topic_lower for keyword in ["tech", "computer", "software"]): |
|
return f"Technology term: {word_lower}" |
|
elif any(keyword in topic_lower for keyword in ["science", "biology", "chemistry"]): |
|
return f"Science: {word_lower}" |
|
elif any(keyword in topic_lower for keyword in ["food", "cooking", "cuisine"]): |
|
return f"Food item: {word_lower}" |
|
elif any(keyword in topic_lower for keyword in ["music", "song", "instrument"]): |
|
return f"Music: {word_lower}" |
|
else: |
|
return f"Related to {topic_lower}: {word_lower}" |
|
|
|
def get_model_info(self) -> Dict[str, Any]: |
|
"""Get information about the loaded model.""" |
|
info = { |
|
"model_name": self.model_name, |
|
"is_initialized": self.is_initialized, |
|
"cache_directory": str(self.cache_dir), |
|
"transformers_available": TRANSFORMERS_AVAILABLE |
|
} |
|
|
|
if self.is_initialized and self.model: |
|
try: |
|
|
|
total_params = sum(p.numel() for p in self.model.parameters()) |
|
info["model_parameters"] = total_params |
|
info["model_size_mb"] = total_params * 4 / (1024 * 1024) |
|
except: |
|
pass |
|
|
|
return info |
|
|
|
|
|
def main(): |
|
"""Test the LLM clue generator.""" |
|
print("π LLM Clue Generator Test") |
|
print("=" * 50) |
|
|
|
|
|
print("π Initializing LLM clue generator...") |
|
generator = LLMClueGenerator() |
|
|
|
try: |
|
generator.initialize() |
|
|
|
|
|
info = generator.get_model_info() |
|
print(f"\nπ Model Information:") |
|
print(f" Model: {info['model_name']}") |
|
print(f" Parameters: {info.get('model_parameters', 'Unknown'):,}") |
|
print(f" Size: {info.get('model_size_mb', 0):.1f} MB") |
|
|
|
|
|
print("\nπ― Single Clue Generation:") |
|
print("-" * 30) |
|
|
|
test_cases = [ |
|
("elephant", "animals"), |
|
("python", "technology"), |
|
("ocean", "geography"), |
|
("guitar", "music"), |
|
("pizza", "food") |
|
] |
|
|
|
for word, topic in test_cases: |
|
print(f"\nWord: '{word}' | Topic: '{topic}'") |
|
|
|
|
|
for style in ["category", "definition", "trivia"]: |
|
start_time = time.time() |
|
clue = generator.generate_clue(word, topic, clue_style=style) |
|
gen_time = time.time() - start_time |
|
|
|
print(f" {style:10}: {clue} ({gen_time:.2f}s)") |
|
|
|
|
|
print(f"\nπ― Batch Generation Test:") |
|
print("-" * 30) |
|
|
|
batch_words = [ |
|
("cat", "animals"), |
|
("computer", "technology"), |
|
("mountain", "geography"), |
|
("piano", "music") |
|
] |
|
|
|
batch_results = generator.generate_clues_batch(batch_words, clue_style="category") |
|
|
|
for word, clue in batch_results.items(): |
|
print(f" {word:10}: {clue}") |
|
|
|
print(f"\nβ
LLM clue generator test completed!") |
|
|
|
except Exception as e: |
|
print(f"β Error during testing: {e}") |
|
print("This might be due to missing transformers library or model download issues.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|