more examples
Browse files- OmniGen/__pycache__/__init__.cpython-310.pyc +0 -0
- OmniGen/__pycache__/model.cpython-310.pyc +0 -0
- OmniGen/__pycache__/pipeline.cpython-310.pyc +0 -0
- OmniGen/__pycache__/processor.cpython-310.pyc +0 -0
- OmniGen/__pycache__/scheduler.cpython-310.pyc +0 -0
- OmniGen/__pycache__/transformer.cpython-310.pyc +0 -0
- OmniGen/__pycache__/utils.cpython-310.pyc +0 -0
- OmniGen/pipeline.py +5 -2
- OmniGen/train_helper/__init__.py +2 -0
- OmniGen/train_helper/data.py +116 -0
- OmniGen/train_helper/loss.py +68 -0
- app.py +87 -16
- imgs/demo_cases/edit.png +2 -2
- imgs/demo_cases/entity.png +2 -2
- imgs/demo_cases/reasoning.png +2 -2
- imgs/demo_cases/same_pose.png +2 -2
- imgs/demo_cases/skeletal.png +2 -2
- imgs/demo_cases/skeletal2img.png +2 -2
- imgs/{demo_cases.png β test_cases/1.jpg} +2 -2
- imgs/{overall.jpg β test_cases/2.jpg} +2 -2
- imgs/test_cases/3.jpg +3 -0
- imgs/test_cases/4.jpg +3 -0
- imgs/test_cases/Amanda.jpg +3 -0
- imgs/test_cases/icl1.jpg +3 -0
- imgs/test_cases/icl2.jpg +3 -0
- imgs/test_cases/icl3.jpg +3 -0
- imgs/test_cases/mckenna.jpg +3 -0
- imgs/test_cases/rose.jpg +3 -0
- imgs/test_cases/vase.jpg +3 -0
- imgs/test_cases/zhang.png +3 -0
OmniGen/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/OmniGen/__pycache__/__init__.cpython-310.pyc and b/OmniGen/__pycache__/__init__.cpython-310.pyc differ
|
|
|
OmniGen/__pycache__/model.cpython-310.pyc
CHANGED
|
Binary files a/OmniGen/__pycache__/model.cpython-310.pyc and b/OmniGen/__pycache__/model.cpython-310.pyc differ
|
|
|
OmniGen/__pycache__/pipeline.cpython-310.pyc
CHANGED
|
Binary files a/OmniGen/__pycache__/pipeline.cpython-310.pyc and b/OmniGen/__pycache__/pipeline.cpython-310.pyc differ
|
|
|
OmniGen/__pycache__/processor.cpython-310.pyc
CHANGED
|
Binary files a/OmniGen/__pycache__/processor.cpython-310.pyc and b/OmniGen/__pycache__/processor.cpython-310.pyc differ
|
|
|
OmniGen/__pycache__/scheduler.cpython-310.pyc
CHANGED
|
Binary files a/OmniGen/__pycache__/scheduler.cpython-310.pyc and b/OmniGen/__pycache__/scheduler.cpython-310.pyc differ
|
|
|
OmniGen/__pycache__/transformer.cpython-310.pyc
CHANGED
|
Binary files a/OmniGen/__pycache__/transformer.cpython-310.pyc and b/OmniGen/__pycache__/transformer.cpython-310.pyc differ
|
|
|
OmniGen/__pycache__/utils.cpython-310.pyc
CHANGED
|
Binary files a/OmniGen/__pycache__/utils.cpython-310.pyc and b/OmniGen/__pycache__/utils.cpython-310.pyc differ
|
|
|
OmniGen/pipeline.py
CHANGED
|
@@ -16,6 +16,7 @@ from diffusers.utils import (
|
|
| 16 |
scale_lora_layers,
|
| 17 |
unscale_lora_layers,
|
| 18 |
)
|
|
|
|
| 19 |
|
| 20 |
from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
|
| 21 |
|
|
@@ -59,12 +60,12 @@ class OmniGenPipeline:
|
|
| 59 |
|
| 60 |
@classmethod
|
| 61 |
def from_pretrained(cls, model_name, vae_path: str=None):
|
| 62 |
-
if not os.path.exists(model_name):
|
| 63 |
logger.info("Model not found, downloading...")
|
| 64 |
cache_folder = os.getenv('HF_HUB_CACHE')
|
| 65 |
model_name = snapshot_download(repo_id=model_name,
|
| 66 |
cache_dir=cache_folder,
|
| 67 |
-
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
| 68 |
logger.info(f"Downloaded model to {model_name}")
|
| 69 |
model = OmniGen.from_pretrained(model_name)
|
| 70 |
processor = OmniGenProcessor.from_pretrained(model_name)
|
|
@@ -82,6 +83,8 @@ class OmniGenPipeline:
|
|
| 82 |
def merge_lora(self, lora_path: str):
|
| 83 |
model = PeftModel.from_pretrained(self.model, lora_path)
|
| 84 |
model.merge_and_unload()
|
|
|
|
|
|
|
| 85 |
self.model = model
|
| 86 |
|
| 87 |
def to(self, device: Union[str, torch.device]):
|
|
|
|
| 16 |
scale_lora_layers,
|
| 17 |
unscale_lora_layers,
|
| 18 |
)
|
| 19 |
+
from safetensors.torch import load_file
|
| 20 |
|
| 21 |
from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
|
| 22 |
|
|
|
|
| 60 |
|
| 61 |
@classmethod
|
| 62 |
def from_pretrained(cls, model_name, vae_path: str=None):
|
| 63 |
+
if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"):
|
| 64 |
logger.info("Model not found, downloading...")
|
| 65 |
cache_folder = os.getenv('HF_HUB_CACHE')
|
| 66 |
model_name = snapshot_download(repo_id=model_name,
|
| 67 |
cache_dir=cache_folder,
|
| 68 |
+
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt'])
|
| 69 |
logger.info(f"Downloaded model to {model_name}")
|
| 70 |
model = OmniGen.from_pretrained(model_name)
|
| 71 |
processor = OmniGenProcessor.from_pretrained(model_name)
|
|
|
|
| 83 |
def merge_lora(self, lora_path: str):
|
| 84 |
model = PeftModel.from_pretrained(self.model, lora_path)
|
| 85 |
model.merge_and_unload()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
self.model = model
|
| 89 |
|
| 90 |
def to(self, device: Union[str, torch.device]):
|
OmniGen/train_helper/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .data import DatasetFromJson, TrainDataCollator
|
| 2 |
+
from .loss import training_losses
|
OmniGen/train_helper/data.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import datasets
|
| 3 |
+
from datasets import load_dataset, ClassLabel, concatenate_datasets
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import random
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import json
|
| 9 |
+
import copy
|
| 10 |
+
# import torchvision.transforms as T
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
import pickle
|
| 13 |
+
import re
|
| 14 |
+
|
| 15 |
+
from OmniGen import OmniGenProcessor
|
| 16 |
+
from OmniGen.processor import OmniGenCollator
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DatasetFromJson(torch.utils.data.Dataset):
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
json_file: str,
|
| 23 |
+
image_path: str,
|
| 24 |
+
processer: OmniGenProcessor,
|
| 25 |
+
image_transform,
|
| 26 |
+
max_input_length_limit: int = 18000,
|
| 27 |
+
condition_dropout_prob: float = 0.1,
|
| 28 |
+
keep_raw_resolution: bool = True,
|
| 29 |
+
):
|
| 30 |
+
|
| 31 |
+
self.image_transform = image_transform
|
| 32 |
+
self.processer = processer
|
| 33 |
+
self.condition_dropout_prob = condition_dropout_prob
|
| 34 |
+
self.max_input_length_limit = max_input_length_limit
|
| 35 |
+
self.keep_raw_resolution = keep_raw_resolution
|
| 36 |
+
|
| 37 |
+
self.data = load_dataset('json', data_files=json_file)['train']
|
| 38 |
+
self.image_path = image_path
|
| 39 |
+
|
| 40 |
+
def process_image(self, image_file):
|
| 41 |
+
if self.image_path is not None:
|
| 42 |
+
image_file = os.path.join(self.image_path, image_file)
|
| 43 |
+
image = Image.open(image_file).convert('RGB')
|
| 44 |
+
return self.image_transform(image)
|
| 45 |
+
|
| 46 |
+
def get_example(self, index):
|
| 47 |
+
example = self.data[index]
|
| 48 |
+
|
| 49 |
+
instruction, input_images, output_image = example['instruction'], example['input_images'], example['output_image']
|
| 50 |
+
if random.random() < self.condition_dropout_prob:
|
| 51 |
+
instruction = '<cfg>'
|
| 52 |
+
input_images = None
|
| 53 |
+
if input_images is not None:
|
| 54 |
+
input_images = [self.process_image(x) for x in input_images]
|
| 55 |
+
mllm_input = self.processer.process_multi_modal_prompt(instruction, input_images)
|
| 56 |
+
|
| 57 |
+
output_image = self.process_image(output_image)
|
| 58 |
+
|
| 59 |
+
return (mllm_input, output_image)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def __getitem__(self, index):
|
| 63 |
+
return self.get_example(index)
|
| 64 |
+
for _ in range(8):
|
| 65 |
+
try:
|
| 66 |
+
mllm_input, output_image = self.get_example(index)
|
| 67 |
+
if len(mllm_input['input_ids']) > self.max_input_length_limit:
|
| 68 |
+
raise RuntimeError(f"cur number of tokens={len(mllm_input['input_ids'])}, larger than max_input_length_limit={self.max_input_length_limit}")
|
| 69 |
+
return mllm_input, output_image
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print("error when loading data: ", e)
|
| 72 |
+
print(self.data[index])
|
| 73 |
+
index = random.randint(0, len(self.data)-1)
|
| 74 |
+
raise RuntimeError("Too many bad data.")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def __len__(self):
|
| 78 |
+
return len(self.data)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TrainDataCollator(OmniGenCollator):
|
| 83 |
+
def __init__(self, pad_token_id: int, hidden_size: int, keep_raw_resolution: bool):
|
| 84 |
+
self.pad_token_id = pad_token_id
|
| 85 |
+
self.hidden_size = hidden_size
|
| 86 |
+
self.keep_raw_resolution = keep_raw_resolution
|
| 87 |
+
|
| 88 |
+
def __call__(self, features):
|
| 89 |
+
mllm_inputs = [f[0] for f in features]
|
| 90 |
+
|
| 91 |
+
output_images = [f[1].unsqueeze(0) for f in features]
|
| 92 |
+
target_img_size = [[x.size(-2), x.size(-1)] for x in output_images]
|
| 93 |
+
|
| 94 |
+
all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
|
| 95 |
+
|
| 96 |
+
if not self.keep_raw_resolution:
|
| 97 |
+
output_image = torch.cat(output_image, dim=0)
|
| 98 |
+
if len(pixel_values) > 0:
|
| 99 |
+
all_pixel_values = torch.cat(all_pixel_values, dim=0)
|
| 100 |
+
else:
|
| 101 |
+
all_pixel_values = None
|
| 102 |
+
|
| 103 |
+
data = {"input_ids": all_padded_input_ids,
|
| 104 |
+
"attention_mask": all_attention_mask,
|
| 105 |
+
"position_ids": all_position_ids,
|
| 106 |
+
"input_pixel_values": all_pixel_values,
|
| 107 |
+
"input_image_sizes": all_image_sizes,
|
| 108 |
+
"padding_images": all_padding_images,
|
| 109 |
+
"output_images": output_images,
|
| 110 |
+
}
|
| 111 |
+
return data
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
OmniGen/train_helper/loss.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def sample_x0(x1):
|
| 5 |
+
"""Sampling x0 & t based on shape of x1 (if needed)
|
| 6 |
+
Args:
|
| 7 |
+
x1 - data point; [batch, *dim]
|
| 8 |
+
"""
|
| 9 |
+
if isinstance(x1, (list, tuple)):
|
| 10 |
+
x0 = [torch.randn_like(img_start) for img_start in x1]
|
| 11 |
+
else:
|
| 12 |
+
x0 = torch.randn_like(x1)
|
| 13 |
+
|
| 14 |
+
return x0
|
| 15 |
+
|
| 16 |
+
def sample_timestep(x1):
|
| 17 |
+
u = torch.normal(mean=0.0, std=1.0, size=(len(x1),))
|
| 18 |
+
t = 1 / (1 + torch.exp(-u))
|
| 19 |
+
t = t.to(x1[0])
|
| 20 |
+
return t
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def training_losses(model, x1, model_kwargs=None, snr_type='uniform'):
|
| 24 |
+
"""Loss for training torche score model
|
| 25 |
+
Args:
|
| 26 |
+
- model: backbone model; could be score, noise, or velocity
|
| 27 |
+
- x1: datapoint
|
| 28 |
+
- model_kwargs: additional arguments for torche model
|
| 29 |
+
"""
|
| 30 |
+
if model_kwargs == None:
|
| 31 |
+
model_kwargs = {}
|
| 32 |
+
|
| 33 |
+
B = len(x1)
|
| 34 |
+
|
| 35 |
+
x0 = sample_x0(x1)
|
| 36 |
+
t = sample_timestep(x1)
|
| 37 |
+
|
| 38 |
+
if isinstance(x1, (list, tuple)):
|
| 39 |
+
xt = [t[i] * x1[i] + (1 - t[i]) * x0[i] for i in range(B)]
|
| 40 |
+
ut = [x1[i] - x0[i] for i in range(B)]
|
| 41 |
+
else:
|
| 42 |
+
dims = [1] * (len(x1.size()) - 1)
|
| 43 |
+
t_ = t.view(t.size(0), *dims)
|
| 44 |
+
xt = t_ * x1 + (1 - t_) * x0
|
| 45 |
+
ut = x1 - x0
|
| 46 |
+
|
| 47 |
+
model_output = model(xt, t, **model_kwargs)
|
| 48 |
+
|
| 49 |
+
terms = {}
|
| 50 |
+
|
| 51 |
+
if isinstance(x1, (list, tuple)):
|
| 52 |
+
assert len(model_output) == len(ut) == len(x1)
|
| 53 |
+
for i in range(B):
|
| 54 |
+
terms["loss"] = torch.stack(
|
| 55 |
+
[((ut[i] - model_output[i]) ** 2).mean() for i in range(B)],
|
| 56 |
+
dim=0,
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
terms["loss"] = mean_flat(((model_output - ut) ** 2))
|
| 60 |
+
|
| 61 |
+
return terms
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def mean_flat(x):
|
| 65 |
+
"""
|
| 66 |
+
Take torche mean over all non-batch dimensions.
|
| 67 |
+
"""
|
| 68 |
+
return torch.mean(x, dim=list(range(1, len(x.size()))))
|
app.py
CHANGED
|
@@ -11,7 +11,7 @@ pipe = OmniGenPipeline.from_pretrained(
|
|
| 11 |
|
| 12 |
@spaces.GPU(duration=180)
|
| 13 |
# η€ΊδΎε€ηε½ζ°οΌηζεΎε
|
| 14 |
-
def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
|
| 15 |
input_images = [img1, img2, img3]
|
| 16 |
# ε»ι€ None
|
| 17 |
input_images = [img for img in input_images if img is not None]
|
|
@@ -26,7 +26,7 @@ def generate_image(text, img1, img2, img3, height, width, guidance_scale, infere
|
|
| 26 |
guidance_scale=guidance_scale,
|
| 27 |
img_guidance_scale=1.6,
|
| 28 |
num_inference_steps=inference_steps,
|
| 29 |
-
separate_cfg_infer=True,
|
| 30 |
use_kv_cache=False,
|
| 31 |
seed=seed,
|
| 32 |
)
|
|
@@ -47,26 +47,28 @@ def generate_image(text, img1, img2, img3, height, width, guidance_scale, infere
|
|
| 47 |
def get_example():
|
| 48 |
case = [
|
| 49 |
[
|
| 50 |
-
"A
|
| 51 |
None,
|
| 52 |
None,
|
| 53 |
None,
|
| 54 |
1024,
|
| 55 |
1024,
|
| 56 |
2.5,
|
|
|
|
| 57 |
50,
|
| 58 |
0,
|
| 59 |
],
|
| 60 |
[
|
| 61 |
-
"
|
| 62 |
"./imgs/test_cases/yifei2.png",
|
| 63 |
None,
|
| 64 |
None,
|
| 65 |
1024,
|
| 66 |
1024,
|
| 67 |
2.5,
|
|
|
|
| 68 |
50,
|
| 69 |
-
|
| 70 |
],
|
| 71 |
[
|
| 72 |
"A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
|
|
@@ -76,17 +78,55 @@ def get_example():
|
|
| 76 |
1024,
|
| 77 |
1024,
|
| 78 |
2.5,
|
|
|
|
| 79 |
50,
|
| 80 |
0,
|
| 81 |
],
|
| 82 |
[
|
| 83 |
-
"Two
|
| 84 |
-
"./imgs/test_cases/
|
| 85 |
-
"./imgs/test_cases/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
None,
|
| 87 |
1024,
|
| 88 |
1024,
|
| 89 |
2.5,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
50,
|
| 91 |
0,
|
| 92 |
],
|
|
@@ -98,6 +138,7 @@ def get_example():
|
|
| 98 |
1024,
|
| 99 |
1024,
|
| 100 |
2.5,
|
|
|
|
| 101 |
50,
|
| 102 |
222,
|
| 103 |
],
|
|
@@ -109,6 +150,7 @@ def get_example():
|
|
| 109 |
1024,
|
| 110 |
1024,
|
| 111 |
2.0,
|
|
|
|
| 112 |
50,
|
| 113 |
0,
|
| 114 |
],
|
|
@@ -120,6 +162,7 @@ def get_example():
|
|
| 120 |
1024,
|
| 121 |
1024,
|
| 122 |
2,
|
|
|
|
| 123 |
50,
|
| 124 |
42,
|
| 125 |
],
|
|
@@ -131,9 +174,22 @@ def get_example():
|
|
| 131 |
1024,
|
| 132 |
1024,
|
| 133 |
2.0,
|
|
|
|
| 134 |
50,
|
| 135 |
123,
|
| 136 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
[
|
| 138 |
"<img><|image_1|><\/img> What item can be used to see the current time? Please remove it.",
|
| 139 |
"./imgs/test_cases/watch.jpg",
|
|
@@ -142,25 +198,27 @@ def get_example():
|
|
| 142 |
1024,
|
| 143 |
1024,
|
| 144 |
2.5,
|
|
|
|
| 145 |
50,
|
| 146 |
0,
|
| 147 |
],
|
| 148 |
[
|
| 149 |
-
"
|
| 150 |
-
"./imgs/test_cases/
|
| 151 |
-
"./imgs/test_cases/
|
| 152 |
-
"./imgs/test_cases/
|
| 153 |
1024,
|
| 154 |
1024,
|
| 155 |
2.5,
|
|
|
|
| 156 |
50,
|
| 157 |
-
|
| 158 |
],
|
| 159 |
]
|
| 160 |
return case
|
| 161 |
|
| 162 |
-
def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
|
| 163 |
-
return generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed)
|
| 164 |
|
| 165 |
description = """
|
| 166 |
OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, and image-conditioned generation.
|
|
@@ -168,6 +226,13 @@ OmniGen is a unified image generation model that you can use to perform various
|
|
| 168 |
For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>` (for the first image, the placeholder is <img><|image_1|></img>. for the second image, the the placeholder is <img><|image_2|></img>).
|
| 169 |
For example, use an image of a woman to generate a new image:
|
| 170 |
prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is \<img\>\<|image_1|\>\</img\>."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
"""
|
| 172 |
|
| 173 |
# Gradio ζ₯ε£
|
|
@@ -197,7 +262,11 @@ with gr.Blocks() as demo:
|
|
| 197 |
|
| 198 |
# εΌε―Όε°ΊεΊ¦θΎε
₯
|
| 199 |
guidance_scale_input = gr.Slider(
|
| 200 |
-
label="Guidance Scale", minimum=1.0, maximum=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
)
|
| 202 |
|
| 203 |
num_inference_steps = gr.Slider(
|
|
@@ -226,6 +295,7 @@ with gr.Blocks() as demo:
|
|
| 226 |
height_input,
|
| 227 |
width_input,
|
| 228 |
guidance_scale_input,
|
|
|
|
| 229 |
num_inference_steps,
|
| 230 |
seed_input,
|
| 231 |
],
|
|
@@ -243,6 +313,7 @@ with gr.Blocks() as demo:
|
|
| 243 |
height_input,
|
| 244 |
width_input,
|
| 245 |
guidance_scale_input,
|
|
|
|
| 246 |
num_inference_steps,
|
| 247 |
seed_input,
|
| 248 |
],
|
|
|
|
| 11 |
|
| 12 |
@spaces.GPU(duration=180)
|
| 13 |
# η€ΊδΎε€ηε½ζ°οΌηζεΎε
|
| 14 |
+
def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed):
|
| 15 |
input_images = [img1, img2, img3]
|
| 16 |
# ε»ι€ None
|
| 17 |
input_images = [img for img in input_images if img is not None]
|
|
|
|
| 26 |
guidance_scale=guidance_scale,
|
| 27 |
img_guidance_scale=1.6,
|
| 28 |
num_inference_steps=inference_steps,
|
| 29 |
+
separate_cfg_infer=True, # set False can speed up the inference process
|
| 30 |
use_kv_cache=False,
|
| 31 |
seed=seed,
|
| 32 |
)
|
|
|
|
| 47 |
def get_example():
|
| 48 |
case = [
|
| 49 |
[
|
| 50 |
+
"A curly-haired man in a red shirt is drinking tea.",
|
| 51 |
None,
|
| 52 |
None,
|
| 53 |
None,
|
| 54 |
1024,
|
| 55 |
1024,
|
| 56 |
2.5,
|
| 57 |
+
1.6,
|
| 58 |
50,
|
| 59 |
0,
|
| 60 |
],
|
| 61 |
[
|
| 62 |
+
"The woman in <img><|image_1|></img> waves her hand happily in the crowd",
|
| 63 |
"./imgs/test_cases/yifei2.png",
|
| 64 |
None,
|
| 65 |
None,
|
| 66 |
1024,
|
| 67 |
1024,
|
| 68 |
2.5,
|
| 69 |
+
1.9,
|
| 70 |
50,
|
| 71 |
+
128,
|
| 72 |
],
|
| 73 |
[
|
| 74 |
"A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
|
|
|
|
| 78 |
1024,
|
| 79 |
1024,
|
| 80 |
2.5,
|
| 81 |
+
1.6,
|
| 82 |
50,
|
| 83 |
0,
|
| 84 |
],
|
| 85 |
[
|
| 86 |
+
"Two woman are raising fried chicken legs in a bar. A woman is <img><|image_1|></img>. The other woman is <img><|image_2|></img>.",
|
| 87 |
+
"./imgs/test_cases/mckenna.jpg",
|
| 88 |
+
"./imgs/test_cases/Amanda.jpg",
|
| 89 |
+
None,
|
| 90 |
+
1024,
|
| 91 |
+
1024,
|
| 92 |
+
2.5,
|
| 93 |
+
1.8,
|
| 94 |
+
50,
|
| 95 |
+
168,
|
| 96 |
+
],
|
| 97 |
+
[
|
| 98 |
+
"A man and a short-haired woman with a wrinkled face are standing in front of a bookshelf in a library. The man is the man in the middle of <img><|image_1|></img>, and the woman is oldest woman in <img><|image_2|></img>",
|
| 99 |
+
"./imgs/test_cases/1.jpg",
|
| 100 |
+
"./imgs/test_cases/2.jpg",
|
| 101 |
+
None,
|
| 102 |
+
1024,
|
| 103 |
+
1024,
|
| 104 |
+
2.5,
|
| 105 |
+
1.6,
|
| 106 |
+
50,
|
| 107 |
+
60,
|
| 108 |
+
],
|
| 109 |
+
[
|
| 110 |
+
"A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <img><|image_1|></img>. The woman is the woman on the left of <img><|image_2|></img>",
|
| 111 |
+
"./imgs/test_cases/3.jpg",
|
| 112 |
+
"./imgs/test_cases/4.jpg",
|
| 113 |
None,
|
| 114 |
1024,
|
| 115 |
1024,
|
| 116 |
2.5,
|
| 117 |
+
1.8,
|
| 118 |
+
50,
|
| 119 |
+
66,
|
| 120 |
+
],
|
| 121 |
+
[
|
| 122 |
+
"The flower <img><|image_1|><\/img> is placed in the vase which is in the middle of <img><|image_2|><\/img> on a wooden table of a living room",
|
| 123 |
+
"./imgs/test_cases/rose.jpg",
|
| 124 |
+
"./imgs/test_cases/vase.jpg",
|
| 125 |
+
None,
|
| 126 |
+
1024,
|
| 127 |
+
1024,
|
| 128 |
+
2.5,
|
| 129 |
+
1.6,
|
| 130 |
50,
|
| 131 |
0,
|
| 132 |
],
|
|
|
|
| 138 |
1024,
|
| 139 |
1024,
|
| 140 |
2.5,
|
| 141 |
+
1.6,
|
| 142 |
50,
|
| 143 |
222,
|
| 144 |
],
|
|
|
|
| 150 |
1024,
|
| 151 |
1024,
|
| 152 |
2.0,
|
| 153 |
+
1.6,
|
| 154 |
50,
|
| 155 |
0,
|
| 156 |
],
|
|
|
|
| 162 |
1024,
|
| 163 |
1024,
|
| 164 |
2,
|
| 165 |
+
1.6,
|
| 166 |
50,
|
| 167 |
42,
|
| 168 |
],
|
|
|
|
| 174 |
1024,
|
| 175 |
1024,
|
| 176 |
2.0,
|
| 177 |
+
1.6,
|
| 178 |
50,
|
| 179 |
123,
|
| 180 |
],
|
| 181 |
+
[
|
| 182 |
+
"Following the depth mapping of this image <img><|image_1|><img>, generate a new photo: A young girl is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
|
| 183 |
+
"./imgs/demo_cases/edit.png",
|
| 184 |
+
None,
|
| 185 |
+
None,
|
| 186 |
+
1024,
|
| 187 |
+
1024,
|
| 188 |
+
2.0,
|
| 189 |
+
1.6,
|
| 190 |
+
50,
|
| 191 |
+
1,
|
| 192 |
+
],
|
| 193 |
[
|
| 194 |
"<img><|image_1|><\/img> What item can be used to see the current time? Please remove it.",
|
| 195 |
"./imgs/test_cases/watch.jpg",
|
|
|
|
| 198 |
1024,
|
| 199 |
1024,
|
| 200 |
2.5,
|
| 201 |
+
1.6,
|
| 202 |
50,
|
| 203 |
0,
|
| 204 |
],
|
| 205 |
[
|
| 206 |
+
"According to the following examples, generate an output for the input.\nInput: <img><|image_1|></img>\nOutput: <img><|image_2|></img>\n\nInput: <img><|image_3|></img>\nOutput: ",
|
| 207 |
+
"./imgs/test_cases/icl1.jpg",
|
| 208 |
+
"./imgs/test_cases/icl2.jpg",
|
| 209 |
+
"./imgs/test_cases/icl3.jpg",
|
| 210 |
1024,
|
| 211 |
1024,
|
| 212 |
2.5,
|
| 213 |
+
1.6,
|
| 214 |
50,
|
| 215 |
+
1,
|
| 216 |
],
|
| 217 |
]
|
| 218 |
return case
|
| 219 |
|
| 220 |
+
def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed):
|
| 221 |
+
return generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed)
|
| 222 |
|
| 223 |
description = """
|
| 224 |
OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, and image-conditioned generation.
|
|
|
|
| 226 |
For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>` (for the first image, the placeholder is <img><|image_1|></img>. for the second image, the the placeholder is <img><|image_2|></img>).
|
| 227 |
For example, use an image of a woman to generate a new image:
|
| 228 |
prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is \<img\>\<|image_1|\>\</img\>."
|
| 229 |
+
|
| 230 |
+
Tips:
|
| 231 |
+
- Oversaturated: If the image appears oversaturated, please reduce the `guidance_scale`.
|
| 232 |
+
- Low-quality: More detailed prompt will lead to better results.
|
| 233 |
+
- Animate Style: If the genereate images is in animate style, you can try to add `photo` to the prompt`.
|
| 234 |
+
- Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image.
|
| 235 |
+
- For image editing tasks, we recommend placing the image before the editing instruction. For example, use `<img><|image_1|></img> remove suit`, rather than `remove suit <img><|image_1|></img>`.
|
| 236 |
"""
|
| 237 |
|
| 238 |
# Gradio ζ₯ε£
|
|
|
|
| 262 |
|
| 263 |
# εΌε―Όε°ΊεΊ¦θΎε
₯
|
| 264 |
guidance_scale_input = gr.Slider(
|
| 265 |
+
label="Guidance Scale", minimum=1.0, maximum=5.0, value=2.5, step=0.1
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
img_guidance_scale_input = gr.Slider(
|
| 269 |
+
label="img_guidance_scale", minimum=1.0, maximum=2.0, value=1.6, step=0.1
|
| 270 |
)
|
| 271 |
|
| 272 |
num_inference_steps = gr.Slider(
|
|
|
|
| 295 |
height_input,
|
| 296 |
width_input,
|
| 297 |
guidance_scale_input,
|
| 298 |
+
img_guidance_scale_input,
|
| 299 |
num_inference_steps,
|
| 300 |
seed_input,
|
| 301 |
],
|
|
|
|
| 313 |
height_input,
|
| 314 |
width_input,
|
| 315 |
guidance_scale_input,
|
| 316 |
+
img_guidance_scale_input,
|
| 317 |
num_inference_steps,
|
| 318 |
seed_input,
|
| 319 |
],
|
imgs/demo_cases/edit.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
imgs/demo_cases/entity.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
imgs/demo_cases/reasoning.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
imgs/demo_cases/same_pose.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
imgs/demo_cases/skeletal.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
imgs/demo_cases/skeletal2img.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
imgs/{demo_cases.png β test_cases/1.jpg}
RENAMED
|
File without changes
|
imgs/{overall.jpg β test_cases/2.jpg}
RENAMED
|
File without changes
|
imgs/test_cases/3.jpg
ADDED
|
Git LFS Details
|
imgs/test_cases/4.jpg
ADDED
|
Git LFS Details
|
imgs/test_cases/Amanda.jpg
ADDED
|
Git LFS Details
|
imgs/test_cases/icl1.jpg
ADDED
|
Git LFS Details
|
imgs/test_cases/icl2.jpg
ADDED
|
Git LFS Details
|
imgs/test_cases/icl3.jpg
ADDED
|
Git LFS Details
|
imgs/test_cases/mckenna.jpg
ADDED
|
Git LFS Details
|
imgs/test_cases/rose.jpg
ADDED
|
Git LFS Details
|
imgs/test_cases/vase.jpg
ADDED
|
Git LFS Details
|
imgs/test_cases/zhang.png
ADDED
|
Git LFS Details
|