|
import os |
|
from collections import OrderedDict |
|
|
|
from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig |
|
from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm |
|
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset |
|
from toolkit.stable_diffusion_model import StableDiffusion |
|
import gc |
|
import torch |
|
from jobs.process import BaseExtensionProcess |
|
from toolkit.train_tools import get_torch_dtype |
|
|
|
|
|
def flush(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
class PureLoraGenerator(BaseExtensionProcess): |
|
|
|
def __init__(self, process_id: int, job, config: OrderedDict): |
|
super().__init__(process_id, job, config) |
|
self.output_folder = self.get_conf('output_folder', required=True) |
|
self.device = self.get_conf('device', 'cuda') |
|
self.device_torch = torch.device(self.device) |
|
self.model_config = ModelConfig(**self.get_conf('model', required=True)) |
|
self.generate_config = SampleConfig(**self.get_conf('sample', required=True)) |
|
self.dtype = self.get_conf('dtype', 'float16') |
|
self.torch_dtype = get_torch_dtype(self.dtype) |
|
lorm_config = self.get_conf('lorm', None) |
|
self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None |
|
|
|
self.device_state_preset = get_train_sd_device_state_preset( |
|
device=torch.device(self.device), |
|
) |
|
|
|
self.progress_bar = None |
|
self.sd = StableDiffusion( |
|
device=self.device, |
|
model_config=self.model_config, |
|
dtype=self.dtype, |
|
) |
|
|
|
def run(self): |
|
super().run() |
|
print("Loading model...") |
|
with torch.no_grad(): |
|
self.sd.load_model() |
|
self.sd.unet.eval() |
|
self.sd.unet.to(self.device_torch) |
|
if isinstance(self.sd.text_encoder, list): |
|
for te in self.sd.text_encoder: |
|
te.eval() |
|
te.to(self.device_torch) |
|
else: |
|
self.sd.text_encoder.eval() |
|
self.sd.to(self.device_torch) |
|
|
|
print(f"Converting to LoRM UNet") |
|
|
|
convert_diffusers_unet_to_lorm( |
|
self.sd.unet, |
|
config=self.lorm_config, |
|
) |
|
|
|
sample_folder = os.path.join(self.output_folder) |
|
gen_img_config_list = [] |
|
|
|
sample_config = self.generate_config |
|
start_seed = sample_config.seed |
|
current_seed = start_seed |
|
for i in range(len(sample_config.prompts)): |
|
if sample_config.walk_seed: |
|
current_seed = start_seed + i |
|
|
|
filename = f"[time]_[count].{self.generate_config.ext}" |
|
output_path = os.path.join(sample_folder, filename) |
|
prompt = sample_config.prompts[i] |
|
extra_args = {} |
|
gen_img_config_list.append(GenerateImageConfig( |
|
prompt=prompt, |
|
width=sample_config.width, |
|
height=sample_config.height, |
|
negative_prompt=sample_config.neg, |
|
seed=current_seed, |
|
guidance_scale=sample_config.guidance_scale, |
|
guidance_rescale=sample_config.guidance_rescale, |
|
num_inference_steps=sample_config.sample_steps, |
|
network_multiplier=sample_config.network_multiplier, |
|
output_path=output_path, |
|
output_ext=sample_config.ext, |
|
adapter_conditioning_scale=sample_config.adapter_conditioning_scale, |
|
**extra_args |
|
)) |
|
|
|
|
|
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) |
|
print("Done generating images") |
|
|
|
del self.sd |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|