File size: 14,516 Bytes
fcd264e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import gradio as gr
from modules import scripts, shared, sd_models, lowvram, devices, paths
import gc
import torch
import os

try:
    from modules.sd_models import forge_model_reload, model_data, CheckpointInfo
    from modules_forge.main_entry import forge_unet_storage_dtype_options
    from backend.memory_management import free_memory as forge_free_memory
    from modules.timer import Timer
    forge = True
except ImportError:
    forge = False
    class CheckpointInfo:
        def __init__(self, filename):
            self.filename = filename
            self.name = os.path.splitext(os.path.basename(filename))[0]
            self.name_or_path = filename
            self.sha256 = None
            self.ids = None
            self.model_name = self.name
            self.title = self.name
    class Timer:
        def record(self, *args, **kwargs): pass

class ModelUtilState:
    last_loaded_checkpoint_info_dict = None
    last_forge_model_params = None
    is_model_unloaded_by_ext = False

state = ModelUtilState()

def get_current_checkpoint_info():
    if forge and hasattr(model_data, 'sd_checkpoint_info') and model_data.sd_checkpoint_info:
        return model_data.sd_checkpoint_info
    if hasattr(shared, 'sd_model') and shared.sd_model and hasattr(shared.sd_model, 'sd_checkpoint_info') and shared.sd_model.sd_checkpoint_info:
        return shared.sd_model.sd_checkpoint_info
    if shared.opts.sd_model_checkpoint:
        checkpoint_path = sd_models.get_checkpoint_path(shared.opts.sd_model_checkpoint)
        if checkpoint_path:
            return CheckpointInfo(checkpoint_path)
    return None

def checkpoint_info_to_dict(chkpt_info):
    if not chkpt_info:
        return None
    return {
        "filename": getattr(chkpt_info, 'filename', None),
        "name": getattr(chkpt_info, 'name', None),
        "name_or_path": getattr(chkpt_info, 'name_or_path', getattr(chkpt_info, 'filename', None)),
        "sha256": getattr(chkpt_info, 'sha256', None),
        "model_name": getattr(chkpt_info, 'model_name', None),
        "title": getattr(chkpt_info, 'title', None),
    }

def ensure_name_or_path(info_obj):
    if not info_obj:
        return info_obj
    if not hasattr(info_obj, 'name_or_path') or not getattr(info_obj, 'name_or_path', None):
        filename_attr = getattr(info_obj, 'filename', None)
        title_attr = getattr(info_obj, 'title', None)
        name_attr = getattr(info_obj, 'name', None)
        if filename_attr:
            print(f"Info object was missing 'name_or_path'. Setting from 'filename': {filename_attr}")
            info_obj.name_or_path = filename_attr
        elif title_attr:
            print(f"Info object was missing 'name_or_path/filename'. Setting from 'title': {title_attr}")
            info_obj.name_or_path = title_attr
        elif name_attr:
            print(f"Info object was missing 'name_or_path/filename/title'. Setting from 'name': {name_attr}")
            info_obj.name_or_path = name_attr
        else:
            print(f"CRITICAL: Info object is missing 'name_or_path', 'filename', 'title', and 'name'. Cannot reliably set 'name_or_path'.")
    return info_obj

def dict_to_checkpoint_info(chkpt_dict):
    if not chkpt_dict or not chkpt_dict.get('name_or_path'):
        print(f"Warning: chkpt_dict is invalid or missing 'name_or_path': {chkpt_dict}")
        return None

    target_model_identifier = chkpt_dict['name_or_path']
    print(f"Attempting to find CheckpointInfo for: {target_model_identifier}")

    available_checkpoints = sd_models.checkpoints_list
    found_info = None

    for name, info_obj_from_list in available_checkpoints.items():
        info_name_or_path = getattr(info_obj_from_list, 'name_or_path', None)
        info_filename = getattr(info_obj_from_list, 'filename', None)
        info_title = getattr(info_obj_from_list, 'title', None)
        match_found = False
        if info_name_or_path and info_name_or_path == target_model_identifier: match_found = True
        elif info_filename and info_filename == target_model_identifier: match_found = True
        elif name == target_model_identifier: match_found = True
        elif info_title and info_title == target_model_identifier: match_found = True
        if match_found:
            print(f"Found matching CheckpointInfo in available_checkpoints: {name}")
            found_info = info_obj_from_list
            break
    if found_info:
        return ensure_name_or_path(found_info)

    print(f"CheckpointInfo for '{target_model_identifier}' not found in list. Attempting to create new one.")
    if os.path.exists(target_model_identifier):
        print(f"File exists at path: {target_model_identifier}. Creating new CheckpointInfo.")
        newly_created_info = CheckpointInfo(target_model_identifier)
        for key, value in chkpt_dict.items():
            if not hasattr(newly_created_info, key) or getattr(newly_created_info, key) is None:
                setattr(newly_created_info, key, value)
        return ensure_name_or_path(newly_created_info)
    else:
        print(f"File does not exist at path: {target_model_identifier}. Cannot create CheckpointInfo.")

    print(f"Warning: Could not reconstruct CheckpointInfo for {target_model_identifier}.")
    return None


def unload_model_logic():
    model_loaded = (forge and hasattr(model_data, 'sd_model') and model_data.sd_model) or \
                   (not forge and hasattr(shared, 'sd_model') and shared.sd_model)
    if not model_loaded:
        state.is_model_unloaded_by_ext = False
        return "Model is already unloaded or not loaded."

    print("Unloading SD model...")
    current_info = get_current_checkpoint_info()
    if current_info:
        state.last_loaded_checkpoint_info_dict = checkpoint_info_to_dict(current_info)
        print(f"Storing info for model: {state.last_loaded_checkpoint_info_dict.get('name_or_path')}")
    else:
        state.last_loaded_checkpoint_info_dict = None
        print("Could not get current checkpoint info to store.")

    if forge:
        if hasattr(model_data, "forge_loading_parameters") and model_data.forge_loading_parameters:
            state.last_forge_model_params = model_data.forge_loading_parameters.copy()
        else:
            state.last_forge_model_params = None
        sd_models.model_data.sd_model = None
        if hasattr(sd_models.model_data, 'loaded_sd_models'):
            sd_models.model_data.loaded_sd_models = []
        if hasattr(sd_models.model_data, 'forge_objects'):
            for attr in ['unet', 'vae', 'clip_l', 'clip_g', 'clip_vision', 'gligen', 'controlnet_predict', 'patch_manager', 'conditioner']: # Added conditioner
                if hasattr(sd_models.model_data.forge_objects, attr):
                    setattr(sd_models.model_data.forge_objects, attr, None)
        cuda_device_str = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
        if torch.cuda.is_available():
            forge_free_memory(torch.cuda.memory_allocated(cuda_device_str), cuda_device_str, free_all=True)
        print("Forge model components cleared and memory freed.")
    else:
        sd_models.unload_model_weights()
        print("Standard model unloaded.")

    lowvram.module_in_gpu = None
    shared.sd_model = None
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    state.is_model_unloaded_by_ext = True
    return "Model unloaded successfully. VRAM freed."

def _ensure_module_on_device(module, module_name, target_device, indent="  "):
    if module and isinstance(module, torch.nn.Module) and next(module.parameters(), None) is not None:
        current_device = next(module.parameters()).device
        if current_device.type != target_device.type or (target_device.type == 'cuda' and current_device.index != target_device.index):
            print(f"{indent}Moving {module_name} from {current_device} to {target_device}...")
            module.to(target_device)
            return True
    return False

