File size: 5,704 Bytes
f15e0fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlus
from insightface.app import FaceAnalysis
from insightface.utils import face_align

from huggingface_hub import hf_hub_download
import torch

from PIL import Image
import cv2

import gradio as gr

hf_hub_download(repo_id='h94/IP-Adapter-FaceID', filename='ip-adapter-faceid-plus_sd15.bin', local_dir='IP-Adapter-FaceID')
hf_hub_download(repo_id='h94/IP-Adapter', filename='models/image_encoder/config.json', local_dir='IP-Adapter')
hf_hub_download(repo_id='h94/IP-Adapter', filename='models/image_encoder/pytorch_model.bin', local_dir='IP-Adapter')

def get_ip_model():
    base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
    vae_model_path = "stabilityai/sd-vae-ft-mse"
    image_encoder_path = "IP-Adapter/models/image_encoder"
    ip_ckpt = "IP-Adapter-FaceID/ip-adapter-faceid-plus_sd15.bin"

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    print(f'Using device: {device}')

    noise_scheduler = DDIMScheduler(num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012,
                                    beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1)

    vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch_dtype)
    pipe = StableDiffusionPipeline.from_pretrained(
        base_model_path,
        torch_dtype=torch_dtype,
        scheduler=noise_scheduler,
        vae=vae,
        feature_extractor=None,
        safety_checker=None
    )

    ip_model = IPAdapterFaceIDPlus(pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, torch_dtype=torch_dtype)
    return ip_model


def generate_images(prompt, img_filepath, negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality, blurry",
                    img_prompt_scale=0.5, num_inference_steps=30, seed=None, n_images=1):
    image = cv2.imread(img_filepath)
    faces = app.get(image)

    faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
    face_image = face_align.norm_crop(image, landmark=faces[0].kps, image_size=200) 
    images = ip_model.generate(
        prompt=prompt, negative_prompt=negative_prompt, face_image=face_image, faceid_embeds=faceid_embeds,
        num_samples=n_images, width=512, height=512, num_inference_steps=num_inference_steps, seed=seed,
        scale=img_prompt_scale,
    )
    return [images[0], Image.fromarray(face_image[..., [2, 1, 0]])]

if __name__ == "__main__":
    ip_model = get_ip_model()
    app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
    app.prepare(ctx_id=0, det_size=(512, 512), det_thresh=0.2)


    with gr.Blocks() as demo:
        gr.Markdown(
    """
    # ✨ Image Prompt Adapter With FaceID 🧙‍♂️
    
    Unleash the magic of generating whimsical images with just an image and a sprinkle of text! Learn the secrets here: [Magic Link](https://huggingface.co/h94/IP-Adapter-FaceID)
    
    🚀 This enchanting demo is designed to soar on GPU. While it can still dance on CPU, conjuring just one image might take up to 600 seconds—compared to the blink-of-an-eye magic on GPU! ✨
    """)
        with gr.Row():
            with gr.Column():
                demo_inputs = []
                demo_inputs.append(gr.Textbox(label='text prompt', value='A bold rider in a white horse'))
                demo_inputs.append(gr.Image(type='filepath', label='image prompt'))
                with gr.Accordion(label='Advanced options', open=False):
                    demo_inputs.append(gr.Textbox(label='negative text prompt', 
                                                  value="deformed hands,  watermark, text, deformed fingers, blurred faces, irregular face, irrregular body shape, ugly eyes, deformed face, squint, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, poorly framed, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, signature, cut off, draft, ugly eyes, squint, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, poorly framed, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, signature, cut off, draft, disfigured, kitsch, ugly, oversaturated, grain, low-res, Deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, ugly, disgusting, poorly drawn, childish, mutilated, mangled, old, surreal, 2 heads, 2 faces"))
                    demo_inputs.append(gr.Slider(maximum=1, minimum=0, value=0.5, step=0.05, label='image prompt scale'))
                btn = gr.Button("Generate")
    
            with gr.Column():
                demo_outputs = []
                demo_outputs.append(gr.Image(label='generated image'))
                demo_outputs.append(gr.Image(label='detected face', height=200, width=200))
        btn.click(generate_images, inputs=demo_inputs, outputs=demo_outputs)
        sample_prompts = [
            'A wizard casting spells in a coffee shop',
            'A penguin teaching a yoga class',
            'A robot composing a symphony',
            'A giraffe participating in a slam poetry contest',
            'A bold rider in a white horse'
        ]
        gr.Examples(sample_prompts, inputs=demo_inputs[0], label='Sample prompts')
    
    demo.launch(share=True, debug=True)