Jose Marie Antonio Minoza commited on
Commit
9c02486
·
1 Parent(s): ad860b5

Initial commit

Browse files
Files changed (1) hide show
  1. models/loader.py +33 -36
models/loader.py CHANGED
@@ -49,9 +49,7 @@ class ModelLoader:
49
  self.model_registry[ModelType.VQVAE] = VQVAE
50
  except ImportError:
51
  print("Warning: VQVAE model implementation not found")
52
-
53
- # Temporary removal of CNN and GAN models
54
- """
55
  try:
56
  from .cnn import CNN
57
  self.model_registry[ModelType.SRCNN] = CNN
@@ -60,11 +58,9 @@ class ModelLoader:
60
 
61
  try:
62
  from .gan import UncertainESRGAN
63
- self.model_registry[ModelType.GAN] = UncertainESRGAN
64
  except ImportError:
65
- print("Warning: Model implementation not found")
66
- """
67
-
68
 
69
  def load_model(self, model_type: str, checkpoint_path: str, config_overrides: Optional[Dict] = None):
70
  """
@@ -94,7 +90,12 @@ class ModelLoader:
94
  # Initialize model with potentially modified config
95
  model_class = self.model_registry[model_type]
96
  model = model_class(**model_config)
97
- model = torch.compile(model)
 
 
 
 
 
98
 
99
  # Move model to device
100
  model = model.to(self.device)
@@ -102,35 +103,31 @@ class ModelLoader:
102
  # Load checkpoint
103
  try:
104
  checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=True)
105
- print(f"Successfully loaded {model_type.value} model from {checkpoint_path}")
106
-
107
- model.load_state_dict(checkpoint['state_dict'], strict = False)
108
-
109
- print(checkpoint['state_dict'].keys())
110
- print(checkpoint.keys())
111
-
112
- model.uncertainty_tracker.ema_errors = checkpoint['state_dict'].get(
113
- '_orig_mod.uncertainty_tracker.ema_errors',
114
- torch.zeros(model.uncertainty_tracker.block_size**2).to(self.device)
115
- )
116
- model.uncertainty_tracker.ema_quantile = checkpoint['state_dict'].get(
117
- '_orig_mod.uncertainty_tracker.ema_quantile',
118
- torch.zeros(model.uncertainty_tracker.block_size**2).to(self.device)
119
- )
120
- model.uncertainty_tracker._initialize_buffers(64, 64, self.device)
121
-
122
- # Store model
123
-
124
- model.uncertainty_tracker.ema_errors = checkpoint['state_dict']['_orig_mod.uncertainty_tracker.ema_errors']
125
- model.uncertainty_tracker.ema_quantile = checkpoint['state_dict']['_orig_mod.uncertainty_tracker.ema_quantile']
126
-
127
- model.uncertainty_tracker.calibrated = checkpoint.get('calibrated', True)
128
 
129
- if model.uncertainty_tracker.calibrated:
130
- if 'block_scale_means' in checkpoint:
131
- model.uncertainty_tracker.block_scale_means = checkpoint['block_scale_means']
132
- if 'block_scale_stds' in checkpoint:
133
- model.uncertainty_tracker.block_scale_stds = checkpoint['block_scale_stds']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  print(f"Successfully loaded {model_type.value} model from {checkpoint_path}")
136
  except Exception as e:
 
49
  self.model_registry[ModelType.VQVAE] = VQVAE
50
  except ImportError:
51
  print("Warning: VQVAE model implementation not found")
52
+
 
 
53
  try:
54
  from .cnn import CNN
55
  self.model_registry[ModelType.SRCNN] = CNN
 
58
 
59
  try:
60
  from .gan import UncertainESRGAN
61
+ self.model_registry[ModelType.GAN] = UncertainESRGAN
62
  except ImportError:
63
+ print("Warning: GAN model implementation not found")
 
 
64
 
65
  def load_model(self, model_type: str, checkpoint_path: str, config_overrides: Optional[Dict] = None):
66
  """
 
90
  # Initialize model with potentially modified config
91
  model_class = self.model_registry[model_type]
92
  model = model_class(**model_config)
93
+
94
+ # Apply torch compilation for performance
95
+ try:
96
+ model = torch.compile(model)
97
+ except Exception as e:
98
+ print(f"Warning: Could not compile model: {e}")
99
 
100
  # Move model to device
101
  model = model.to(self.device)
 
103
  # Load checkpoint
104
  try:
105
  checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=True)
106
+ print(f"Checkpoint keys: {checkpoint.keys()}")
107
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ # Explicitly load uncertainty tracker buffers
110
+ if hasattr(model, 'uncertainty_tracker'):
111
+ # Load EMA buffers
112
+ model.uncertainty_tracker.ema_errors = checkpoint['state_dict'].get(
113
+ '_orig_mod.uncertainty_tracker.ema_errors',
114
+ torch.zeros(model.uncertainty_tracker.block_size**2).to(self.device)
115
+ )
116
+ model.uncertainty_tracker.ema_quantile = checkpoint['state_dict'].get(
117
+ '_orig_mod.uncertainty_tracker.ema_quantile',
118
+ torch.zeros(model.uncertainty_tracker.block_size**2).to(self.device)
119
+ )
120
+
121
+ # Initialize buffers with proper dimensions (64x64 for super-resolution output)
122
+ model.uncertainty_tracker._initialize_buffers(64, 64, self.device)
123
+
124
+ # Load calibration state if available
125
+ model.uncertainty_tracker.calibrated = checkpoint.get('calibrated', False)
126
+ if model.uncertainty_tracker.calibrated:
127
+ if 'block_scale_means' in checkpoint:
128
+ model.uncertainty_tracker.block_scale_means = checkpoint['block_scale_means']
129
+ if 'block_scale_stds' in checkpoint:
130
+ model.uncertainty_tracker.block_scale_stds = checkpoint['block_scale_stds']
131
 
132
  print(f"Successfully loaded {model_type.value} model from {checkpoint_path}")
133
  except Exception as e: