|
|
|
""" |
|
Hangman DQN Simulator Playground |
|
A comprehensive tool for monitoring, comparing, and debugging DQN model improvements. |
|
""" |
|
|
|
import argparse |
|
import os |
|
import time |
|
import json |
|
import torch |
|
import numpy as np |
|
from pathlib import Path |
|
from typing import Dict, List, Tuple, Optional |
|
from dataclasses import dataclass, asdict |
|
try: |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
HAS_PLOTTING = True |
|
except ImportError: |
|
HAS_PLOTTING = False |
|
print("Warning: matplotlib/seaborn not available. Visualization will be skipped.") |
|
from collections import defaultdict, Counter |
|
|
|
|
|
from hangman.rl.models import DuelingQNet, ActorCriticNet |
|
from hangman.rl.utils import load_dict, by_len, set_seed |
|
from hangman.rl.priors import build_length_priors, build_positional_priors, CandCache |
|
from hangman.rl.eval import greedy_rollout, run_solver |
|
from hangman.rl.envs import BatchEnv |
|
from hangman.rl.replay import Replay, SuccessReplay |
|
from hangman.rl.seed_bc import seed_expert |
|
from argparse import Namespace |
|
|
|
|
|
@dataclass |
|
class ModelMetrics: |
|
"""Container for model performance metrics""" |
|
model_name: str |
|
win_rate: float |
|
avg_turns: float |
|
avg_reward: float |
|
episodes_tested: int |
|
strategy: str |
|
timestamp: str |
|
model_path: str |
|
training_config: Dict |
|
|
|
|
|
@dataclass |
|
class GameResult: |
|
"""Container for individual game results""" |
|
word: str |
|
word_length: int |
|
won: bool |
|
turns_taken: int |
|
final_reward: float |
|
guesses: List[str] |
|
pattern_history: List[str] |
|
q_values_history: List[Dict[str, float]] |
|
|
|
|
|
class HangmanSimulator: |
|
"""Comprehensive hangman simulator for monitoring DQN improvements""" |
|
|
|
def __init__(self, dict_path: str = "data/words_250000_train.txt", |
|
len_lo: int = 4, len_hi: int = 12, max_len: int = 35): |
|
self.dict_path = dict_path |
|
self.len_lo = len_lo |
|
self.len_hi = len_hi |
|
self.max_len = max_len |
|
|
|
|
|
print("Loading dictionary and building priors...") |
|
self.words = load_dict(dict_path) |
|
self.buckets = by_len(self.words, len_lo, len_hi) |
|
self.priors = build_length_priors(self.buckets) |
|
self.pos_priors = build_positional_priors(self.buckets, max_len) |
|
self.cand_cache = CandCache(100_000) |
|
|
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(f"Using device: {self.device}") |
|
|
|
|
|
self.model_metrics: List[ModelMetrics] = [] |
|
self.game_results: List[GameResult] = [] |
|
|
|
|
|
self.output_dir = Path("simulator_results") |
|
self.output_dir.mkdir(exist_ok=True) |
|
|
|
def load_model(self, model_path: str, model_name: str = None) -> DuelingQNet: |
|
"""Load a trained model from checkpoint""" |
|
if model_name is None: |
|
model_name = Path(model_path).stem |
|
|
|
print(f"Loading model: {model_name} from {model_path}") |
|
|
|
model = DuelingQNet( |
|
d_model=128, nhead=4, nlayers=2, ff_mult=4, |
|
max_len=self.max_len, dropout=0.1 |
|
) |
|
|
|
if os.path.exists(model_path): |
|
checkpoint = torch.load(model_path, map_location=self.device) |
|
model.load_state_dict(checkpoint['model']) |
|
print(f"β
Model loaded successfully") |
|
else: |
|
print(f"β Model file not found: {model_path}") |
|
return None |
|
|
|
model.to(self.device) |
|
return model |
|
|
|
def evaluate_model(self, model: DuelingQNet, model_name: str, |
|
episodes: int = 1000, detailed: bool = False) -> ModelMetrics: |
|
"""Evaluate a model and return comprehensive metrics""" |
|
print(f"\nπ Evaluating {model_name} over {episodes} episodes...") |
|
|
|
|
|
env = BatchEnv(self.buckets, 6, 64, sorted(self.buckets.keys()), self.max_len) |
|
env.reset() |
|
|
|
|
|
start_time = time.time() |
|
win_rate = greedy_rollout(env, model, self.device, N=episodes, |
|
priors=self.priors, log_stride=100) |
|
eval_time = time.time() - start_time |
|
|
|
|
|
avg_turns = 6.0 |
|
avg_reward = win_rate * 1.0 + (1 - win_rate) * (-1.0) |
|
|
|
metrics = ModelMetrics( |
|
model_name=model_name, |
|
win_rate=win_rate, |
|
avg_turns=avg_turns, |
|
avg_reward=avg_reward, |
|
episodes_tested=episodes, |
|
strategy="dqn", |
|
timestamp=time.strftime("%Y-%m-%d %H:%M:%S"), |
|
model_path="", |
|
training_config={} |
|
) |
|
|
|
print(f"β
{model_name}: Win Rate = {win_rate:.3f} ({eval_time:.1f}s)") |
|
return metrics |
|
|
|
def compare_with_baselines(self, model: DuelingQNet, model_name: str, |
|
episodes: int = 1000) -> Dict[str, float]: |
|
"""Compare model performance with heuristic baselines""" |
|
print(f"\nπ Comparing {model_name} with baseline strategies...") |
|
|
|
results = {} |
|
|
|
|
|
dqn_metrics = self.evaluate_model(model, model_name, episodes) |
|
results['dqn'] = dqn_metrics.win_rate |
|
|
|
|
|
strategies = ['cand', 'igx', 'pos', 'len', 'ig'] |
|
|
|
for strategy in strategies: |
|
print(f"Testing {strategy} baseline...") |
|
args = Namespace( |
|
solver_mode=strategy, |
|
tries=6, |
|
batch_env=64, |
|
max_len=self.max_len, |
|
solver_eval_N=episodes, |
|
csv_log=False |
|
) |
|
|
|
win_rate = run_solver(args, self.buckets, self.priors, self.pos_priors) |
|
results[strategy] = win_rate |
|
print(f" {strategy}: {win_rate:.3f}") |
|
|
|
return results |
|
|
|
def analyze_model_decisions(self, model: DuelingQNet, model_name: str, |
|
num_games: int = 10) -> List[GameResult]: |
|
"""Analyze individual game decisions for debugging""" |
|
print(f"\n㪠Analyzing {model_name} decisions in {num_games} games...") |
|
|
|
model.eval() |
|
if hasattr(model, "remove_noise"): |
|
model.remove_noise() |
|
|
|
results = [] |
|
env = BatchEnv(self.buckets, 6, 1, sorted(self.buckets.keys()), self.max_len) |
|
|
|
for game_idx in range(num_games): |
|
env.reset() |
|
word = env.words[0] |
|
L = len(word) |
|
|
|
game_result = GameResult( |
|
word=word, |
|
word_length=L, |
|
won=False, |
|
turns_taken=0, |
|
final_reward=0.0, |
|
guesses=[], |
|
pattern_history=[], |
|
q_values_history=[] |
|
) |
|
|
|
while not env.done[0]: |
|
|
|
pat_idx, tried, lens, tries = env.observe() |
|
B_now = pat_idx.size(0) |
|
|
|
|
|
lp = torch.zeros((B_now, 26), dtype=torch.float32) |
|
lp[0, :] = torch.tensor(self.priors.get(L, [0.0] * 26)) |
|
|
|
tn = (tries.float() / 6.0).unsqueeze(1) |
|
|
|
|
|
with torch.no_grad(): |
|
q_values = model( |
|
pat_idx.to(self.device), |
|
tried.to(self.device), |
|
lens.to(self.device), |
|
lp.to(self.device), |
|
tn.to(self.device) |
|
) |
|
|
|
|
|
q_vals = q_values[0].cpu().numpy() |
|
action = int(q_vals.argmax()) |
|
letter = chr(ord('a') + action) |
|
|
|
|
|
q_dict = {chr(ord('a') + i): float(q_vals[i]) for i in range(26)} |
|
game_result.q_values_history.append(q_dict) |
|
game_result.guesses.append(letter) |
|
game_result.pattern_history.append(env.patterns[0]) |
|
|
|
|
|
reward = env.step(torch.tensor([action]))[0].item() |
|
game_result.turns_taken += 1 |
|
game_result.final_reward = reward |
|
game_result.won = bool(env.won[0].item()) |
|
|
|
results.append(game_result) |
|
print(f" Game {game_idx+1}: {word} -> {'WON' if game_result.won else 'LOST'} " |
|
f"({game_result.turns_taken} turns)") |
|
|
|
model.train() |
|
if hasattr(model, "resample_noise"): |
|
model.resample_noise() |
|
|
|
return results |
|
|
|
def create_performance_report(self, results: Dict[str, float], |
|
model_name: str) -> str: |
|
"""Create a comprehensive performance report""" |
|
report = f""" |
|
# Hangman DQN Performance Report |
|
**Model**: {model_name} |
|
**Timestamp**: {time.strftime("%Y-%m-%d %H:%M:%S")} |
|
**Episodes Tested**: 1000 |
|
|
|
## Performance Comparison |
|
|
|
| Strategy | Win Rate | Performance vs DQN | |
|
|----------|----------|-------------------|""" |
|
|
|
dqn_rate = results.get('dqn', 0.0) |
|
|
|
for strategy, rate in results.items(): |
|
if strategy == 'dqn': |
|
continue |
|
diff = rate - dqn_rate |
|
diff_pct = (diff / max(dqn_rate, 0.001)) * 100 |
|
report += f"\n| {strategy.upper()} | {rate:.3f} | {diff:+.3f} ({diff_pct:+.1f}%)" |
|
|
|
report += f"\n| **DQN** | **{dqn_rate:.3f}** | **baseline** |" |
|
|
|
|
|
best_baseline = max([(k, v) for k, v in results.items() if k != 'dqn'], |
|
key=lambda x: x[1]) |
|
|
|
report += f""" |
|
|
|
## Analysis |
|
- **Best Baseline**: {best_baseline[0].upper()} ({best_baseline[1]:.3f} win rate) |
|
- **DQN Performance**: {dqn_rate:.3f} win rate |
|
- **Gap to Best**: {best_baseline[1] - dqn_rate:.3f} ({(best_baseline[1] - dqn_rate)/best_baseline[1]*100:.1f}% behind) |
|
|
|
## Recommendations |
|
""" |
|
|
|
if dqn_rate < 0.1: |
|
report += "- β **Critical**: DQN performance is extremely poor. Check training data quality and model architecture.\n" |
|
elif dqn_rate < best_baseline[1] * 0.5: |
|
report += "- β οΈ **Poor**: DQN significantly underperforms best baseline. Consider retraining with better teacher strategy.\n" |
|
elif dqn_rate < best_baseline[1] * 0.8: |
|
report += "- πΆ **Fair**: DQN shows promise but needs improvement. Fine-tune training parameters.\n" |
|
else: |
|
report += "- β
**Good**: DQN performance is competitive with baselines.\n" |
|
|
|
if best_baseline[0] == 'cand': |
|
report += "- π― **Priority**: Use 'cand' strategy as teacher for retraining (85%+ win rate)\n" |
|
|
|
return report |
|
|
|
def save_results(self, results: Dict[str, float], model_name: str, |
|
game_analysis: List[GameResult] = None): |
|
"""Save all results to files""" |
|
timestamp = time.strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
results_file = self.output_dir / f"performance_{model_name}_{timestamp}.json" |
|
with open(results_file, 'w') as f: |
|
json.dump({ |
|
'model_name': model_name, |
|
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"), |
|
'results': results, |
|
'config': { |
|
'dict_path': self.dict_path, |
|
'len_lo': self.len_lo, |
|
'len_hi': self.len_hi, |
|
'max_len': self.max_len |
|
} |
|
}, f, indent=2) |
|
|
|
|
|
report = self.create_performance_report(results, model_name) |
|
report_file = self.output_dir / f"report_{model_name}_{timestamp}.md" |
|
with open(report_file, 'w') as f: |
|
f.write(report) |
|
|
|
|
|
if game_analysis: |
|
analysis_file = self.output_dir / f"analysis_{model_name}_{timestamp}.json" |
|
with open(analysis_file, 'w') as f: |
|
json.dump([asdict(result) for result in game_analysis], f, indent=2) |
|
|
|
print(f"\nπΎ Results saved:") |
|
print(f" - Performance: {results_file}") |
|
print(f" - Report: {report_file}") |
|
if game_analysis: |
|
print(f" - Analysis: {analysis_file}") |
|
|
|
def create_visualization(self, results: Dict[str, float], model_name: str): |
|
"""Create performance visualization""" |
|
if not HAS_PLOTTING: |
|
print("β οΈ Visualization skipped - matplotlib not available") |
|
return |
|
|
|
plt.figure(figsize=(12, 8)) |
|
|
|
|
|
plt.subplot(2, 2, 1) |
|
strategies = list(results.keys()) |
|
rates = list(results.values()) |
|
colors = ['red' if s == 'dqn' else 'blue' for s in strategies] |
|
|
|
bars = plt.bar(strategies, rates, color=colors, alpha=0.7) |
|
plt.title(f'Win Rate Comparison - {model_name}') |
|
plt.ylabel('Win Rate') |
|
plt.xticks(rotation=45) |
|
|
|
|
|
for i, (strategy, rate) in enumerate(zip(strategies, rates)): |
|
if strategy == 'dqn': |
|
bars[i].set_color('red') |
|
bars[i].set_alpha(1.0) |
|
|
|
|
|
plt.subplot(2, 2, 2) |
|
dqn_rate = results.get('dqn', 0.0) |
|
gaps = [rate - dqn_rate for rate in rates] |
|
colors = ['red' if gap < 0 else 'green' for gap in gaps] |
|
|
|
plt.bar(strategies, gaps, color=colors, alpha=0.7) |
|
plt.title('Performance Gap vs DQN') |
|
plt.ylabel('Win Rate Difference') |
|
plt.xticks(rotation=45) |
|
plt.axhline(y=0, color='black', linestyle='--', alpha=0.5) |
|
|
|
|
|
plt.subplot(2, 2, 3) |
|
word_lengths = sorted(self.buckets.keys()) |
|
word_counts = [len(self.buckets[L]) for L in word_lengths] |
|
plt.bar(word_lengths, word_counts, alpha=0.7) |
|
plt.title('Word Distribution by Length') |
|
plt.xlabel('Word Length') |
|
plt.ylabel('Count') |
|
|
|
|
|
plt.subplot(2, 2, 4) |
|
plt.text(0.5, 0.5, 'Performance Trend\n(Coming Soon)', |
|
ha='center', va='center', fontsize=12) |
|
plt.title('Training Progress') |
|
plt.axis('off') |
|
|
|
plt.tight_layout() |
|
|
|
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S") |
|
plot_file = self.output_dir / f"visualization_{model_name}_{timestamp}.png" |
|
plt.savefig(plot_file, dpi=300, bbox_inches='tight') |
|
plt.show() |
|
|
|
print(f"π Visualization saved: {plot_file}") |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Hangman DQN Simulator Playground") |
|
parser.add_argument("--model", type=str, required=True, help="Path to model checkpoint") |
|
parser.add_argument("--model-name", type=str, help="Name for the model (default: filename)") |
|
parser.add_argument("--episodes", type=int, default=1000, help="Number of episodes to test") |
|
parser.add_argument("--detailed", action="store_true", help="Run detailed analysis") |
|
parser.add_argument("--analyze-decisions", action="store_true", help="Analyze individual game decisions") |
|
parser.add_argument("--num-games", type=int, default=10, help="Number of games for decision analysis") |
|
parser.add_argument("--dict-path", type=str, default="data/words_250000_train.txt", help="Dictionary path") |
|
parser.add_argument("--len-lo", type=int, default=4, help="Minimum word length") |
|
parser.add_argument("--len-hi", type=int, default=12, help="Maximum word length") |
|
parser.add_argument("--max-len", type=int, default=35, help="Model max sequence length") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
simulator = HangmanSimulator( |
|
dict_path=args.dict_path, |
|
len_lo=args.len_lo, |
|
len_hi=args.len_hi, |
|
max_len=args.max_len |
|
) |
|
|
|
|
|
model_name = args.model_name or Path(args.model).stem |
|
model = simulator.load_model(args.model, model_name) |
|
|
|
if model is None: |
|
print("β Failed to load model. Exiting.") |
|
return |
|
|
|
|
|
print(f"\nπ Starting comprehensive evaluation of {model_name}...") |
|
|
|
|
|
results = simulator.compare_with_baselines(model, model_name, args.episodes) |
|
|
|
|
|
game_analysis = None |
|
if args.analyze_decisions: |
|
game_analysis = simulator.analyze_model_decisions(model, model_name, args.num_games) |
|
|
|
|
|
simulator.save_results(results, model_name, game_analysis) |
|
|
|
|
|
simulator.create_visualization(results, model_name) |
|
|
|
|
|
print(f"\nπ Summary for {model_name}:") |
|
print(f" DQN Win Rate: {results.get('dqn', 0.0):.3f}") |
|
best_baseline = max([(k, v) for k, v in results.items() if k != 'dqn'], key=lambda x: x[1]) |
|
print(f" Best Baseline: {best_baseline[0]} ({best_baseline[1]:.3f})") |
|
print(f" Gap: {best_baseline[1] - results.get('dqn', 0.0):.3f}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|