def reload_last_model_logic():
    model_currently_loaded = (forge and hasattr(model_data, 'sd_model') and model_data.sd_model and model_data.sd_model is not shared.sd_model_empty) or \
                             (not forge and hasattr(shared, 'sd_model') and shared.sd_model and shared.sd_model is not shared.sd_model_empty)

    if model_currently_loaded and not state.is_model_unloaded_by_ext:
        return "Model is already loaded and was not unloaded by this extension. No action taken."

    if not state.last_loaded_checkpoint_info_dict:
        if shared.opts.sd_model_checkpoint:
            print(f"No specific model info stored by extension, trying to use WebUI's selected model: {shared.opts.sd_model_checkpoint}")
            checkpoint_path = sd_models.get_checkpoint_path(shared.opts.sd_model_checkpoint)
            if checkpoint_path:
                 state.last_loaded_checkpoint_info_dict = checkpoint_info_to_dict(CheckpointInfo(checkpoint_path))
            else:
                return "No last model information found and WebUI's selected model could not be resolved."
        else:
            return "No last model information found to reload."

    chkpt_info_to_load = dict_to_checkpoint_info(state.last_loaded_checkpoint_info_dict)
    if not chkpt_info_to_load or not getattr(chkpt_info_to_load, 'name_or_path', None):
        return f"Could not reconstruct valid CheckpointInfo from stored data: {state.last_loaded_checkpoint_info_dict}. Cannot reload."

    model_display_name = getattr(chkpt_info_to_load, 'name_or_path', getattr(chkpt_info_to_load, 'filename', 'Unknown Model'))
    print(f"Reloading SD model: {model_display_name}")

    try:
        devices.torch_gc()

        if forge:
            print("Forge: Reloading using forge_model_reload()...")
            if state.last_forge_model_params:
                sd_models.model_data.forge_loading_parameters = state.last_forge_model_params.copy()
                sd_models.model_data.forge_loading_parameters['checkpoint_info'] = chkpt_info_to_load
            else:
                print("Warning: No specific Forge params stored, building defaults for reload.")
                unet_storage_dtype, _ = forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, (None, False))
                sd_models.model_data.forge_loading_parameters = dict(
                    checkpoint_info=chkpt_info_to_load,
                    additional_modules=shared.opts.forge_additional_modules,
                    unet_storage_dtype=unet_storage_dtype
                )
            sd_models.model_data.forge_hash = None
            
            forge_model_reload()

            if not sd_models.model_data.sd_model:
                raise RuntimeError("forge_model_reload() did not populate model_data.sd_model.")
            
            shared.sd_model = sd_models.model_data.sd_model
            print("Forge: forge_model_reload() completed.")

            if torch.cuda.is_available():
                cuda_device = torch.device(devices.get_cuda_device_string())
                print(f"Forge: Verifying device placement on {cuda_device} after reload...")
                
                _ensure_module_on_device(shared.sd_model, "shared.sd_model (main)", cuda_device)
                
                if hasattr(shared.sd_model, 'forge_objects') and shared.sd_model.forge_objects:
                    fo = shared.sd_model.forge_objects
                    _ensure_module_on_device(getattr(fo, 'unet', None), "UNet (from forge_objects)", cuda_device)
                    _ensure_module_on_device(getattr(fo, 'vae', None), "VAE (from forge_objects)", cuda_device)
                    _ensure_module_on_device(getattr(fo, 'clip', None), "CLIP (main from forge_objects)", cuda_device)
                    if hasattr(fo, 'clip') and fo.clip:
                        _ensure_module_on_device(getattr(fo.clip,'cond_stage_model', None), "CLIP cond_stage_model", cuda_device)

                if hasattr(shared.sd_model, 'conditioner') and shared.sd_model.conditioner:
                    _ensure_module_on_device(shared.sd_model.conditioner, "Conditioner", cuda_device)
                    if hasattr(shared.sd_model.conditioner, 'embedders'):
                        for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
                            _ensure_module_on_device(embedder, f"Embedder {i}", cuda_device)

                print("Forge: Device verification and correction attempt finished.")
        else:
            sd_models.load_model(chkpt_info_to_load)
            print("Standard model reloaded.")
            if torch.cuda.is_available() and shared.sd_model:
                 cuda_device = torch.device(devices.get_cuda_device_string())
                 _ensure_module_on_device(shared.sd_model, "shared.sd_model (main)", cuda_device)


        state.is_model_unloaded_by_ext = False
        return f"Model '{model_display_name}' reloaded successfully."

    except Exception as e:
        print(f"Error reloading model: {e}")
        import traceback
        traceback.print_exc()
        lowvram.module_in_gpu = None
        shared.sd_model = None
        if forge and hasattr(model_data, 'sd_model'): model_data.sd_model = None
        gc.collect()
        if torch.cuda.is_available(): torch.cuda.empty_cache()
        return f"Error reloading model: {e}. Model remains unloaded."


class UnloadReloadModelScript(scripts.Script):
    def title(self):
        return "Model Unload/Reload Util"

    def show(self, is_img2img):
        return scripts.AlwaysVisible

    def ui(self, is_img2img):
        with gr.Accordion(self.title(), open=False):
            with gr.Row():
                unload_button = gr.Button("Unload Current SD Model (Free VRAM)")
                reload_button = gr.Button("Reload Last Unloaded SD Model")
            status_text = gr.Textbox(label="Status", value="Ready.", interactive=False, lines=3, max_lines=3)

            unload_button.click(fn=unload_model_logic, inputs=[], outputs=[status_text])
            reload_button.click(fn=reload_last_model_logic, inputs=[], outputs=[status_text])
        return [unload_button, reload_button, status_text]