|
--- |
|
base_model: stabilityai/stable-diffusion-xl-base-1.0 |
|
library_name: diffusers |
|
license: openrail++ |
|
tags: |
|
- text-to-image |
|
- text-to-image |
|
- diffusers-training |
|
- diffusers |
|
- lora |
|
- template:sd-lora |
|
- stable-diffusion-xl |
|
- stable-diffusion-xl-diffusers |
|
instance_prompt: <leaf microstructure> |
|
widget: [] |
|
--- |
|
|
|
# SDXL Fine-tuned with Leaf Images |
|
|
|
## Model description |
|
|
|
These are LoRA adaption weights for the SDXL-base-1.0 model. |
|
|
|
## Trigger keywords |
|
|
|
The following image were used during fine-tuning using the keyword \<leaf microstructure\>: |
|
|
|
 |
|
|
|
You should use <leaf microstructure> to trigger the image generation. |
|
|
|
## How to use |
|
|
|
Defining some helper functions: |
|
|
|
```python |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
import os |
|
from datetime import datetime |
|
from PIL import Image |
|
|
|
def generate_filename(base_name, extension=".png"): |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
return f"{base_name}_{timestamp}{extension}" |
|
|
|
def save_image(image, directory, base_name="image_grid"): |
|
|
|
filename = generate_filename(base_name) |
|
file_path = os.path.join(directory, filename) |
|
image.save(file_path) |
|
print(f"Image saved as {file_path}") |
|
|
|
def image_grid(imgs, rows, cols, save=True, save_dir='generated_images', base_name="image_grid", |
|
save_individual_files=False): |
|
|
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir) |
|
|
|
assert len(imgs) == rows * cols |
|
|
|
w, h = imgs[0].size |
|
grid = Image.new('RGB', size=(cols * w, rows * h)) |
|
grid_w, grid_h = grid.size |
|
|
|
for i, img in enumerate(imgs): |
|
grid.paste(img, box=(i % cols * w, i // cols * h)) |
|
if save_individual_files: |
|
save_image(img, save_dir, base_name=base_name+f'_{i}-of-{len(imgs)}_') |
|
|
|
if save and save_dir: |
|
save_image(grid, save_dir, base_name) |
|
|
|
return grid |
|
``` |
|
|
|
### Text-to-image |
|
|
|
Model loading: |
|
|
|
```python |
|
|
|
import torch |
|
from diffusers import DiffusionPipeline, AutoencoderKL |
|
|
|
repo_id='lamm-mit/SDXL-leaf-inspired' |
|
|
|
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) |
|
base = DiffusionPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
vae=vae, |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
use_safetensors=True |
|
) |
|
base.load_lora_weights(repo_id) |
|
_ = base.to("cuda") |
|
|
|
refiner = DiffusionPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-refiner-1.0", |
|
text_encoder_2=base.text_encoder_2, |
|
vae=base.vae, |
|
torch_dtype=torch.float16, |
|
use_safetensors=True, |
|
variant="fp16", |
|
) |
|
refiner.to("cuda") |
|
``` |
|
|
|
Image generation: |
|
|
|
```python |
|
prompt = "a vase that resembles a <leaf microstructure>, high quality" |
|
|
|
num_samples = 4 |
|
num_rows = 4 |
|
guidance_scale = 15 |
|
|
|
all_images = [] |
|
|
|
for _ in range(num_rows): |
|
# Define how many steps and what % of steps to be run on each experts (80/20) |
|
n_steps = 25 |
|
high_noise_frac = 0.8 |
|
|
|
# run both experts |
|
image = base( |
|
prompt=prompt, |
|
num_inference_steps=n_steps, guidance_scale=guidance_scale, |
|
denoising_end=high_noise_frac,num_images_per_prompt=num_samples, |
|
output_type="latent", |
|
).images |
|
image = refiner( |
|
prompt=prompt, |
|
num_inference_steps=n_steps, guidance_scale=guidance_scale, |
|
denoising_start=high_noise_frac,num_images_per_prompt=num_samples, |
|
image=image, |
|
).images |
|
|
|
all_images.extend(image) |
|
|
|
grid = image_grid(all_images, num_rows, num_samples, |
|
save_individual_files=True, |
|
) |
|
grid |
|
``` |
|
|
|
|
|
 |
