File size: 7,500 Bytes
f56ede2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import cv2
import torch
from PIL import Image
import numpy as np
import yaml
import argparse
from controlnet_aux import OpenposeDetector
from diffusers import (
    StableDiffusionControlNetPipeline, 
    ControlNetModel, 
    UniPCMultistepScheduler
)

from utils.download import load_image
from utils.plot import image_grid
import os
from tqdm import tqdm
import re
import uuid

def load_config(config_path):
    try:
        with open(config_path, 'r') as file:
            return yaml.safe_load(file)
    except Exception as e:
        raise ValueError(f"Error loading config file: {e}")

def initialize_controlnet(config):
    model_id = config['model_id']
    local_dir = config.get('local_dir', model_id)
    return ControlNetModel.from_pretrained(
        local_dir if local_dir != model_id else model_id,
        torch_dtype=torch.float16
    )

def initialize_pipeline(controlnet, config):
    model_id = config['model_id']
    local_dir = config.get('local_dir', model_id)
    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        local_dir if local_dir != model_id else model_id,
        controlnet=controlnet,
        torch_dtype=torch.float16
    )
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    return pipe

def setup_device(pipe):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        pipe.enable_model_cpu_offload()
    pipe.to(device)
    return device

def generate_images(pipe, prompts, pose_images, generators, negative_prompts, num_steps, guidance_scale, controlnet_conditioning_scale, width, height):
    return pipe(
        prompts,
        pose_images,
        negative_prompt=negative_prompts,
        generator=generators,
        num_inference_steps=num_steps,
        guidance_scale=guidance_scale,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        width=width,
        height=height
    ).images

def infer(args):
    # Load configuration
    configs = load_config(args.config_path)
    
    # Initialize models
    controlnet_detector = OpenposeDetector.from_pretrained(
        configs[2]['model_id']  # lllyasviel/ControlNet
    )
    controlnet = initialize_controlnet(configs[0])
    pipe = initialize_pipeline(controlnet, configs[1])
    
    # Setup device
    device = setup_device(pipe)
    
    # Load and process image
    try:
        if args.input_image:
            demo_image = Image.open(args.input_image).convert("RGB")
        elif args.image_url:
            demo_image = load_image(args.image_url)
        else:
            raise ValueError("Either --input_image or --image_url must be provided")
    except Exception as e:
        raise ValueError(f"Error loading image: {e}")
    
    poses = [controlnet_detector(demo_image)]
    
    # Generate images
    generators = [torch.Generator(device="cpu").manual_seed(args.seed + i) for i in range(len(poses))]
    
    output_images = generate_images(
        pipe,
        [args.prompt] * len(generators),
        poses,
        generators,
        [args.negative_prompt] * len(generators),
        args.num_steps,
        args.guidance_scale,
        args.controlnet_conditioning_scale,
        args.width,
        args.height
    )
    
    # Save images if save_output is True
    if args.save_output:
        os.makedirs(args.output_dir, exist_ok=True)
        for i, img in enumerate(tqdm(output_images, desc="Saving images")):
            if args.use_prompt_as_output_name:
                # Sanitize prompt for filename (replace spaces and special characters)
                sanitized_prompt = re.sub(r'[^\w\s-]', '', args.prompt).replace(' ', '_').lower()
                filename = f"{sanitized_prompt}_{i}.png"
            else:
                # Use UUID for filename
                filename = f"{uuid.uuid4()}_{i}.png"
            img.save(os.path.join(args.output_dir, filename))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="ControlNet image generation with pose detection")
    # Create mutually exclusive group for input_image and image_url
    image_group = parser.add_mutually_exclusive_group(required=True)
    image_group.add_argument("--input_image", type=str, default=None,
                             help="Path to local input image (default: tests/test_data/yoga1.jpg)")
    image_group.add_argument("--image_url", type=str, default=None,
                             help="URL of input image (e.g., https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg)")
    
    parser.add_argument("--config_path", type=str, default="configs/model_ckpts.yaml", 
                        help="Path to configuration YAML file")
    parser.add_argument("--prompt", type=str, default="a man is doing yoga",
                        help="Text prompt for image generation")
    parser.add_argument("--negative_prompt", type=str, 
                        default="monochrome, lowres, bad anatomy, worst quality, low quality",
                        help="Negative prompt for image generation")
    parser.add_argument("--num_steps", type=int, default=20,
                        help="Number of inference steps")
    parser.add_argument("--seed", type=int, default=2,
                        help="Random seed for generation")
    parser.add_argument("--width", type=int, default=512,
                        help="Width of the generated image")
    parser.add_argument("--height", type=int, default=512,
                        help="Height of the generated image")
    parser.add_argument("--guidance_scale", type=float, default=7.5,
                        help="Guidance scale for prompt adherence")
    parser.add_argument("--controlnet_conditioning_scale", type=float, default=1.0,
                        help="ControlNet conditioning scale")
    parser.add_argument("--output_dir", type=str, default="tests/test_data",
                        help="Directory to save generated images")
    parser.add_argument("--use_prompt_as_output_name", action="store_true",
                        help="Use prompt as part of output image filename")
    parser.add_argument("--save_output", action="store_true artr",
                        help="Save generated images to output directory")
    
    args = parser.parse_args()
    infer(args)

# Using image_url
# python script.py \
#     --config_path configs/model_ckpts.yaml \
#     --image_url https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg \
#     --prompt "a man is doing yoga in a serene park" \
#     --negative_prompt "monochrome, lowres, bad anatomy" \
#     --num_steps 30 \
#     --seed 42 \
#     --width 512 \
#     --height 512 \
#     --guidance_scale 7.5 \
#     --controlnet_conditioning_scale 0.8 \
#     --output_dir "tests/test_data" \
#     --save_output

# Using input_image
# python script.py \
#     --config_path configs/model_ckpts.yaml \
#     --input_image "tests/test_data/yoga1.jpg" \
#     --prompt "a man is doing yoga in a serene park" \
#     --negative_prompt "monochrome, lowres, bad anatomy" \
#     --num_steps 30 \
#     --seed 42 \
#     --width 512 \
#     --height 512 \
#     --guidance_scale 7.5 \
#     --controlnet_conditioning_scale 0.8 \
#     --output_dir "tests/test_data" \ 
#     --save_output