color_transfer / app.py
JamesL404's picture
Update app.py
4c94334 verified
import argparse
import sys
import numpy as np
import os
import cv2
import skimage
import gradio as gr
from pygltflib import GLTF2
from pygltflib.utils import ImageFormat, Image
import PIL.Image
import io
import base64
METHODS = ('lhm', 'pccm', 'reinhard')
def transfer_lhm(content, reference):
"""Transfers colors from a reference image to a content image using the
Linear Histogram Matching.
content: NumPy array (HxWxC)
reference: NumPy array (HxWxC)
"""
# Convert HxWxC image to a (H*W)xC matrix.
shape = content.shape
assert len(shape) == 3
content = content.reshape(-1, shape[-1]).astype(np.float32)
reference = reference.reshape(-1, shape[-1]).astype(np.float32)
def matrix_sqrt(X):
eig_val, eig_vec = np.linalg.eig(X)
return eig_vec.dot(np.diag(np.sqrt(eig_val))).dot(eig_vec.T)
#
mu_content = np.mean(content, axis=0)
#
mu_reference = np.mean(reference, axis=0)
cov_content = np.cov(content, rowvar=False)
cov_reference = np.cov(reference, rowvar=False)
#
result = matrix_sqrt(cov_reference)
#
result = result.dot(np.linalg.inv(matrix_sqrt(cov_content)))
#
result = result.dot((content - mu_content).T).T
#result = result.dot((content*1 - mu_content*0.5).T).T*3
#
result = result + mu_reference
# Restore image dimensions.
result = result.reshape(shape).clip(0, 255).round().astype(np.uint8)
return result
def transfer_pccm(content, reference):
"""Transfers colors from a reference image to a content image using
Principal Component Color Matching.
content: NumPy array (HxWxC)
reference: NumPy array (HxWxC)
"""
# Convert HxWxC image to a (H*W)xC matrix.
shape = content.shape
assert len(shape) == 3
content = content.reshape(-1, shape[-1]).astype(np.float32)
reference = reference.reshape(-1, shape[-1]).astype(np.float32)
mu_content = np.mean(content, axis=0)
mu_reference = np.mean(reference, axis=0)
cov_content = np.cov(content, rowvar=False)
cov_reference = np.cov(reference, rowvar=False)
eigval_content, eigvec_content = np.linalg.eig(cov_content)
eigval_reference, eigvec_reference = np.linalg.eig(cov_reference)
scaling = np.diag(np.sqrt(eigval_reference / eigval_content))
transform = eigvec_reference.dot(scaling).dot(eigvec_content.T)
result = (content - mu_content).dot(transform.T) + mu_reference
# Restore image dimensions.
result = result.reshape(shape).clip(0, 255).round().astype(np.uint8)
return result
def transfer_reinhard(content, reference):
"""Transfers colors from a reference image to a content image using the
technique from Reinhard et al.
content: NumPy array (HxWxC)
reference: NumPy array (HxWxC)
"""
# Convert HxWxC image to a (H*W)xC matrix.
shape = content.shape
assert len(shape) == 3
content = content.reshape(-1, shape[-1]).astype(np.float32)
reference = reference.reshape(-1, shape[-1]).astype(np.float32)
m1 = np.array([
[0.3811, 0.1967, 0.0241],
[0.5783, 0.7244, 0.1288],
[0.0402, 0.0782, 0.8444],
])
m2 = np.array([
[0.5774, 0.4082, 0.7071],
[0.5774, 0.4082, -0.7071],
[0.5774, -0.8165, 0.0000],
])
m3 = np.array([
[0.5774, 0.5774, 0.5774],
[0.4082, 0.4082, -0.8165],
[0.7071, -0.7071, 0.0000],
])
m4 = np.array([
[4.4679, -1.2186, 0.0497],
[-3.5873, 2.3809, -0.2439],
[0.1193, -0.1624, 1.2045],
])
# Avoid log of 0. Clipping is used instead of adding epsilon, to avoid
# taking a log of a small number whose very low output distorts the results.
# WARN: This differs from the Reinhard paper, where no adjustment is made.
lab_content = np.log10(np.maximum(1.0, content.dot(m1))).dot(m2)
lab_reference = np.log10(np.maximum(1.0, reference.dot(m1))).dot(m2)
mu_content = lab_content.mean(axis=0) # shape=3
mu_reference = lab_reference.mean(axis=0)
std_source = np.std(content, axis=0)
std_target = np.std(reference, axis=0)
#variable percentage for mu and std
result = lab_content - mu_content
result *= std_target
result /= std_source
result += mu_reference
result = (10 ** result.dot(m3)).dot(m4)
# Restore image dimensions.
result = result.reshape(shape).clip(0, 255).round().astype(np.uint8)
return result
# =================================================================================
def runModel(content, style):
if os.path.exists("0.png"):
os.remove("0.png")
gltf = GLTF2().load(content)
gltf.convert_images(ImageFormat.FILE,"/tmp/", override=True)
gltf.images[0].uri
gltf_reinhard = GLTF2().load(content)
gltf_lhm = GLTF2().load(content)
gltf_pccm = GLTF2().load(content)
ori_image=PIL.Image.open('/tmp/0.png')
ori_image=ori_image.convert('RGB')
content_img = np.array(ori_image)[:, :, :3]
style_img = np.array(style)[:, :, :3]
output_reinhard = transfer_reinhard(content_img, style_img)
output_lhm = transfer_lhm(content_img, style_img)
output_pccm = transfer_pccm(content_img, style_img)
image_reinhard = PIL.Image.fromarray(output_reinhard,'RGB')
image_lhm = PIL.Image.fromarray(output_lhm,'RGB')
image_pccm = PIL.Image.fromarray(output_pccm,'RGB')
image_reinhard=image_reinhard.save('/tmp/reinhard.png')
image_lhm=image_lhm.save('/tmp/lhm.png')
image_pccm=image_pccm.save('/tmp/pccm.png')
image1=Image()
image2=Image()
image3=Image()
image1.uri='/tmp/reinhard.png'
image2.uri='/tmp/lhm.png'
image3.uri='/tmp/pccm.png'
gltf_reinhard.images[0]=image1
gltf_lhm.images[0]=image2
gltf_pccm.images[0]=image3
gltf_reinhard.convert_images(ImageFormat.DATAURI)
gltf_lhm.convert_images(ImageFormat.DATAURI)
gltf_pccm.convert_images(ImageFormat.DATAURI)
gltf_reinhard.images[0].uri
gltf_reinhard.images[0].name
gltf_lhm.images[0].uri
gltf_lhm.images[0].name
gltf_pccm.images[0].uri
gltf_pccm.images[0].name
gltf_pccm.save('/tmp/pccm.glb')
gltf_lhm.save('/tmp/lhm.glb')
gltf_reinhard.save('/tmp/reinhard.glb')
return '/tmp/pccm.glb','/tmp/lhm.glb', '/tmp/reinhard.glb'
demo = gr.Interface(runModel, [gr.Model3D(),gr.Image(type ='pil')], outputs = [gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Reinhard"), gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="lhm"), gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="pccm")], )
demo.launch()