LIA-X / gradio_tabs /img_edit.py
YaohuiW's picture
Update gradio_tabs/img_edit.py
3e81f41 verified
import tempfile
import gradio as gr
import torch
import torchvision
from PIL import Image
import numpy as np
import imageio
import spaces
from einops import rearrange
# lables
labels_k = [
'yaw1',
'yaw2',
'pitch',
'roll1',
'roll2',
'neck',
'pout',
'open->close',
'"O" Mouth',
'smile',
'close->open',
'eyebrows',
'eyeballs1',
'eyeballs2',
]
labels_v = [
37, 39, 28, 15, 33, 31,
6, 25, 16, 19,
13, 24, 17, 26
]
def load_image(img, size):
img = Image.open(img).convert('RGB')
w, h = img.size
img = img.resize((size, size))
img = np.asarray(img)
img = np.copy(img)
img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
return img / 255.0, w, h
def img_preprocessing(img_path, size):
img, w, h = load_image(img_path, size) # [0, 1]
img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
return imgs_norm, w, h
# def resize(img, size):
# transform = torchvision.transforms.Compose([
# torchvision.transforms.Resize((size,size), antialias=True),
# ])
# return transform(img)
# def resize_back(img, w, h):
# transform = torchvision.transforms.Compose([
# torchvision.transforms.Resize((h, w), antialias=True),
# ])
# return transform(img)
# Pre-compile resize transforms for better performance
resize_transform_cache = {}
def get_resize_transform(size):
"""Get cached resize transform - creates once, reuses many times"""
if size not in resize_transform_cache:
# Only create the transform if it doesn't exist in cache
resize_transform_cache[size] = torchvision.transforms.Resize(
size,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True
)
return resize_transform_cache[size]
def resize(img, size):
"""Use cached resize transform"""
transform = get_resize_transform((size, size))
return transform(img)
def resize_back(img, w, h):
"""Use cached resize transform for back operation"""
transform = get_resize_transform((h, w))
return transform(img)
def img_denorm(img):
img = img.clamp(-1, 1)
img = (img - img.min()) / (img.max() - img.min())
return img
# def img_postprocessing(image, w, h):
# image = resize_back(image, w, h)
# image = image.permute(0, 2, 3, 1)
# edited_image = img_denorm(image)
# img_output = (edited_image[0].numpy() * 255).astype(np.uint8)
# with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
# imageio.imwrite(temp_file.name, img_output, quality=8)
# return temp_file.name
def img_postprocessing(img, w, h):
img = resize_back(img, w, h)
img = img_denorm(img)
img = img.squeeze(0).permute(1, 2, 0).contiguous() # contiguous() for fast transfer
img_output = (img.cpu().numpy() * 255).astype(np.uint8)
return img_output
def img_edit(gen, device):
@spaces.GPU
@torch.no_grad()
def edit_img(image, *selected_s):
image_tensor, w, h = img_preprocessing(image, 512)
image_tensor = image_tensor.to(device)
edited_image_tensor = gen.edit_img(image_tensor, labels_v, selected_s)
# de-norm
edited_image = img_postprocessing(edited_image_tensor, w, h)
return edited_image
def clear_media():
return None, *([0] * len(labels_k))
with gr.Tab("Image Editing"):
inputs_s = []
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
with gr.Accordion(open=True, label="Image"):
image_input = gr.Image(type="filepath", width=512) # , height=550)
gr.Examples(
examples=[
["./data/source/macron.png"],
["./data/source/einstein.png"],
["./data/source/taylor.png"],
["./data/source/portrait1.png"],
["./data/source/portrait2.png"],
["./data/source/portrait3.png"],
],
inputs=[image_input],
cache_examples=False,
visible=True,
)
with gr.Row():
with gr.Column(scale=1):
with gr.Row(): # Buttons now within a single Row
edit_btn = gr.Button("Edit")
clear_btn = gr.Button("Clear")
with gr.Row():
animate_btn = gr.Button("Generate")
with gr.Column(scale=1):
with gr.Row():
with gr.Accordion(open=True, label="Edited Image"):
image_output = gr.Image(label="Output Image", type='numpy', interactive=False, width=512)
with gr.Accordion("Control Panel", open=True):
with gr.Tab("Head"):
with gr.Row():
for k in labels_k[:3]:
slider = gr.Slider(minimum=-1.0, maximum=0.5, value=0, label=k)
inputs_s.append(slider)
with gr.Row():
for k in labels_k[3:6]:
slider = gr.Slider(minimum=-0.5, maximum=0.5, value=0, label=k)
inputs_s.append(slider)
with gr.Tab("Mouth"):
with gr.Row():
for k in labels_k[6:8]:
slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k)
inputs_s.append(slider)
with gr.Row():
for k in labels_k[8:10]:
slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k)
inputs_s.append(slider)
with gr.Tab("Eyes"):
with gr.Row():
for k in labels_k[10:12]:
slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k)
inputs_s.append(slider)
with gr.Row():
for k in labels_k[12:14]:
slider = gr.Slider(minimum=-0.2, maximum=0.2, value=0, label=k)
inputs_s.append(slider)
edit_btn.click(
fn=edit_img,
inputs=[image_input] + inputs_s,
outputs=[image_output],
show_progress=True
)
clear_btn.click(
fn=clear_media,
outputs=[image_output] + inputs_s
)