|
|
|
## Fine-tuning script |
|
|
|
Download this script: [SDXL DreamBooth-LoRA_Fine-Tune.ipynb](https://huggingface.co/lamm-mit/SDXL-leaf-inspired/resolve/main/SDXL_DreamBooth_LoRA_Fine-Tune.ipynb) |
|
|
|
You need to create a local folder ```leaf_concept_dir_SDXL``` and add the leaf images (provided in this repository, see subfolder). |
|
|
|
The code will automatically download the training script. |
|
|
|
The training script can handle custom prompts associated with each image, which are generated using BLIP. |
|
|
|
For instance, for the images used here, they are: |
|
|
|
```raw |
|
['<leaf microstructure>, a close up of a green plant with a lot of small holes', |
|
'<leaf microstructure>, a close up of a leaf with a small insect on it', |
|
'<leaf microstructure>, a close up of a plant with a lot of green leaves', |
|
'<leaf microstructure>, a close up of a green plant with a yellow light', |
|
'<leaf microstructure>, a close up of a green plant with a white center', |
|
'<leaf microstructure>, arafed leaf with a white line on the center', |
|
'<leaf microstructure>, a close up of a leaf with a yellow light shining through it', |
|
'<leaf microstructure>, arafed image of a green plant with a yellow cross'] |
|
``` |
|
|
|
Training then proceeds as: |
|
|
|
```python |
|
HF_username = 'lamm-mit' |
|
|
|
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" |
|
pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" |
|
|
|
instance_prompt ="<leaf microstructure>" |
|
instance_data_dir = "./leaf_concept_dir_SDXL/" |
|
|
|
val_prompt = "a vase that resembles a <leaf microstructure>, high quality" |
|
val_epochs = 100 |
|
|
|
instance_output_dir="leaf_LoRA_SDXL_V10" #for checkpointing |
|
``` |
|
|
|
Dataset generatio with custom per-image captions |
|
```python |
|
import requests |
|
from transformers import AutoProcessor, BlipForConditionalGeneration |
|
import torch |
|
import glob |
|
from PIL import Image |
|
import json |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
# load the processor and the captioning model |
|
blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large") |
|
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large",torch_dtype=torch.float16).to(device) |
|
|
|
# captioning utility |
|
def caption_images(input_image): |
|
inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16) |
|
pixel_values = inputs.pixel_values |
|
|
|
generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50) |
|
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return generated_caption |
|
|
|
caption_prefix = f"{instance_prompt}, " |
|
with open(f'{instance_data_dir}metadata.jsonl', 'w') as outfile: |
|
for img in imgs_and_paths: |
|
caption = caption_prefix + caption_images(img[1]).split("\n")[0] |
|
entry = {"file_name":img[0].split("/")[-1], "prompt": caption} |
|
json.dump(entry, outfile) |
|
outfile.write('\n') |
|
``` |
|
This produces a JSON file in the ```instance_data_dir``` directory: |
|
|
|
```json |
|
{"file_name": "0.jpeg", "prompt": "<leaf microstructure>, a close up of a green plant with a lot of small holes"} |
|
{"file_name": "1.jpeg", "prompt": "<leaf microstructure>, a close up of a leaf with a small insect on it"} |
|
{"file_name": "2.jpeg", "prompt": "<leaf microstructure>, a close up of a plant with a lot of green leaves"} |
|
{"file_name": "3.jpeg", "prompt": "<leaf microstructure>, a close up of a leaf with a yellow substance in it"} |
|
{"file_name": "87.jpg", "prompt": "<leaf microstructure>, a close up of a green plant with a yellow light"} |
|
{"file_name": "88.jpg", "prompt": "<leaf microstructure>, a close up of a green plant with a white center"} |
|
{"file_name": "90.jpg", "prompt": "<leaf microstructure>, arafed leaf with a white line on the center"} |
|
{"file_name": "91.jpg", "prompt": "<leaf microstructure>, arafed image of a green leaf with a white spot"} |
|
{"file_name": "92.jpg", "prompt": "<leaf microstructure>, a close up of a leaf with a yellow light shining through it"} |
|
{"file_name": "94.jpg", "prompt": "<leaf microstructure>, arafed image of a green plant with a yellow cross"} |
|
``` |
|
|
|
```raw |
|
!accelerate launch train_dreambooth_lora_sdxl.py \ |
|
--pretrained_model_name_or_path="{pretrained_model_name_or_path}" \ |
|
--pretrained_vae_model_name_or_path="{pretrained_vae_model_name_or_path}"\ |
|
--dataset_name="{instance_data_dir}" \ |
|
--output_dir="{instance_output_dir}" \ |
|
--caption_column="prompt"\ |
|
--mixed_precision="fp16" \ |
|
--instance_prompt="{instance_prompt}" \ |
|
--validation_prompt="{val_prompt}" \ |
|
--validation_epochs="{val_epochs}" \ |
|
--resolution=1024 \ |
|
--train_batch_size=1 \ |
|
--gradient_accumulation_steps=3 \ |
|
--gradient_checkpointing \ |
|
--learning_rate=1e-4 \ |
|
--snr_gamma=5.0 \ |
|
--lr_scheduler="constant" \ |
|
--lr_warmup_steps=0 \ |
|
--mixed_precision="fp16" \ |
|
--use_8bit_adam \ |
|
--max_train_steps=500 \ |
|
--checkpointing_steps=500 \ |
|
--seed="0" |
|
``` |