hangman-dqn-baseline / simulator_playground.py
egpivo's picture
Upload simulator_playground.py with huggingface_hub
eb42ef5 verified
raw
history blame
17.9 kB
#!/usr/bin/env python3
"""
Hangman DQN Simulator Playground
A comprehensive tool for monitoring, comparing, and debugging DQN model improvements.
"""
import argparse
import os
import time
import json
import torch
import numpy as np
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass, asdict
try:
import matplotlib.pyplot as plt
import seaborn as sns
HAS_PLOTTING = True
except ImportError:
HAS_PLOTTING = False
print("Warning: matplotlib/seaborn not available. Visualization will be skipped.")
from collections import defaultdict, Counter
# Import hangman modules
from hangman.rl.models import DuelingQNet, ActorCriticNet
from hangman.rl.utils import load_dict, by_len, set_seed
from hangman.rl.priors import build_length_priors, build_positional_priors, CandCache
from hangman.rl.eval import greedy_rollout, run_solver
from hangman.rl.envs import BatchEnv
from hangman.rl.replay import Replay, SuccessReplay
from hangman.rl.seed_bc import seed_expert
from argparse import Namespace
@dataclass
class ModelMetrics:
"""Container for model performance metrics"""
model_name: str
win_rate: float
avg_turns: float
avg_reward: float
episodes_tested: int
strategy: str
timestamp: str
model_path: str
training_config: Dict
@dataclass
class GameResult:
"""Container for individual game results"""
word: str
word_length: int
won: bool
turns_taken: int
final_reward: float
guesses: List[str]
pattern_history: List[str]
q_values_history: List[Dict[str, float]]
class HangmanSimulator:
"""Comprehensive hangman simulator for monitoring DQN improvements"""
def __init__(self, dict_path: str = "data/words_250000_train.txt",
len_lo: int = 4, len_hi: int = 12, max_len: int = 35):
self.dict_path = dict_path
self.len_lo = len_lo
self.len_hi = len_hi
self.max_len = max_len
# Load data
print("Loading dictionary and building priors...")
self.words = load_dict(dict_path)
self.buckets = by_len(self.words, len_lo, len_hi)
self.priors = build_length_priors(self.buckets)
self.pos_priors = build_positional_priors(self.buckets, max_len)
self.cand_cache = CandCache(100_000)
# Device setup
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {self.device}")
# Results storage
self.model_metrics: List[ModelMetrics] = []
self.game_results: List[GameResult] = []
# Create output directory
self.output_dir = Path("simulator_results")
self.output_dir.mkdir(exist_ok=True)
def load_model(self, model_path: str, model_name: str = None) -> DuelingQNet:
"""Load a trained model from checkpoint"""
if model_name is None:
model_name = Path(model_path).stem
print(f"Loading model: {model_name} from {model_path}")
model = DuelingQNet(
d_model=128, nhead=4, nlayers=2, ff_mult=4,
max_len=self.max_len, dropout=0.1
)
if os.path.exists(model_path):
checkpoint = torch.load(model_path, map_location=self.device)
model.load_state_dict(checkpoint['model'])
print(f"βœ… Model loaded successfully")
else:
print(f"❌ Model file not found: {model_path}")
return None
model.to(self.device)
return model
def evaluate_model(self, model: DuelingQNet, model_name: str,
episodes: int = 1000, detailed: bool = False) -> ModelMetrics:
"""Evaluate a model and return comprehensive metrics"""
print(f"\nπŸ” Evaluating {model_name} over {episodes} episodes...")
# Setup environment
env = BatchEnv(self.buckets, 6, 64, sorted(self.buckets.keys()), self.max_len)
env.reset()
# Run evaluation
start_time = time.time()
win_rate = greedy_rollout(env, model, self.device, N=episodes,
priors=self.priors, log_stride=100)
eval_time = time.time() - start_time
# Calculate additional metrics
avg_turns = 6.0 # Placeholder - would need to track this in evaluation
avg_reward = win_rate * 1.0 + (1 - win_rate) * (-1.0) # Simplified
metrics = ModelMetrics(
model_name=model_name,
win_rate=win_rate,
avg_turns=avg_turns,
avg_reward=avg_reward,
episodes_tested=episodes,
strategy="dqn",
timestamp=time.strftime("%Y-%m-%d %H:%M:%S"),
model_path="", # Will be set by caller
training_config={}
)
print(f"βœ… {model_name}: Win Rate = {win_rate:.3f} ({eval_time:.1f}s)")
return metrics
def compare_with_baselines(self, model: DuelingQNet, model_name: str,
episodes: int = 1000) -> Dict[str, float]:
"""Compare model performance with heuristic baselines"""
print(f"\nπŸ“Š Comparing {model_name} with baseline strategies...")
results = {}
# Evaluate DQN model
dqn_metrics = self.evaluate_model(model, model_name, episodes)
results['dqn'] = dqn_metrics.win_rate
# Evaluate baseline strategies
strategies = ['cand', 'igx', 'pos', 'len', 'ig']
for strategy in strategies:
print(f"Testing {strategy} baseline...")
args = Namespace(
solver_mode=strategy,
tries=6,
batch_env=64,
max_len=self.max_len,
solver_eval_N=episodes,
csv_log=False
)
win_rate = run_solver(args, self.buckets, self.priors, self.pos_priors)
results[strategy] = win_rate
print(f" {strategy}: {win_rate:.3f}")
return results
def analyze_model_decisions(self, model: DuelingQNet, model_name: str,
num_games: int = 10) -> List[GameResult]:
"""Analyze individual game decisions for debugging"""
print(f"\nπŸ”¬ Analyzing {model_name} decisions in {num_games} games...")
model.eval()
if hasattr(model, "remove_noise"):
model.remove_noise()
results = []
env = BatchEnv(self.buckets, 6, 1, sorted(self.buckets.keys()), self.max_len)
for game_idx in range(num_games):
env.reset()
word = env.words[0]
L = len(word)
game_result = GameResult(
word=word,
word_length=L,
won=False,
turns_taken=0,
final_reward=0.0,
guesses=[],
pattern_history=[],
q_values_history=[]
)
while not env.done[0]:
# Get model prediction
pat_idx, tried, lens, tries = env.observe()
B_now = pat_idx.size(0)
# Prepare inputs
lp = torch.zeros((B_now, 26), dtype=torch.float32)
lp[0, :] = torch.tensor(self.priors.get(L, [0.0] * 26))
tn = (tries.float() / 6.0).unsqueeze(1)
# Get Q-values
with torch.no_grad():
q_values = model(
pat_idx.to(self.device),
tried.to(self.device),
lens.to(self.device),
lp.to(self.device),
tn.to(self.device)
)
# Convert Q-values to probabilities for analysis
q_vals = q_values[0].cpu().numpy()
action = int(q_vals.argmax())
letter = chr(ord('a') + action)
# Store decision info
q_dict = {chr(ord('a') + i): float(q_vals[i]) for i in range(26)}
game_result.q_values_history.append(q_dict)
game_result.guesses.append(letter)
game_result.pattern_history.append(env.patterns[0])
# Take action
reward = env.step(torch.tensor([action]))[0].item()
game_result.turns_taken += 1
game_result.final_reward = reward
game_result.won = bool(env.won[0].item())
results.append(game_result)
print(f" Game {game_idx+1}: {word} -> {'WON' if game_result.won else 'LOST'} "
f"({game_result.turns_taken} turns)")
model.train()
if hasattr(model, "resample_noise"):
model.resample_noise()
return results
def create_performance_report(self, results: Dict[str, float],
model_name: str) -> str:
"""Create a comprehensive performance report"""
report = f"""
# Hangman DQN Performance Report
**Model**: {model_name}
**Timestamp**: {time.strftime("%Y-%m-%d %H:%M:%S")}
**Episodes Tested**: 1000
## Performance Comparison
| Strategy | Win Rate | Performance vs DQN |
|----------|----------|-------------------|"""
dqn_rate = results.get('dqn', 0.0)
for strategy, rate in results.items():
if strategy == 'dqn':
continue
diff = rate - dqn_rate
diff_pct = (diff / max(dqn_rate, 0.001)) * 100
report += f"\n| {strategy.upper()} | {rate:.3f} | {diff:+.3f} ({diff_pct:+.1f}%)"
report += f"\n| **DQN** | **{dqn_rate:.3f}** | **baseline** |"
# Analysis
best_baseline = max([(k, v) for k, v in results.items() if k != 'dqn'],
key=lambda x: x[1])
report += f"""
## Analysis
- **Best Baseline**: {best_baseline[0].upper()} ({best_baseline[1]:.3f} win rate)
- **DQN Performance**: {dqn_rate:.3f} win rate
- **Gap to Best**: {best_baseline[1] - dqn_rate:.3f} ({(best_baseline[1] - dqn_rate)/best_baseline[1]*100:.1f}% behind)
## Recommendations
"""
if dqn_rate < 0.1:
report += "- ❌ **Critical**: DQN performance is extremely poor. Check training data quality and model architecture.\n"
elif dqn_rate < best_baseline[1] * 0.5:
report += "- ⚠️ **Poor**: DQN significantly underperforms best baseline. Consider retraining with better teacher strategy.\n"
elif dqn_rate < best_baseline[1] * 0.8:
report += "- πŸ”Ά **Fair**: DQN shows promise but needs improvement. Fine-tune training parameters.\n"
else:
report += "- βœ… **Good**: DQN performance is competitive with baselines.\n"
if best_baseline[0] == 'cand':
report += "- 🎯 **Priority**: Use 'cand' strategy as teacher for retraining (85%+ win rate)\n"
return report
def save_results(self, results: Dict[str, float], model_name: str,
game_analysis: List[GameResult] = None):
"""Save all results to files"""
timestamp = time.strftime("%Y%m%d_%H%M%S")
# Save performance comparison
results_file = self.output_dir / f"performance_{model_name}_{timestamp}.json"
with open(results_file, 'w') as f:
json.dump({
'model_name': model_name,
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
'results': results,
'config': {
'dict_path': self.dict_path,
'len_lo': self.len_lo,
'len_hi': self.len_hi,
'max_len': self.max_len
}
}, f, indent=2)
# Save performance report
report = self.create_performance_report(results, model_name)
report_file = self.output_dir / f"report_{model_name}_{timestamp}.md"
with open(report_file, 'w') as f:
f.write(report)
# Save game analysis if provided
if game_analysis:
analysis_file = self.output_dir / f"analysis_{model_name}_{timestamp}.json"
with open(analysis_file, 'w') as f:
json.dump([asdict(result) for result in game_analysis], f, indent=2)
print(f"\nπŸ’Ύ Results saved:")
print(f" - Performance: {results_file}")
print(f" - Report: {report_file}")
if game_analysis:
print(f" - Analysis: {analysis_file}")
def create_visualization(self, results: Dict[str, float], model_name: str):
"""Create performance visualization"""
if not HAS_PLOTTING:
print("⚠️ Visualization skipped - matplotlib not available")
return
plt.figure(figsize=(12, 8))
# Performance comparison bar chart
plt.subplot(2, 2, 1)
strategies = list(results.keys())
rates = list(results.values())
colors = ['red' if s == 'dqn' else 'blue' for s in strategies]
bars = plt.bar(strategies, rates, color=colors, alpha=0.7)
plt.title(f'Win Rate Comparison - {model_name}')
plt.ylabel('Win Rate')
plt.xticks(rotation=45)
# Highlight DQN
for i, (strategy, rate) in enumerate(zip(strategies, rates)):
if strategy == 'dqn':
bars[i].set_color('red')
bars[i].set_alpha(1.0)
# Performance gap analysis
plt.subplot(2, 2, 2)
dqn_rate = results.get('dqn', 0.0)
gaps = [rate - dqn_rate for rate in rates]
colors = ['red' if gap < 0 else 'green' for gap in gaps]
plt.bar(strategies, gaps, color=colors, alpha=0.7)
plt.title('Performance Gap vs DQN')
plt.ylabel('Win Rate Difference')
plt.xticks(rotation=45)
plt.axhline(y=0, color='black', linestyle='--', alpha=0.5)
# Word length analysis (placeholder)
plt.subplot(2, 2, 3)
word_lengths = sorted(self.buckets.keys())
word_counts = [len(self.buckets[L]) for L in word_lengths]
plt.bar(word_lengths, word_counts, alpha=0.7)
plt.title('Word Distribution by Length')
plt.xlabel('Word Length')
plt.ylabel('Count')
# Performance trend (placeholder for future use)
plt.subplot(2, 2, 4)
plt.text(0.5, 0.5, 'Performance Trend\n(Coming Soon)',
ha='center', va='center', fontsize=12)
plt.title('Training Progress')
plt.axis('off')
plt.tight_layout()
# Save plot
timestamp = time.strftime("%Y%m%d_%H%M%S")
plot_file = self.output_dir / f"visualization_{model_name}_{timestamp}.png"
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
plt.show()
print(f"πŸ“Š Visualization saved: {plot_file}")
def main():
parser = argparse.ArgumentParser(description="Hangman DQN Simulator Playground")
parser.add_argument("--model", type=str, required=True, help="Path to model checkpoint")
parser.add_argument("--model-name", type=str, help="Name for the model (default: filename)")
parser.add_argument("--episodes", type=int, default=1000, help="Number of episodes to test")
parser.add_argument("--detailed", action="store_true", help="Run detailed analysis")
parser.add_argument("--analyze-decisions", action="store_true", help="Analyze individual game decisions")
parser.add_argument("--num-games", type=int, default=10, help="Number of games for decision analysis")
parser.add_argument("--dict-path", type=str, default="data/words_250000_train.txt", help="Dictionary path")
parser.add_argument("--len-lo", type=int, default=4, help="Minimum word length")
parser.add_argument("--len-hi", type=int, default=12, help="Maximum word length")
parser.add_argument("--max-len", type=int, default=35, help="Model max sequence length")
args = parser.parse_args()
# Initialize simulator
simulator = HangmanSimulator(
dict_path=args.dict_path,
len_lo=args.len_lo,
len_hi=args.len_hi,
max_len=args.max_len
)
# Load model
model_name = args.model_name or Path(args.model).stem
model = simulator.load_model(args.model, model_name)
if model is None:
print("❌ Failed to load model. Exiting.")
return
# Run comprehensive evaluation
print(f"\nπŸš€ Starting comprehensive evaluation of {model_name}...")
# Compare with baselines
results = simulator.compare_with_baselines(model, model_name, args.episodes)
# Detailed analysis if requested
game_analysis = None
if args.analyze_decisions:
game_analysis = simulator.analyze_model_decisions(model, model_name, args.num_games)
# Save results
simulator.save_results(results, model_name, game_analysis)
# Create visualization
simulator.create_visualization(results, model_name)
# Print summary
print(f"\nπŸ“‹ Summary for {model_name}:")
print(f" DQN Win Rate: {results.get('dqn', 0.0):.3f}")
best_baseline = max([(k, v) for k, v in results.items() if k != 'dqn'], key=lambda x: x[1])
print(f" Best Baseline: {best_baseline[0]} ({best_baseline[1]:.3f})")
print(f" Gap: {best_baseline[1] - results.get('dqn', 0.0):.3f}")
if __name__ == "__main__":
main()