Jose Marie Antonio Minoza
commited on
Commit
·
9c02486
1
Parent(s):
ad860b5
Initial commit
Browse files- models/loader.py +33 -36
models/loader.py
CHANGED
@@ -49,9 +49,7 @@ class ModelLoader:
|
|
49 |
self.model_registry[ModelType.VQVAE] = VQVAE
|
50 |
except ImportError:
|
51 |
print("Warning: VQVAE model implementation not found")
|
52 |
-
|
53 |
-
# Temporary removal of CNN and GAN models
|
54 |
-
"""
|
55 |
try:
|
56 |
from .cnn import CNN
|
57 |
self.model_registry[ModelType.SRCNN] = CNN
|
@@ -60,11 +58,9 @@ class ModelLoader:
|
|
60 |
|
61 |
try:
|
62 |
from .gan import UncertainESRGAN
|
63 |
-
self.model_registry[ModelType.GAN] = UncertainESRGAN
|
64 |
except ImportError:
|
65 |
-
print("Warning:
|
66 |
-
"""
|
67 |
-
|
68 |
|
69 |
def load_model(self, model_type: str, checkpoint_path: str, config_overrides: Optional[Dict] = None):
|
70 |
"""
|
@@ -94,7 +90,12 @@ class ModelLoader:
|
|
94 |
# Initialize model with potentially modified config
|
95 |
model_class = self.model_registry[model_type]
|
96 |
model = model_class(**model_config)
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
# Move model to device
|
100 |
model = model.to(self.device)
|
@@ -102,35 +103,31 @@ class ModelLoader:
|
|
102 |
# Load checkpoint
|
103 |
try:
|
104 |
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=True)
|
105 |
-
print(f"
|
106 |
-
|
107 |
-
model.load_state_dict(checkpoint['state_dict'], strict = False)
|
108 |
-
|
109 |
-
print(checkpoint['state_dict'].keys())
|
110 |
-
print(checkpoint.keys())
|
111 |
-
|
112 |
-
model.uncertainty_tracker.ema_errors = checkpoint['state_dict'].get(
|
113 |
-
'_orig_mod.uncertainty_tracker.ema_errors',
|
114 |
-
torch.zeros(model.uncertainty_tracker.block_size**2).to(self.device)
|
115 |
-
)
|
116 |
-
model.uncertainty_tracker.ema_quantile = checkpoint['state_dict'].get(
|
117 |
-
'_orig_mod.uncertainty_tracker.ema_quantile',
|
118 |
-
torch.zeros(model.uncertainty_tracker.block_size**2).to(self.device)
|
119 |
-
)
|
120 |
-
model.uncertainty_tracker._initialize_buffers(64, 64, self.device)
|
121 |
-
|
122 |
-
# Store model
|
123 |
-
|
124 |
-
model.uncertainty_tracker.ema_errors = checkpoint['state_dict']['_orig_mod.uncertainty_tracker.ema_errors']
|
125 |
-
model.uncertainty_tracker.ema_quantile = checkpoint['state_dict']['_orig_mod.uncertainty_tracker.ema_quantile']
|
126 |
-
|
127 |
-
model.uncertainty_tracker.calibrated = checkpoint.get('calibrated', True)
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
print(f"Successfully loaded {model_type.value} model from {checkpoint_path}")
|
136 |
except Exception as e:
|
|
|
49 |
self.model_registry[ModelType.VQVAE] = VQVAE
|
50 |
except ImportError:
|
51 |
print("Warning: VQVAE model implementation not found")
|
52 |
+
|
|
|
|
|
53 |
try:
|
54 |
from .cnn import CNN
|
55 |
self.model_registry[ModelType.SRCNN] = CNN
|
|
|
58 |
|
59 |
try:
|
60 |
from .gan import UncertainESRGAN
|
61 |
+
self.model_registry[ModelType.GAN] = UncertainESRGAN
|
62 |
except ImportError:
|
63 |
+
print("Warning: GAN model implementation not found")
|
|
|
|
|
64 |
|
65 |
def load_model(self, model_type: str, checkpoint_path: str, config_overrides: Optional[Dict] = None):
|
66 |
"""
|
|
|
90 |
# Initialize model with potentially modified config
|
91 |
model_class = self.model_registry[model_type]
|
92 |
model = model_class(**model_config)
|
93 |
+
|
94 |
+
# Apply torch compilation for performance
|
95 |
+
try:
|
96 |
+
model = torch.compile(model)
|
97 |
+
except Exception as e:
|
98 |
+
print(f"Warning: Could not compile model: {e}")
|
99 |
|
100 |
# Move model to device
|
101 |
model = model.to(self.device)
|
|
|
103 |
# Load checkpoint
|
104 |
try:
|
105 |
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=True)
|
106 |
+
print(f"Checkpoint keys: {checkpoint.keys()}")
|
107 |
+
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
+
# Explicitly load uncertainty tracker buffers
|
110 |
+
if hasattr(model, 'uncertainty_tracker'):
|
111 |
+
# Load EMA buffers
|
112 |
+
model.uncertainty_tracker.ema_errors = checkpoint['state_dict'].get(
|
113 |
+
'_orig_mod.uncertainty_tracker.ema_errors',
|
114 |
+
torch.zeros(model.uncertainty_tracker.block_size**2).to(self.device)
|
115 |
+
)
|
116 |
+
model.uncertainty_tracker.ema_quantile = checkpoint['state_dict'].get(
|
117 |
+
'_orig_mod.uncertainty_tracker.ema_quantile',
|
118 |
+
torch.zeros(model.uncertainty_tracker.block_size**2).to(self.device)
|
119 |
+
)
|
120 |
+
|
121 |
+
# Initialize buffers with proper dimensions (64x64 for super-resolution output)
|
122 |
+
model.uncertainty_tracker._initialize_buffers(64, 64, self.device)
|
123 |
+
|
124 |
+
# Load calibration state if available
|
125 |
+
model.uncertainty_tracker.calibrated = checkpoint.get('calibrated', False)
|
126 |
+
if model.uncertainty_tracker.calibrated:
|
127 |
+
if 'block_scale_means' in checkpoint:
|
128 |
+
model.uncertainty_tracker.block_scale_means = checkpoint['block_scale_means']
|
129 |
+
if 'block_scale_stds' in checkpoint:
|
130 |
+
model.uncertainty_tracker.block_scale_stds = checkpoint['block_scale_stds']
|
131 |
|
132 |
print(f"Successfully loaded {model_type.value} model from {checkpoint_path}")
|
133 |
except Exception as e:
|