Upload simulator_playground.py with huggingface_hub
Browse files- simulator_playground.py +469 -0
simulator_playground.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Hangman DQN Simulator Playground
|
4 |
+
A comprehensive tool for monitoring, comparing, and debugging DQN model improvements.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
import json
|
11 |
+
import torch
|
12 |
+
import numpy as np
|
13 |
+
from pathlib import Path
|
14 |
+
from typing import Dict, List, Tuple, Optional
|
15 |
+
from dataclasses import dataclass, asdict
|
16 |
+
try:
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
import seaborn as sns
|
19 |
+
HAS_PLOTTING = True
|
20 |
+
except ImportError:
|
21 |
+
HAS_PLOTTING = False
|
22 |
+
print("Warning: matplotlib/seaborn not available. Visualization will be skipped.")
|
23 |
+
from collections import defaultdict, Counter
|
24 |
+
|
25 |
+
# Import hangman modules
|
26 |
+
from hangman.rl.models import DuelingQNet, ActorCriticNet
|
27 |
+
from hangman.rl.utils import load_dict, by_len, set_seed
|
28 |
+
from hangman.rl.priors import build_length_priors, build_positional_priors, CandCache
|
29 |
+
from hangman.rl.eval import greedy_rollout, run_solver
|
30 |
+
from hangman.rl.envs import BatchEnv
|
31 |
+
from hangman.rl.replay import Replay, SuccessReplay
|
32 |
+
from hangman.rl.seed_bc import seed_expert
|
33 |
+
from argparse import Namespace
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class ModelMetrics:
|
38 |
+
"""Container for model performance metrics"""
|
39 |
+
model_name: str
|
40 |
+
win_rate: float
|
41 |
+
avg_turns: float
|
42 |
+
avg_reward: float
|
43 |
+
episodes_tested: int
|
44 |
+
strategy: str
|
45 |
+
timestamp: str
|
46 |
+
model_path: str
|
47 |
+
training_config: Dict
|
48 |
+
|
49 |
+
|
50 |
+
@dataclass
|
51 |
+
class GameResult:
|
52 |
+
"""Container for individual game results"""
|
53 |
+
word: str
|
54 |
+
word_length: int
|
55 |
+
won: bool
|
56 |
+
turns_taken: int
|
57 |
+
final_reward: float
|
58 |
+
guesses: List[str]
|
59 |
+
pattern_history: List[str]
|
60 |
+
q_values_history: List[Dict[str, float]]
|
61 |
+
|
62 |
+
|
63 |
+
class HangmanSimulator:
|
64 |
+
"""Comprehensive hangman simulator for monitoring DQN improvements"""
|
65 |
+
|
66 |
+
def __init__(self, dict_path: str = "data/words_250000_train.txt",
|
67 |
+
len_lo: int = 4, len_hi: int = 12, max_len: int = 35):
|
68 |
+
self.dict_path = dict_path
|
69 |
+
self.len_lo = len_lo
|
70 |
+
self.len_hi = len_hi
|
71 |
+
self.max_len = max_len
|
72 |
+
|
73 |
+
# Load data
|
74 |
+
print("Loading dictionary and building priors...")
|
75 |
+
self.words = load_dict(dict_path)
|
76 |
+
self.buckets = by_len(self.words, len_lo, len_hi)
|
77 |
+
self.priors = build_length_priors(self.buckets)
|
78 |
+
self.pos_priors = build_positional_priors(self.buckets, max_len)
|
79 |
+
self.cand_cache = CandCache(100_000)
|
80 |
+
|
81 |
+
# Device setup
|
82 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
83 |
+
print(f"Using device: {self.device}")
|
84 |
+
|
85 |
+
# Results storage
|
86 |
+
self.model_metrics: List[ModelMetrics] = []
|
87 |
+
self.game_results: List[GameResult] = []
|
88 |
+
|
89 |
+
# Create output directory
|
90 |
+
self.output_dir = Path("simulator_results")
|
91 |
+
self.output_dir.mkdir(exist_ok=True)
|
92 |
+
|
93 |
+
def load_model(self, model_path: str, model_name: str = None) -> DuelingQNet:
|
94 |
+
"""Load a trained model from checkpoint"""
|
95 |
+
if model_name is None:
|
96 |
+
model_name = Path(model_path).stem
|
97 |
+
|
98 |
+
print(f"Loading model: {model_name} from {model_path}")
|
99 |
+
|
100 |
+
model = DuelingQNet(
|
101 |
+
d_model=128, nhead=4, nlayers=2, ff_mult=4,
|
102 |
+
max_len=self.max_len, dropout=0.1
|
103 |
+
)
|
104 |
+
|
105 |
+
if os.path.exists(model_path):
|
106 |
+
checkpoint = torch.load(model_path, map_location=self.device)
|
107 |
+
model.load_state_dict(checkpoint['model'])
|
108 |
+
print(f"โ
Model loaded successfully")
|
109 |
+
else:
|
110 |
+
print(f"โ Model file not found: {model_path}")
|
111 |
+
return None
|
112 |
+
|
113 |
+
model.to(self.device)
|
114 |
+
return model
|
115 |
+
|
116 |
+
def evaluate_model(self, model: DuelingQNet, model_name: str,
|
117 |
+
episodes: int = 1000, detailed: bool = False) -> ModelMetrics:
|
118 |
+
"""Evaluate a model and return comprehensive metrics"""
|
119 |
+
print(f"\n๐ Evaluating {model_name} over {episodes} episodes...")
|
120 |
+
|
121 |
+
# Setup environment
|
122 |
+
env = BatchEnv(self.buckets, 6, 64, sorted(self.buckets.keys()), self.max_len)
|
123 |
+
env.reset()
|
124 |
+
|
125 |
+
# Run evaluation
|
126 |
+
start_time = time.time()
|
127 |
+
win_rate = greedy_rollout(env, model, self.device, N=episodes,
|
128 |
+
priors=self.priors, log_stride=100)
|
129 |
+
eval_time = time.time() - start_time
|
130 |
+
|
131 |
+
# Calculate additional metrics
|
132 |
+
avg_turns = 6.0 # Placeholder - would need to track this in evaluation
|
133 |
+
avg_reward = win_rate * 1.0 + (1 - win_rate) * (-1.0) # Simplified
|
134 |
+
|
135 |
+
metrics = ModelMetrics(
|
136 |
+
model_name=model_name,
|
137 |
+
win_rate=win_rate,
|
138 |
+
avg_turns=avg_turns,
|
139 |
+
avg_reward=avg_reward,
|
140 |
+
episodes_tested=episodes,
|
141 |
+
strategy="dqn",
|
142 |
+
timestamp=time.strftime("%Y-%m-%d %H:%M:%S"),
|
143 |
+
model_path="", # Will be set by caller
|
144 |
+
training_config={}
|
145 |
+
)
|
146 |
+
|
147 |
+
print(f"โ
{model_name}: Win Rate = {win_rate:.3f} ({eval_time:.1f}s)")
|
148 |
+
return metrics
|
149 |
+
|
150 |
+
def compare_with_baselines(self, model: DuelingQNet, model_name: str,
|
151 |
+
episodes: int = 1000) -> Dict[str, float]:
|
152 |
+
"""Compare model performance with heuristic baselines"""
|
153 |
+
print(f"\n๐ Comparing {model_name} with baseline strategies...")
|
154 |
+
|
155 |
+
results = {}
|
156 |
+
|
157 |
+
# Evaluate DQN model
|
158 |
+
dqn_metrics = self.evaluate_model(model, model_name, episodes)
|
159 |
+
results['dqn'] = dqn_metrics.win_rate
|
160 |
+
|
161 |
+
# Evaluate baseline strategies
|
162 |
+
strategies = ['cand', 'igx', 'pos', 'len', 'ig']
|
163 |
+
|
164 |
+
for strategy in strategies:
|
165 |
+
print(f"Testing {strategy} baseline...")
|
166 |
+
args = Namespace(
|
167 |
+
solver_mode=strategy,
|
168 |
+
tries=6,
|
169 |
+
batch_env=64,
|
170 |
+
max_len=self.max_len,
|
171 |
+
solver_eval_N=episodes,
|
172 |
+
csv_log=False
|
173 |
+
)
|
174 |
+
|
175 |
+
win_rate = run_solver(args, self.buckets, self.priors, self.pos_priors)
|
176 |
+
results[strategy] = win_rate
|
177 |
+
print(f" {strategy}: {win_rate:.3f}")
|
178 |
+
|
179 |
+
return results
|
180 |
+
|
181 |
+
def analyze_model_decisions(self, model: DuelingQNet, model_name: str,
|
182 |
+
num_games: int = 10) -> List[GameResult]:
|
183 |
+
"""Analyze individual game decisions for debugging"""
|
184 |
+
print(f"\n๐ฌ Analyzing {model_name} decisions in {num_games} games...")
|
185 |
+
|
186 |
+
model.eval()
|
187 |
+
if hasattr(model, "remove_noise"):
|
188 |
+
model.remove_noise()
|
189 |
+
|
190 |
+
results = []
|
191 |
+
env = BatchEnv(self.buckets, 6, 1, sorted(self.buckets.keys()), self.max_len)
|
192 |
+
|
193 |
+
for game_idx in range(num_games):
|
194 |
+
env.reset()
|
195 |
+
word = env.words[0]
|
196 |
+
L = len(word)
|
197 |
+
|
198 |
+
game_result = GameResult(
|
199 |
+
word=word,
|
200 |
+
word_length=L,
|
201 |
+
won=False,
|
202 |
+
turns_taken=0,
|
203 |
+
final_reward=0.0,
|
204 |
+
guesses=[],
|
205 |
+
pattern_history=[],
|
206 |
+
q_values_history=[]
|
207 |
+
)
|
208 |
+
|
209 |
+
while not env.done[0]:
|
210 |
+
# Get model prediction
|
211 |
+
pat_idx, tried, lens, tries = env.observe()
|
212 |
+
B_now = pat_idx.size(0)
|
213 |
+
|
214 |
+
# Prepare inputs
|
215 |
+
lp = torch.zeros((B_now, 26), dtype=torch.float32)
|
216 |
+
lp[0, :] = torch.tensor(self.priors.get(L, [0.0] * 26))
|
217 |
+
|
218 |
+
tn = (tries.float() / 6.0).unsqueeze(1)
|
219 |
+
|
220 |
+
# Get Q-values
|
221 |
+
with torch.no_grad():
|
222 |
+
q_values = model(
|
223 |
+
pat_idx.to(self.device),
|
224 |
+
tried.to(self.device),
|
225 |
+
lens.to(self.device),
|
226 |
+
lp.to(self.device),
|
227 |
+
tn.to(self.device)
|
228 |
+
)
|
229 |
+
|
230 |
+
# Convert Q-values to probabilities for analysis
|
231 |
+
q_vals = q_values[0].cpu().numpy()
|
232 |
+
action = int(q_vals.argmax())
|
233 |
+
letter = chr(ord('a') + action)
|
234 |
+
|
235 |
+
# Store decision info
|
236 |
+
q_dict = {chr(ord('a') + i): float(q_vals[i]) for i in range(26)}
|
237 |
+
game_result.q_values_history.append(q_dict)
|
238 |
+
game_result.guesses.append(letter)
|
239 |
+
game_result.pattern_history.append(env.patterns[0])
|
240 |
+
|
241 |
+
# Take action
|
242 |
+
reward = env.step(torch.tensor([action]))[0].item()
|
243 |
+
game_result.turns_taken += 1
|
244 |
+
game_result.final_reward = reward
|
245 |
+
game_result.won = bool(env.won[0].item())
|
246 |
+
|
247 |
+
results.append(game_result)
|
248 |
+
print(f" Game {game_idx+1}: {word} -> {'WON' if game_result.won else 'LOST'} "
|
249 |
+
f"({game_result.turns_taken} turns)")
|
250 |
+
|
251 |
+
model.train()
|
252 |
+
if hasattr(model, "resample_noise"):
|
253 |
+
model.resample_noise()
|
254 |
+
|
255 |
+
return results
|
256 |
+
|
257 |
+
def create_performance_report(self, results: Dict[str, float],
|
258 |
+
model_name: str) -> str:
|
259 |
+
"""Create a comprehensive performance report"""
|
260 |
+
report = f"""
|
261 |
+
# Hangman DQN Performance Report
|
262 |
+
**Model**: {model_name}
|
263 |
+
**Timestamp**: {time.strftime("%Y-%m-%d %H:%M:%S")}
|
264 |
+
**Episodes Tested**: 1000
|
265 |
+
|
266 |
+
## Performance Comparison
|
267 |
+
|
268 |
+
| Strategy | Win Rate | Performance vs DQN |
|
269 |
+
|----------|----------|-------------------|"""
|
270 |
+
|
271 |
+
dqn_rate = results.get('dqn', 0.0)
|
272 |
+
|
273 |
+
for strategy, rate in results.items():
|
274 |
+
if strategy == 'dqn':
|
275 |
+
continue
|
276 |
+
diff = rate - dqn_rate
|
277 |
+
diff_pct = (diff / max(dqn_rate, 0.001)) * 100
|
278 |
+
report += f"\n| {strategy.upper()} | {rate:.3f} | {diff:+.3f} ({diff_pct:+.1f}%)"
|
279 |
+
|
280 |
+
report += f"\n| **DQN** | **{dqn_rate:.3f}** | **baseline** |"
|
281 |
+
|
282 |
+
# Analysis
|
283 |
+
best_baseline = max([(k, v) for k, v in results.items() if k != 'dqn'],
|
284 |
+
key=lambda x: x[1])
|
285 |
+
|
286 |
+
report += f"""
|
287 |
+
|
288 |
+
## Analysis
|
289 |
+
- **Best Baseline**: {best_baseline[0].upper()} ({best_baseline[1]:.3f} win rate)
|
290 |
+
- **DQN Performance**: {dqn_rate:.3f} win rate
|
291 |
+
- **Gap to Best**: {best_baseline[1] - dqn_rate:.3f} ({(best_baseline[1] - dqn_rate)/best_baseline[1]*100:.1f}% behind)
|
292 |
+
|
293 |
+
## Recommendations
|
294 |
+
"""
|
295 |
+
|
296 |
+
if dqn_rate < 0.1:
|
297 |
+
report += "- โ **Critical**: DQN performance is extremely poor. Check training data quality and model architecture.\n"
|
298 |
+
elif dqn_rate < best_baseline[1] * 0.5:
|
299 |
+
report += "- โ ๏ธ **Poor**: DQN significantly underperforms best baseline. Consider retraining with better teacher strategy.\n"
|
300 |
+
elif dqn_rate < best_baseline[1] * 0.8:
|
301 |
+
report += "- ๐ถ **Fair**: DQN shows promise but needs improvement. Fine-tune training parameters.\n"
|
302 |
+
else:
|
303 |
+
report += "- โ
**Good**: DQN performance is competitive with baselines.\n"
|
304 |
+
|
305 |
+
if best_baseline[0] == 'cand':
|
306 |
+
report += "- ๐ฏ **Priority**: Use 'cand' strategy as teacher for retraining (85%+ win rate)\n"
|
307 |
+
|
308 |
+
return report
|
309 |
+
|
310 |
+
def save_results(self, results: Dict[str, float], model_name: str,
|
311 |
+
game_analysis: List[GameResult] = None):
|
312 |
+
"""Save all results to files"""
|
313 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
314 |
+
|
315 |
+
# Save performance comparison
|
316 |
+
results_file = self.output_dir / f"performance_{model_name}_{timestamp}.json"
|
317 |
+
with open(results_file, 'w') as f:
|
318 |
+
json.dump({
|
319 |
+
'model_name': model_name,
|
320 |
+
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
|
321 |
+
'results': results,
|
322 |
+
'config': {
|
323 |
+
'dict_path': self.dict_path,
|
324 |
+
'len_lo': self.len_lo,
|
325 |
+
'len_hi': self.len_hi,
|
326 |
+
'max_len': self.max_len
|
327 |
+
}
|
328 |
+
}, f, indent=2)
|
329 |
+
|
330 |
+
# Save performance report
|
331 |
+
report = self.create_performance_report(results, model_name)
|
332 |
+
report_file = self.output_dir / f"report_{model_name}_{timestamp}.md"
|
333 |
+
with open(report_file, 'w') as f:
|
334 |
+
f.write(report)
|
335 |
+
|
336 |
+
# Save game analysis if provided
|
337 |
+
if game_analysis:
|
338 |
+
analysis_file = self.output_dir / f"analysis_{model_name}_{timestamp}.json"
|
339 |
+
with open(analysis_file, 'w') as f:
|
340 |
+
json.dump([asdict(result) for result in game_analysis], f, indent=2)
|
341 |
+
|
342 |
+
print(f"\n๐พ Results saved:")
|
343 |
+
print(f" - Performance: {results_file}")
|
344 |
+
print(f" - Report: {report_file}")
|
345 |
+
if game_analysis:
|
346 |
+
print(f" - Analysis: {analysis_file}")
|
347 |
+
|
348 |
+
def create_visualization(self, results: Dict[str, float], model_name: str):
|
349 |
+
"""Create performance visualization"""
|
350 |
+
if not HAS_PLOTTING:
|
351 |
+
print("โ ๏ธ Visualization skipped - matplotlib not available")
|
352 |
+
return
|
353 |
+
|
354 |
+
plt.figure(figsize=(12, 8))
|
355 |
+
|
356 |
+
# Performance comparison bar chart
|
357 |
+
plt.subplot(2, 2, 1)
|
358 |
+
strategies = list(results.keys())
|
359 |
+
rates = list(results.values())
|
360 |
+
colors = ['red' if s == 'dqn' else 'blue' for s in strategies]
|
361 |
+
|
362 |
+
bars = plt.bar(strategies, rates, color=colors, alpha=0.7)
|
363 |
+
plt.title(f'Win Rate Comparison - {model_name}')
|
364 |
+
plt.ylabel('Win Rate')
|
365 |
+
plt.xticks(rotation=45)
|
366 |
+
|
367 |
+
# Highlight DQN
|
368 |
+
for i, (strategy, rate) in enumerate(zip(strategies, rates)):
|
369 |
+
if strategy == 'dqn':
|
370 |
+
bars[i].set_color('red')
|
371 |
+
bars[i].set_alpha(1.0)
|
372 |
+
|
373 |
+
# Performance gap analysis
|
374 |
+
plt.subplot(2, 2, 2)
|
375 |
+
dqn_rate = results.get('dqn', 0.0)
|
376 |
+
gaps = [rate - dqn_rate for rate in rates]
|
377 |
+
colors = ['red' if gap < 0 else 'green' for gap in gaps]
|
378 |
+
|
379 |
+
plt.bar(strategies, gaps, color=colors, alpha=0.7)
|
380 |
+
plt.title('Performance Gap vs DQN')
|
381 |
+
plt.ylabel('Win Rate Difference')
|
382 |
+
plt.xticks(rotation=45)
|
383 |
+
plt.axhline(y=0, color='black', linestyle='--', alpha=0.5)
|
384 |
+
|
385 |
+
# Word length analysis (placeholder)
|
386 |
+
plt.subplot(2, 2, 3)
|
387 |
+
word_lengths = sorted(self.buckets.keys())
|
388 |
+
word_counts = [len(self.buckets[L]) for L in word_lengths]
|
389 |
+
plt.bar(word_lengths, word_counts, alpha=0.7)
|
390 |
+
plt.title('Word Distribution by Length')
|
391 |
+
plt.xlabel('Word Length')
|
392 |
+
plt.ylabel('Count')
|
393 |
+
|
394 |
+
# Performance trend (placeholder for future use)
|
395 |
+
plt.subplot(2, 2, 4)
|
396 |
+
plt.text(0.5, 0.5, 'Performance Trend\n(Coming Soon)',
|
397 |
+
ha='center', va='center', fontsize=12)
|
398 |
+
plt.title('Training Progress')
|
399 |
+
plt.axis('off')
|
400 |
+
|
401 |
+
plt.tight_layout()
|
402 |
+
|
403 |
+
# Save plot
|
404 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
405 |
+
plot_file = self.output_dir / f"visualization_{model_name}_{timestamp}.png"
|
406 |
+
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
|
407 |
+
plt.show()
|
408 |
+
|
409 |
+
print(f"๐ Visualization saved: {plot_file}")
|
410 |
+
|
411 |
+
|
412 |
+
def main():
|
413 |
+
parser = argparse.ArgumentParser(description="Hangman DQN Simulator Playground")
|
414 |
+
parser.add_argument("--model", type=str, required=True, help="Path to model checkpoint")
|
415 |
+
parser.add_argument("--model-name", type=str, help="Name for the model (default: filename)")
|
416 |
+
parser.add_argument("--episodes", type=int, default=1000, help="Number of episodes to test")
|
417 |
+
parser.add_argument("--detailed", action="store_true", help="Run detailed analysis")
|
418 |
+
parser.add_argument("--analyze-decisions", action="store_true", help="Analyze individual game decisions")
|
419 |
+
parser.add_argument("--num-games", type=int, default=10, help="Number of games for decision analysis")
|
420 |
+
parser.add_argument("--dict-path", type=str, default="data/words_250000_train.txt", help="Dictionary path")
|
421 |
+
parser.add_argument("--len-lo", type=int, default=4, help="Minimum word length")
|
422 |
+
parser.add_argument("--len-hi", type=int, default=12, help="Maximum word length")
|
423 |
+
parser.add_argument("--max-len", type=int, default=35, help="Model max sequence length")
|
424 |
+
|
425 |
+
args = parser.parse_args()
|
426 |
+
|
427 |
+
# Initialize simulator
|
428 |
+
simulator = HangmanSimulator(
|
429 |
+
dict_path=args.dict_path,
|
430 |
+
len_lo=args.len_lo,
|
431 |
+
len_hi=args.len_hi,
|
432 |
+
max_len=args.max_len
|
433 |
+
)
|
434 |
+
|
435 |
+
# Load model
|
436 |
+
model_name = args.model_name or Path(args.model).stem
|
437 |
+
model = simulator.load_model(args.model, model_name)
|
438 |
+
|
439 |
+
if model is None:
|
440 |
+
print("โ Failed to load model. Exiting.")
|
441 |
+
return
|
442 |
+
|
443 |
+
# Run comprehensive evaluation
|
444 |
+
print(f"\n๐ Starting comprehensive evaluation of {model_name}...")
|
445 |
+
|
446 |
+
# Compare with baselines
|
447 |
+
results = simulator.compare_with_baselines(model, model_name, args.episodes)
|
448 |
+
|
449 |
+
# Detailed analysis if requested
|
450 |
+
game_analysis = None
|
451 |
+
if args.analyze_decisions:
|
452 |
+
game_analysis = simulator.analyze_model_decisions(model, model_name, args.num_games)
|
453 |
+
|
454 |
+
# Save results
|
455 |
+
simulator.save_results(results, model_name, game_analysis)
|
456 |
+
|
457 |
+
# Create visualization
|
458 |
+
simulator.create_visualization(results, model_name)
|
459 |
+
|
460 |
+
# Print summary
|
461 |
+
print(f"\n๐ Summary for {model_name}:")
|
462 |
+
print(f" DQN Win Rate: {results.get('dqn', 0.0):.3f}")
|
463 |
+
best_baseline = max([(k, v) for k, v in results.items() if k != 'dqn'], key=lambda x: x[1])
|
464 |
+
print(f" Best Baseline: {best_baseline[0]} ({best_baseline[1]:.3f})")
|
465 |
+
print(f" Gap: {best_baseline[1] - results.get('dqn', 0.0):.3f}")
|
466 |
+
|
467 |
+
|
468 |
+
if __name__ == "__main__":
|
469 |
+
main()
|