Ovis-U1-3B / test_img_edit.py
Flourish's picture
Upload 12 files
ff3266f verified
import os
import argparse
import math
import numpy as np
import torch
from PIL import Image
from transformers import AutoModelForCausalLM
def parse_args():
parser = argparse.ArgumentParser(description="Test Image Editing")
parser.add_argument(
"--model_path",
type=str,
default="AIDC-AI/Ovis-U1-3B",
)
parser.add_argument(
"--steps", type=int, default=50,
)
parser.add_argument(
"--img_cfg", type=float, default=1.5,
)
parser.add_argument(
"--txt_cfg", type=float, default=6,
)
args = parser.parse_args()
return args
def load_blank_image(width, height):
pil_image = Image.new("RGB", (width, height), (255, 255, 255)).convert('RGB')
return pil_image
def build_inputs(model, text_tokenizer, visual_tokenizer, prompt, pil_image, target_width, target_height):
if pil_image is not None:
target_size = (int(target_width), int(target_height))
pil_image, vae_pixel_values, cond_img_ids = model.visual_generator.process_image_aspectratio(pil_image, target_size)
cond_img_ids[..., 0] = 1.0
vae_pixel_values = vae_pixel_values.unsqueeze(0).to(device=model.device)
width = pil_image.width
height = pil_image.height
resized_height, resized_width = visual_tokenizer.smart_resize(height, width, max_pixels=visual_tokenizer.image_processor.min_pixels)
pil_image = pil_image.resize((resized_width, resized_height))
else:
vae_pixel_values = None
cond_img_ids = None
prompt, input_ids, pixel_values, grid_thws = model.preprocess_inputs(
prompt,
[pil_image],
generation_preface=None,
return_labels=False,
propagate_exception=False,
multimodal_type='single_image',
fix_sample_overall_length_navit=False
)
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
input_ids = input_ids.unsqueeze(0).to(device=model.device)
attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
if pixel_values is not None:
pixel_values = torch.cat([
pixel_values.to(device=visual_tokenizer.device, dtype=torch.bfloat16) if pixel_values is not None else None
],dim=0)
if grid_thws is not None:
grid_thws = torch.cat([
grid_thws.to(device=visual_tokenizer.device) if grid_thws is not None else None
],dim=0)
return input_ids, pixel_values, attention_mask, grid_thws, vae_pixel_values
def pipe_img_edit(model, input_img, prompt, steps, txt_cfg, img_cfg, seed=42):
text_tokenizer = model.get_text_tokenizer()
visual_tokenizer = model.get_visual_tokenizer()
width, height = input_img.size
height, width = visual_tokenizer.smart_resize(height, width, factor=32)
gen_kwargs = dict(
max_new_tokens=1024,
do_sample=False,
top_p=None,
top_k=None,
temperature=None,
repetition_penalty=None,
eos_token_id=text_tokenizer.eos_token_id,
pad_token_id=text_tokenizer.pad_token_id,
use_cache=True,
height=height,
width=width,
num_steps=steps,
seed=seed,
img_cfg=img_cfg,
txt_cfg=txt_cfg,
)
uncond_image = load_blank_image(width, height)
uncond_prompt = "<image>\nGenerate an image."
input_ids, pixel_values, attention_mask, grid_thws, _ = build_inputs(model, text_tokenizer, visual_tokenizer, uncond_prompt, uncond_image, width, height)
with torch.inference_mode():
no_both_cond = model.generate_condition(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, grid_thws=grid_thws, **gen_kwargs)
input_img = input_img.resize((width, height))
prompt = "<image>\n" + prompt.strip()
with torch.inference_mode():
input_ids, pixel_values, attention_mask, grid_thws, _ = build_inputs(model, text_tokenizer, visual_tokenizer, uncond_prompt, input_img, width, height)
no_txt_cond = model.generate_condition(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, grid_thws=grid_thws, **gen_kwargs)
input_ids, pixel_values, attention_mask, grid_thws, vae_pixel_values = build_inputs(model, text_tokenizer, visual_tokenizer, prompt, input_img, width, height)
with torch.inference_mode():
cond = model.generate_condition(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, grid_thws=grid_thws, **gen_kwargs)
cond["vae_pixel_values"] = vae_pixel_values
images = model.generate_img(cond=cond, no_both_cond=no_both_cond, no_txt_cond=no_txt_cond, **gen_kwargs)
return images
def main():
args = parse_args()
model, loading_info = AutoModelForCausalLM.from_pretrained(args.model_path,
torch_dtype=torch.bfloat16,
output_loading_info=True,
trust_remote_code=True
)
print(f'Loading info of Ovis-U1:\n{loading_info}')
model = model.eval().to("cuda")
model = model.to(torch.bfloat16)
image_path = os.path.join(os.path.dirname(__file__), "docs", "imgs", "cat.png")
pil_img = Image.open(image_path).convert('RGB')
prompt = "add a hat to this cat."
image = pipe_img_edit(model, pil_img, prompt,
args.steps, args.txt_cfg, args.img_cfg)[0]
image.save("test_image_edit.png")
if __name__ == "__main__":
main()