egpivo commited on
Commit
eb42ef5
ยท
verified ยท
1 Parent(s): 0c86088

Upload simulator_playground.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. simulator_playground.py +469 -0
simulator_playground.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Hangman DQN Simulator Playground
4
+ A comprehensive tool for monitoring, comparing, and debugging DQN model improvements.
5
+ """
6
+
7
+ import argparse
8
+ import os
9
+ import time
10
+ import json
11
+ import torch
12
+ import numpy as np
13
+ from pathlib import Path
14
+ from typing import Dict, List, Tuple, Optional
15
+ from dataclasses import dataclass, asdict
16
+ try:
17
+ import matplotlib.pyplot as plt
18
+ import seaborn as sns
19
+ HAS_PLOTTING = True
20
+ except ImportError:
21
+ HAS_PLOTTING = False
22
+ print("Warning: matplotlib/seaborn not available. Visualization will be skipped.")
23
+ from collections import defaultdict, Counter
24
+
25
+ # Import hangman modules
26
+ from hangman.rl.models import DuelingQNet, ActorCriticNet
27
+ from hangman.rl.utils import load_dict, by_len, set_seed
28
+ from hangman.rl.priors import build_length_priors, build_positional_priors, CandCache
29
+ from hangman.rl.eval import greedy_rollout, run_solver
30
+ from hangman.rl.envs import BatchEnv
31
+ from hangman.rl.replay import Replay, SuccessReplay
32
+ from hangman.rl.seed_bc import seed_expert
33
+ from argparse import Namespace
34
+
35
+
36
+ @dataclass
37
+ class ModelMetrics:
38
+ """Container for model performance metrics"""
39
+ model_name: str
40
+ win_rate: float
41
+ avg_turns: float
42
+ avg_reward: float
43
+ episodes_tested: int
44
+ strategy: str
45
+ timestamp: str
46
+ model_path: str
47
+ training_config: Dict
48
+
49
+
50
+ @dataclass
51
+ class GameResult:
52
+ """Container for individual game results"""
53
+ word: str
54
+ word_length: int
55
+ won: bool
56
+ turns_taken: int
57
+ final_reward: float
58
+ guesses: List[str]
59
+ pattern_history: List[str]
60
+ q_values_history: List[Dict[str, float]]
61
+
62
+
63
+ class HangmanSimulator:
64
+ """Comprehensive hangman simulator for monitoring DQN improvements"""
65
+
66
+ def __init__(self, dict_path: str = "data/words_250000_train.txt",
67
+ len_lo: int = 4, len_hi: int = 12, max_len: int = 35):
68
+ self.dict_path = dict_path
69
+ self.len_lo = len_lo
70
+ self.len_hi = len_hi
71
+ self.max_len = max_len
72
+
73
+ # Load data
74
+ print("Loading dictionary and building priors...")
75
+ self.words = load_dict(dict_path)
76
+ self.buckets = by_len(self.words, len_lo, len_hi)
77
+ self.priors = build_length_priors(self.buckets)
78
+ self.pos_priors = build_positional_priors(self.buckets, max_len)
79
+ self.cand_cache = CandCache(100_000)
80
+
81
+ # Device setup
82
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
83
+ print(f"Using device: {self.device}")
84
+
85
+ # Results storage
86
+ self.model_metrics: List[ModelMetrics] = []
87
+ self.game_results: List[GameResult] = []
88
+
89
+ # Create output directory
90
+ self.output_dir = Path("simulator_results")
91
+ self.output_dir.mkdir(exist_ok=True)
92
+
93
+ def load_model(self, model_path: str, model_name: str = None) -> DuelingQNet:
94
+ """Load a trained model from checkpoint"""
95
+ if model_name is None:
96
+ model_name = Path(model_path).stem
97
+
98
+ print(f"Loading model: {model_name} from {model_path}")
99
+
100
+ model = DuelingQNet(
101
+ d_model=128, nhead=4, nlayers=2, ff_mult=4,
102
+ max_len=self.max_len, dropout=0.1
103
+ )
104
+
105
+ if os.path.exists(model_path):
106
+ checkpoint = torch.load(model_path, map_location=self.device)
107
+ model.load_state_dict(checkpoint['model'])
108
+ print(f"โœ… Model loaded successfully")
109
+ else:
110
+ print(f"โŒ Model file not found: {model_path}")
111
+ return None
112
+
113
+ model.to(self.device)
114
+ return model
115
+
116
+ def evaluate_model(self, model: DuelingQNet, model_name: str,
117
+ episodes: int = 1000, detailed: bool = False) -> ModelMetrics:
118
+ """Evaluate a model and return comprehensive metrics"""
119
+ print(f"\n๐Ÿ” Evaluating {model_name} over {episodes} episodes...")
120
+
121
+ # Setup environment
122
+ env = BatchEnv(self.buckets, 6, 64, sorted(self.buckets.keys()), self.max_len)
123
+ env.reset()
124
+
125
+ # Run evaluation
126
+ start_time = time.time()
127
+ win_rate = greedy_rollout(env, model, self.device, N=episodes,
128
+ priors=self.priors, log_stride=100)
129
+ eval_time = time.time() - start_time
130
+
131
+ # Calculate additional metrics
132
+ avg_turns = 6.0 # Placeholder - would need to track this in evaluation
133
+ avg_reward = win_rate * 1.0 + (1 - win_rate) * (-1.0) # Simplified
134
+
135
+ metrics = ModelMetrics(
136
+ model_name=model_name,
137
+ win_rate=win_rate,
138
+ avg_turns=avg_turns,
139
+ avg_reward=avg_reward,
140
+ episodes_tested=episodes,
141
+ strategy="dqn",
142
+ timestamp=time.strftime("%Y-%m-%d %H:%M:%S"),
143
+ model_path="", # Will be set by caller
144
+ training_config={}
145
+ )
146
+
147
+ print(f"โœ… {model_name}: Win Rate = {win_rate:.3f} ({eval_time:.1f}s)")
148
+ return metrics
149
+
150
+ def compare_with_baselines(self, model: DuelingQNet, model_name: str,
151
+ episodes: int = 1000) -> Dict[str, float]:
152
+ """Compare model performance with heuristic baselines"""
153
+ print(f"\n๐Ÿ“Š Comparing {model_name} with baseline strategies...")
154
+
155
+ results = {}
156
+
157
+ # Evaluate DQN model
158
+ dqn_metrics = self.evaluate_model(model, model_name, episodes)
159
+ results['dqn'] = dqn_metrics.win_rate
160
+
161
+ # Evaluate baseline strategies
162
+ strategies = ['cand', 'igx', 'pos', 'len', 'ig']
163
+
164
+ for strategy in strategies:
165
+ print(f"Testing {strategy} baseline...")
166
+ args = Namespace(
167
+ solver_mode=strategy,
168
+ tries=6,
169
+ batch_env=64,
170
+ max_len=self.max_len,
171
+ solver_eval_N=episodes,
172
+ csv_log=False
173
+ )
174
+
175
+ win_rate = run_solver(args, self.buckets, self.priors, self.pos_priors)
176
+ results[strategy] = win_rate
177
+ print(f" {strategy}: {win_rate:.3f}")
178
+
179
+ return results
180
+
181
+ def analyze_model_decisions(self, model: DuelingQNet, model_name: str,
182
+ num_games: int = 10) -> List[GameResult]:
183
+ """Analyze individual game decisions for debugging"""
184
+ print(f"\n๐Ÿ”ฌ Analyzing {model_name} decisions in {num_games} games...")
185
+
186
+ model.eval()
187
+ if hasattr(model, "remove_noise"):
188
+ model.remove_noise()
189
+
190
+ results = []
191
+ env = BatchEnv(self.buckets, 6, 1, sorted(self.buckets.keys()), self.max_len)
192
+
193
+ for game_idx in range(num_games):
194
+ env.reset()
195
+ word = env.words[0]
196
+ L = len(word)
197
+
198
+ game_result = GameResult(
199
+ word=word,
200
+ word_length=L,
201
+ won=False,
202
+ turns_taken=0,
203
+ final_reward=0.0,
204
+ guesses=[],
205
+ pattern_history=[],
206
+ q_values_history=[]
207
+ )
208
+
209
+ while not env.done[0]:
210
+ # Get model prediction
211
+ pat_idx, tried, lens, tries = env.observe()
212
+ B_now = pat_idx.size(0)
213
+
214
+ # Prepare inputs
215
+ lp = torch.zeros((B_now, 26), dtype=torch.float32)
216
+ lp[0, :] = torch.tensor(self.priors.get(L, [0.0] * 26))
217
+
218
+ tn = (tries.float() / 6.0).unsqueeze(1)
219
+
220
+ # Get Q-values
221
+ with torch.no_grad():
222
+ q_values = model(
223
+ pat_idx.to(self.device),
224
+ tried.to(self.device),
225
+ lens.to(self.device),
226
+ lp.to(self.device),
227
+ tn.to(self.device)
228
+ )
229
+
230
+ # Convert Q-values to probabilities for analysis
231
+ q_vals = q_values[0].cpu().numpy()
232
+ action = int(q_vals.argmax())
233
+ letter = chr(ord('a') + action)
234
+
235
+ # Store decision info
236
+ q_dict = {chr(ord('a') + i): float(q_vals[i]) for i in range(26)}
237
+ game_result.q_values_history.append(q_dict)
238
+ game_result.guesses.append(letter)
239
+ game_result.pattern_history.append(env.patterns[0])
240
+
241
+ # Take action
242
+ reward = env.step(torch.tensor([action]))[0].item()
243
+ game_result.turns_taken += 1
244
+ game_result.final_reward = reward
245
+ game_result.won = bool(env.won[0].item())
246
+
247
+ results.append(game_result)
248
+ print(f" Game {game_idx+1}: {word} -> {'WON' if game_result.won else 'LOST'} "
249
+ f"({game_result.turns_taken} turns)")
250
+
251
+ model.train()
252
+ if hasattr(model, "resample_noise"):
253
+ model.resample_noise()
254
+
255
+ return results
256
+
257
+ def create_performance_report(self, results: Dict[str, float],
258
+ model_name: str) -> str:
259
+ """Create a comprehensive performance report"""
260
+ report = f"""
261
+ # Hangman DQN Performance Report
262
+ **Model**: {model_name}
263
+ **Timestamp**: {time.strftime("%Y-%m-%d %H:%M:%S")}
264
+ **Episodes Tested**: 1000
265
+
266
+ ## Performance Comparison
267
+
268
+ | Strategy | Win Rate | Performance vs DQN |
269
+ |----------|----------|-------------------|"""
270
+
271
+ dqn_rate = results.get('dqn', 0.0)
272
+
273
+ for strategy, rate in results.items():
274
+ if strategy == 'dqn':
275
+ continue
276
+ diff = rate - dqn_rate
277
+ diff_pct = (diff / max(dqn_rate, 0.001)) * 100
278
+ report += f"\n| {strategy.upper()} | {rate:.3f} | {diff:+.3f} ({diff_pct:+.1f}%)"
279
+
280
+ report += f"\n| **DQN** | **{dqn_rate:.3f}** | **baseline** |"
281
+
282
+ # Analysis
283
+ best_baseline = max([(k, v) for k, v in results.items() if k != 'dqn'],
284
+ key=lambda x: x[1])
285
+
286
+ report += f"""
287
+
288
+ ## Analysis
289
+ - **Best Baseline**: {best_baseline[0].upper()} ({best_baseline[1]:.3f} win rate)
290
+ - **DQN Performance**: {dqn_rate:.3f} win rate
291
+ - **Gap to Best**: {best_baseline[1] - dqn_rate:.3f} ({(best_baseline[1] - dqn_rate)/best_baseline[1]*100:.1f}% behind)
292
+
293
+ ## Recommendations
294
+ """
295
+
296
+ if dqn_rate < 0.1:
297
+ report += "- โŒ **Critical**: DQN performance is extremely poor. Check training data quality and model architecture.\n"
298
+ elif dqn_rate < best_baseline[1] * 0.5:
299
+ report += "- โš ๏ธ **Poor**: DQN significantly underperforms best baseline. Consider retraining with better teacher strategy.\n"
300
+ elif dqn_rate < best_baseline[1] * 0.8:
301
+ report += "- ๐Ÿ”ถ **Fair**: DQN shows promise but needs improvement. Fine-tune training parameters.\n"
302
+ else:
303
+ report += "- โœ… **Good**: DQN performance is competitive with baselines.\n"
304
+
305
+ if best_baseline[0] == 'cand':
306
+ report += "- ๐ŸŽฏ **Priority**: Use 'cand' strategy as teacher for retraining (85%+ win rate)\n"
307
+
308
+ return report
309
+
310
+ def save_results(self, results: Dict[str, float], model_name: str,
311
+ game_analysis: List[GameResult] = None):
312
+ """Save all results to files"""
313
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
314
+
315
+ # Save performance comparison
316
+ results_file = self.output_dir / f"performance_{model_name}_{timestamp}.json"
317
+ with open(results_file, 'w') as f:
318
+ json.dump({
319
+ 'model_name': model_name,
320
+ 'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
321
+ 'results': results,
322
+ 'config': {
323
+ 'dict_path': self.dict_path,
324
+ 'len_lo': self.len_lo,
325
+ 'len_hi': self.len_hi,
326
+ 'max_len': self.max_len
327
+ }
328
+ }, f, indent=2)
329
+
330
+ # Save performance report
331
+ report = self.create_performance_report(results, model_name)
332
+ report_file = self.output_dir / f"report_{model_name}_{timestamp}.md"
333
+ with open(report_file, 'w') as f:
334
+ f.write(report)
335
+
336
+ # Save game analysis if provided
337
+ if game_analysis:
338
+ analysis_file = self.output_dir / f"analysis_{model_name}_{timestamp}.json"
339
+ with open(analysis_file, 'w') as f:
340
+ json.dump([asdict(result) for result in game_analysis], f, indent=2)
341
+
342
+ print(f"\n๐Ÿ’พ Results saved:")
343
+ print(f" - Performance: {results_file}")
344
+ print(f" - Report: {report_file}")
345
+ if game_analysis:
346
+ print(f" - Analysis: {analysis_file}")
347
+
348
+ def create_visualization(self, results: Dict[str, float], model_name: str):
349
+ """Create performance visualization"""
350
+ if not HAS_PLOTTING:
351
+ print("โš ๏ธ Visualization skipped - matplotlib not available")
352
+ return
353
+
354
+ plt.figure(figsize=(12, 8))
355
+
356
+ # Performance comparison bar chart
357
+ plt.subplot(2, 2, 1)
358
+ strategies = list(results.keys())
359
+ rates = list(results.values())
360
+ colors = ['red' if s == 'dqn' else 'blue' for s in strategies]
361
+
362
+ bars = plt.bar(strategies, rates, color=colors, alpha=0.7)
363
+ plt.title(f'Win Rate Comparison - {model_name}')
364
+ plt.ylabel('Win Rate')
365
+ plt.xticks(rotation=45)
366
+
367
+ # Highlight DQN
368
+ for i, (strategy, rate) in enumerate(zip(strategies, rates)):
369
+ if strategy == 'dqn':
370
+ bars[i].set_color('red')
371
+ bars[i].set_alpha(1.0)
372
+
373
+ # Performance gap analysis
374
+ plt.subplot(2, 2, 2)
375
+ dqn_rate = results.get('dqn', 0.0)
376
+ gaps = [rate - dqn_rate for rate in rates]
377
+ colors = ['red' if gap < 0 else 'green' for gap in gaps]
378
+
379
+ plt.bar(strategies, gaps, color=colors, alpha=0.7)
380
+ plt.title('Performance Gap vs DQN')
381
+ plt.ylabel('Win Rate Difference')
382
+ plt.xticks(rotation=45)
383
+ plt.axhline(y=0, color='black', linestyle='--', alpha=0.5)
384
+
385
+ # Word length analysis (placeholder)
386
+ plt.subplot(2, 2, 3)
387
+ word_lengths = sorted(self.buckets.keys())
388
+ word_counts = [len(self.buckets[L]) for L in word_lengths]
389
+ plt.bar(word_lengths, word_counts, alpha=0.7)
390
+ plt.title('Word Distribution by Length')
391
+ plt.xlabel('Word Length')
392
+ plt.ylabel('Count')
393
+
394
+ # Performance trend (placeholder for future use)
395
+ plt.subplot(2, 2, 4)
396
+ plt.text(0.5, 0.5, 'Performance Trend\n(Coming Soon)',
397
+ ha='center', va='center', fontsize=12)
398
+ plt.title('Training Progress')
399
+ plt.axis('off')
400
+
401
+ plt.tight_layout()
402
+
403
+ # Save plot
404
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
405
+ plot_file = self.output_dir / f"visualization_{model_name}_{timestamp}.png"
406
+ plt.savefig(plot_file, dpi=300, bbox_inches='tight')
407
+ plt.show()
408
+
409
+ print(f"๐Ÿ“Š Visualization saved: {plot_file}")
410
+
411
+
412
+ def main():
413
+ parser = argparse.ArgumentParser(description="Hangman DQN Simulator Playground")
414
+ parser.add_argument("--model", type=str, required=True, help="Path to model checkpoint")
415
+ parser.add_argument("--model-name", type=str, help="Name for the model (default: filename)")
416
+ parser.add_argument("--episodes", type=int, default=1000, help="Number of episodes to test")
417
+ parser.add_argument("--detailed", action="store_true", help="Run detailed analysis")
418
+ parser.add_argument("--analyze-decisions", action="store_true", help="Analyze individual game decisions")
419
+ parser.add_argument("--num-games", type=int, default=10, help="Number of games for decision analysis")
420
+ parser.add_argument("--dict-path", type=str, default="data/words_250000_train.txt", help="Dictionary path")
421
+ parser.add_argument("--len-lo", type=int, default=4, help="Minimum word length")
422
+ parser.add_argument("--len-hi", type=int, default=12, help="Maximum word length")
423
+ parser.add_argument("--max-len", type=int, default=35, help="Model max sequence length")
424
+
425
+ args = parser.parse_args()
426
+
427
+ # Initialize simulator
428
+ simulator = HangmanSimulator(
429
+ dict_path=args.dict_path,
430
+ len_lo=args.len_lo,
431
+ len_hi=args.len_hi,
432
+ max_len=args.max_len
433
+ )
434
+
435
+ # Load model
436
+ model_name = args.model_name or Path(args.model).stem
437
+ model = simulator.load_model(args.model, model_name)
438
+
439
+ if model is None:
440
+ print("โŒ Failed to load model. Exiting.")
441
+ return
442
+
443
+ # Run comprehensive evaluation
444
+ print(f"\n๐Ÿš€ Starting comprehensive evaluation of {model_name}...")
445
+
446
+ # Compare with baselines
447
+ results = simulator.compare_with_baselines(model, model_name, args.episodes)
448
+
449
+ # Detailed analysis if requested
450
+ game_analysis = None
451
+ if args.analyze_decisions:
452
+ game_analysis = simulator.analyze_model_decisions(model, model_name, args.num_games)
453
+
454
+ # Save results
455
+ simulator.save_results(results, model_name, game_analysis)
456
+
457
+ # Create visualization
458
+ simulator.create_visualization(results, model_name)
459
+
460
+ # Print summary
461
+ print(f"\n๐Ÿ“‹ Summary for {model_name}:")
462
+ print(f" DQN Win Rate: {results.get('dqn', 0.0):.3f}")
463
+ best_baseline = max([(k, v) for k, v in results.items() if k != 'dqn'], key=lambda x: x[1])
464
+ print(f" Best Baseline: {best_baseline[0]} ({best_baseline[1]:.3f})")
465
+ print(f" Gap: {best_baseline[1] - results.get('dqn', 0.0):.3f}")
466
+
467
+
468
+ if __name__ == "__main__":
469
+ main()