Spaces:
Runtime error
Runtime error
import os | |
from PIL import Image | |
import torch | |
import gradio as gr | |
import torch | |
torch.backends.cudnn.benchmark = True | |
from torchvision import transforms, utils | |
from util import * | |
from PIL import Image | |
import math | |
import random | |
import numpy as np | |
from torch import nn, autograd, optim | |
from torch.nn import functional as F | |
from tqdm import tqdm | |
import lpips | |
from model import * | |
#from e4e_projection import projection as e4e_projection | |
from copy import deepcopy | |
import imageio | |
import os | |
import sys | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
from argparse import Namespace | |
from e4e.models.psp import pSp | |
from util import * | |
from huggingface_hub import hf_hub_download | |
device= 'cpu' | |
model_path_e = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512_cat.pt") | |
ckpt = torch.load(model_path_e, map_location='cpu') | |
opts = ckpt['opts'] | |
opts['checkpoint_path'] = model_path_e | |
opts= Namespace(**opts) | |
net = pSp(opts, device).eval().to(device) | |
def projection(img, name, device='cuda'): | |
transform = transforms.Compose( | |
[ | |
transforms.Resize(256), | |
transforms.CenterCrop(256), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
img = transform(img).unsqueeze(0).to(device) | |
images, w_plus = net(img, randomize_noise=False, return_latents=True) | |
result_file = {} | |
result_file['latent'] = w_plus[0] | |
torch.save(result_file, name) | |
return w_plus[0] | |
def inference(img): | |
img.save('out.jpg') | |
aligned_face = align_face('out.jpg') | |
my_w = projection(aligned_face, "test.pt", device).unsqueeze(0) | |
npimage = my_w.permute(1, 2, 0).detach().numpy() | |
imageio.imwrite('filename.jpeg', npimage) | |
return 'filename.jpeg' | |
title = "PetBreeder v1.1" | |
description = "Gradio Demo for PetBreeder." | |
gr.Interface(inference, | |
[gr.inputs.Image(type="pil")], | |
gr.outputs.Image(type="file"), | |
title=title, | |
description=description).launch() | |