Spaces:
Runtime error
Runtime error
import os | |
from PIL import Image | |
import torch | |
import gradio as gr | |
import torch | |
torch.backends.cudnn.benchmark = True | |
from torchvision import transforms, utils | |
from util import * | |
from PIL import Image | |
import math | |
import random | |
import numpy as np | |
from torch import nn, autograd, optim | |
from torch.nn import functional as F | |
from tqdm import tqdm | |
import lpips | |
from model import * | |
#from e4e_projection import projection as e4e_projection | |
from copy import deepcopy | |
import imageio | |
import os | |
import sys | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
from argparse import Namespace | |
from e4e.models.psp import pSp | |
from e4e.models.encoders import psp_encoders | |
from util import * | |
from huggingface_hub import hf_hub_download | |
device= 'cpu' | |
ffhq_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512.pt") | |
ffhq_ckpt = torch.load(ffhq_model_path, map_location='cpu') | |
ffhq_latent_avg = ffhq_ckpt['latent_avg'].to(device) | |
ffhq_opts = ffhq_ckpt['opts'] | |
ffhq_opts['checkpoint_path'] = ffhq_model_path | |
ffhq_opts= Namespace(**ffhq_opts) | |
ffhq_encoder = psp_encoders.Encoder4Editing(50, 'ir_se', ffhq_opts) | |
ffhq_e_filt = {k[len('encoder') + 1:]: v for k, v in ffhq_ckpt['state_dict'].items() if k[:len('encoder')] == 'encoder'} | |
ffhq_encoder.load_state_dict(ffhq_e_filt, strict=True) | |
ffhq_encoder.eval() | |
ffhq_encoder.to(device) | |
ffhq_decoder = Generator(512, 512, 8, channel_multiplier=2) | |
ffhq_d_filt = {k[len('decoder') + 1:]: v for k, v in ffhq_ckpt['state_dict'].items() if k[:len('decoder')] == 'decoder'} | |
ffhq_decoder.load_state_dict(ffhq_d_filt, strict=True) | |
ffhq_decoder.eval() | |
ffhq_decoder.to(device) | |
clear_output() | |
dog_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512_dog.pt") | |
dog_ckpt = torch.load(dog_model_path, map_location='cpu') | |
dog_latent_avg = dog_ckpt['latent_avg'].to(device) | |
dog_opts = dog_ckpt['opts'] | |
dog_opts['checkpoint_path'] = dog_model_path | |
dog_opts= Namespace(**dog_opts) | |
dog_encoder = psp_encoders.Encoder4Editing(50, 'ir_se', dog_opts) | |
dog_e_filt = {k[len('encoder') + 1:]: v for k, v in dog_ckpt['state_dict'].items() if k[:len('encoder')] == 'encoder'} | |
dog_encoder.load_state_dict(dog_e_filt, strict=True) | |
dog_encoder.eval() | |
dog_encoder.to(device) | |
dog_decoder = Generator(512, 512, 8, channel_multiplier=2) | |
dog_d_filt = {k[len('decoder') + 1:]: v for k, v in dog_ckpt['state_dict'].items() if k[:len('decoder')] == 'decoder'} | |
dog_decoder.load_state_dict(dog_d_filt, strict=True) | |
dog_decoder.eval() | |
dog_decoder.to(device) | |
clear_output() | |
cat_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512_cat.pt") | |
cat_ckpt = torch.load(cat_model_path, map_location='cpu') | |
cat_latent_avg = cat_ckpt['latent_avg'].to(device) | |
cat_opts = cat_ckpt['opts'] | |
cat_opts['checkpoint_path'] = cat_model_path | |
cat_opts= Namespace(**cat_opts) | |
cat_encoder = psp_encoders.Encoder4Editing(50, 'ir_se', cat_opts) | |
cat_e_filt = {k[len('encoder') + 1:]: v for k, v in cat_ckpt['state_dict'].items() if k[:len('encoder')] == 'encoder'} | |
cat_encoder.load_state_dict(cat_e_filt, strict=True) | |
cat_encoder.eval() | |
cat_encoder.to(device) | |
cat_decoder = Generator(512, 512, 8, channel_multiplier=2) | |
cat_d_filt = {k[len('decoder') + 1:]: v for k, v in cat_ckpt['state_dict'].items() if k[:len('decoder')] == 'decoder'} | |
cat_decoder.load_state_dict(cat_d_filt, strict=True) | |
cat_decoder.eval() | |
cat_decoder.to(device) | |
clear_output() | |
def gen_im(model_type='ffhq'): | |
if model_type=='ffhq': | |
imgs, _ = ffhq_decoder([ffhq_codes], input_is_latent=True, randomize_noise=False, return_latents=True) | |
elif model_type=='dog': | |
imgs, _ = dog_decoder([dog_codes], input_is_latent=True, randomize_noise=False, return_latents=True) | |
elif model_type=='cat': | |
imgs, _ = cat_decoder([cat_codes], input_is_latent=True, randomize_noise=False, return_latents=True) | |
else: | |
imgs, _ = custom_decoder([custom_codes], input_is_latent=True, randomize_noise=False, return_latents=True) | |
return tensor2im(imgs[0]) | |
def inference(img): | |
img.save('out.jpg') | |
aligned_face = align_face('out.jpg') | |
ffhq_codes = ffhq_encoder(aligned_face.unsqueeze(0).to(device).float()) | |
ffhq_codes = ffhq_codes + ffhq_latent_avg.repeat(ffhq_codes.shape[0], 1, 1) | |
cat_codes = cat_encoder(aligned_face.unsqueeze(0).to(device).float()) | |
cat_codes = cat_codes + ffhq_latent_avg.repeat(cat_codes.shape[0], 1, 1) | |
dog_codes = dog_encoder(aligned_face.unsqueeze(0).to(device).float()) | |
dog_codes = dog_codes + ffhq_latent_avg.repeat(dog_codes.shape[0], 1, 1) | |
animal = "cat" | |
npimage = gen_im(animal) | |
imageio.imwrite('filename.jpeg', npimage) | |
return 'filename.jpeg' | |
title = "PetBreeder v1.1" | |
description = "Gradio Demo for PetBreeder." | |
gr.Interface(inference, | |
[gr.inputs.Image(type="pil")], | |
gr.outputs.Image(type="file"), | |
title=title, | |
description=description).launch() | |