Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
-
import cv2
|
5 |
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
|
6 |
from model import UNet2DConditionModelEx
|
7 |
from pipeline import StableDiffusionControlLoraV3Pipeline
|
@@ -11,189 +10,229 @@ from huggingface_hub import login
|
|
11 |
import spaces
|
12 |
import random
|
13 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
14 |
|
15 |
# Login using the token
|
16 |
login(token=os.environ.get("HF_TOKEN"))
|
17 |
|
18 |
-
#
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
canny_image = cv2.Canny(image, low_threshold, high_threshold)
|
69 |
-
canny_image = np.stack([canny_image] * 3, axis=-1)
|
70 |
-
return Image.fromarray(canny_image)
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
raise gr.Error("Please provide an input image!")
|
76 |
|
77 |
try:
|
|
|
|
|
|
|
|
|
78 |
if seed is not None and seed != "":
|
79 |
try:
|
80 |
generator = torch.Generator().manual_seed(int(seed))
|
|
|
81 |
except ValueError:
|
82 |
generator = torch.Generator()
|
|
|
83 |
else:
|
84 |
generator = torch.Generator()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
progress(0.1, desc="Processing input image...")
|
87 |
-
canny_image = get_canny_image(input_image, low_threshold, high_threshold)
|
88 |
-
|
89 |
progress(0.3, desc="Generating image...")
|
90 |
with torch.no_grad():
|
91 |
-
|
92 |
prompt=prompt,
|
93 |
negative_prompt=negative_prompt,
|
94 |
num_inference_steps=int(steps),
|
95 |
guidance_scale=float(guidance_scale),
|
96 |
-
image=
|
|
|
97 |
extra_condition_scale=1.0,
|
98 |
-
generator=generator
|
99 |
-
|
|
|
|
|
|
|
|
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
progress(1.0, desc="Done!")
|
102 |
-
|
103 |
-
|
|
|
104 |
except Exception as e:
|
105 |
raise gr.Error(f"An error occurred: {str(e)}")
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
return None
|
112 |
-
|
113 |
-
# Example data with reduced steps
|
114 |
-
examples = [
|
115 |
-
[
|
116 |
-
"conditions/example1.jpg",
|
117 |
-
"a futuristic cyberpunk city",
|
118 |
-
"blurry, bad quality",
|
119 |
-
7.5,
|
120 |
-
25, # Reduced steps
|
121 |
-
100,
|
122 |
-
200,
|
123 |
-
42
|
124 |
-
],
|
125 |
-
[
|
126 |
-
"conditions/example2.jpg",
|
127 |
-
"a serene mountain landscape",
|
128 |
-
"dark, gloomy",
|
129 |
-
7.0,
|
130 |
-
25, # Reduced steps
|
131 |
-
120,
|
132 |
-
180,
|
133 |
-
123
|
134 |
-
]
|
135 |
-
]
|
136 |
|
137 |
# Create the Gradio interface
|
138 |
-
with gr.Blocks() as demo:
|
139 |
gr.Markdown(
|
140 |
"""
|
141 |
-
#
|
142 |
-
⚠️ Warning: This is a demo
|
143 |
-
For
|
144 |
-
The model uses edge detection to guide the image generation process.
|
145 |
"""
|
146 |
)
|
147 |
|
148 |
with gr.Row():
|
149 |
with gr.Column():
|
150 |
-
input_image = gr.Image(label="Input Image", type="
|
151 |
-
random_image_btn = gr.Button("Load Random Reference Image")
|
152 |
-
status_text = gr.Textbox(label="Status", value="Ready", interactive=False)
|
153 |
|
154 |
prompt = gr.Textbox(
|
155 |
label="Prompt",
|
156 |
-
placeholder="
|
157 |
)
|
158 |
negative_prompt = gr.Textbox(
|
159 |
label="Negative Prompt",
|
160 |
-
placeholder="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
)
|
162 |
-
with gr.Row():
|
163 |
-
low_threshold = gr.Slider(minimum=1, maximum=255, value=100, label="Canny Low Threshold")
|
164 |
-
high_threshold = gr.Slider(minimum=1, maximum=255, value=200, label="Canny High Threshold")
|
165 |
-
guidance_scale = gr.Slider(minimum=1, maximum=20, value=7.5, label="Guidance Scale")
|
166 |
-
steps = gr.Slider(minimum=1, maximum=50, value=25, label="Steps") # Reduced max steps
|
167 |
-
seed = gr.Textbox(label="Seed (empty for random)", placeholder="Enter a number for reproducible results")
|
168 |
generate = gr.Button("Generate")
|
169 |
|
170 |
with gr.Column():
|
171 |
-
canny_output = gr.Image(label="Canny Edge Detection")
|
172 |
result = gr.Image(label="Generated Image")
|
173 |
|
174 |
-
# Set up example gallery
|
175 |
-
gr.Examples(
|
176 |
-
examples=examples,
|
177 |
-
inputs=[
|
178 |
-
input_image,
|
179 |
-
prompt,
|
180 |
-
negative_prompt,
|
181 |
-
guidance_scale,
|
182 |
-
steps,
|
183 |
-
low_threshold,
|
184 |
-
high_threshold,
|
185 |
-
seed
|
186 |
-
],
|
187 |
-
outputs=[canny_output, result],
|
188 |
-
fn=generate_image,
|
189 |
-
cache_examples=True
|
190 |
-
)
|
191 |
-
|
192 |
-
random_image_btn.click(
|
193 |
-
fn=random_image_click,
|
194 |
-
outputs=input_image
|
195 |
-
)
|
196 |
-
|
197 |
generate.click(
|
198 |
fn=generate_image,
|
199 |
inputs=[
|
@@ -202,11 +241,9 @@ with gr.Blocks() as demo:
|
|
202 |
negative_prompt,
|
203 |
guidance_scale,
|
204 |
steps,
|
205 |
-
low_threshold,
|
206 |
-
high_threshold,
|
207 |
seed
|
208 |
],
|
209 |
-
outputs=
|
210 |
)
|
211 |
|
212 |
demo.queue()
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
|
|
4 |
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
|
5 |
from model import UNet2DConditionModelEx
|
6 |
from pipeline import StableDiffusionControlLoraV3Pipeline
|
|
|
10 |
import spaces
|
11 |
import random
|
12 |
from pathlib import Path
|
13 |
+
import hashlib
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from tqdm import tqdm
|
17 |
|
18 |
# Login using the token
|
19 |
login(token=os.environ.get("HF_TOKEN"))
|
20 |
|
21 |
+
# Setup directories
|
22 |
+
HF_SPACE_ID = "naonauno/groundbi-factory"
|
23 |
+
OUTPUT_DIR = "/home/user/outputs"
|
24 |
+
|
25 |
+
os.makedirs('outputs', exist_ok=True)
|
26 |
+
os.makedirs('metadata', exist_ok=True)
|
27 |
+
metadata_dir = 'metadata'
|
28 |
+
|
29 |
+
class AdvancedGenerationTracker:
|
30 |
+
def __init__(self, total_steps):
|
31 |
+
self.progress_bar = tqdm(total=total_steps, desc="Image Generation")
|
32 |
+
self.current_step = 0
|
33 |
+
self.memory_usage_log = []
|
34 |
+
|
35 |
+
def update_progress(self, step_size=1):
|
36 |
+
self.current_step += step_size
|
37 |
+
self.progress_bar.update(step_size)
|
38 |
+
self._log_memory_usage()
|
39 |
+
|
40 |
+
def _log_memory_usage(self):
|
41 |
+
if torch.cuda.is_available():
|
42 |
+
memory_info = {
|
43 |
+
'step': self.current_step,
|
44 |
+
'cuda_allocated': torch.cuda.memory_allocated(),
|
45 |
+
'cuda_reserved': torch.cuda.memory_reserved(),
|
46 |
+
'cuda_max_allocated': torch.cuda.max_memory_allocated()
|
47 |
+
}
|
48 |
+
self.memory_usage_log.append(memory_info)
|
49 |
+
|
50 |
+
def finalize(self):
|
51 |
+
self.progress_bar.close()
|
52 |
+
return self.memory_usage_log
|
53 |
+
|
54 |
+
def setup_pipeline():
|
55 |
+
unet = UNet2DConditionModelEx.from_pretrained(
|
56 |
+
"runwayml/stable-diffusion-v1-5",
|
57 |
+
subfolder="unet"
|
58 |
+
)
|
59 |
+
unet = unet.add_extra_conditions("ow-gbi-control-lora")
|
60 |
+
|
61 |
+
pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
|
62 |
+
"runwayml/stable-diffusion-v1-5",
|
63 |
+
unet=unet
|
64 |
+
)
|
65 |
+
|
66 |
+
# Performance optimizations
|
67 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
68 |
+
pipe.enable_attention_slicing()
|
69 |
+
pipe.enable_vae_slicing()
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
pipe.load_lora_weights(
|
72 |
+
"models",
|
73 |
+
weight_name="40kHalf.safetensors"
|
74 |
+
)
|
75 |
+
return pipe
|
76 |
+
|
77 |
+
pipe = setup_pipeline()
|
78 |
+
|
79 |
+
def save_to_space(image, filename):
|
80 |
+
path = os.path.join(OUTPUT_DIR, filename)
|
81 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
82 |
+
image.save(path)
|
83 |
+
return path
|
84 |
+
|
85 |
+
def generate_advanced_filename(prompt, seed, style=None):
|
86 |
+
hash_input = f"{prompt}_{seed}"
|
87 |
+
filename_hash = hashlib.md5(hash_input.encode()).hexdigest()[:8]
|
88 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
89 |
+
style_prefix = f"{style}_" if style else ""
|
90 |
+
return f"{style_prefix}{timestamp}_{filename_hash}"
|
91 |
+
|
92 |
+
def export_generation_metadata(metadata, output_path):
|
93 |
+
with open(output_path, 'w') as f:
|
94 |
+
json.dump(metadata, f, indent=2)
|
95 |
+
return output_path
|
96 |
+
|
97 |
+
@spaces.GPU(duration=180)
|
98 |
+
def generate_image(
|
99 |
+
image,
|
100 |
+
prompt,
|
101 |
+
negative_prompt,
|
102 |
+
guidance_scale,
|
103 |
+
steps,
|
104 |
+
seed,
|
105 |
+
strength=0.8,
|
106 |
+
num_images=1,
|
107 |
+
progress=gr.Progress()
|
108 |
+
):
|
109 |
+
if image is None:
|
110 |
raise gr.Error("Please provide an input image!")
|
111 |
|
112 |
try:
|
113 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
114 |
+
output_base_dir = os.path.join('outputs', timestamp)
|
115 |
+
os.makedirs(output_base_dir, exist_ok=True)
|
116 |
+
|
117 |
if seed is not None and seed != "":
|
118 |
try:
|
119 |
generator = torch.Generator().manual_seed(int(seed))
|
120 |
+
current_seed = int(seed)
|
121 |
except ValueError:
|
122 |
generator = torch.Generator()
|
123 |
+
current_seed = random.randint(1, 1000000)
|
124 |
else:
|
125 |
generator = torch.Generator()
|
126 |
+
current_seed = random.randint(1, 1000000)
|
127 |
+
|
128 |
+
tracker = AdvancedGenerationTracker(steps)
|
129 |
+
|
130 |
+
def callback_on_step_end(pipeline, step, timestep, callback_kwargs):
|
131 |
+
tracker.update_progress()
|
132 |
+
if progress is not None:
|
133 |
+
progress(step/steps)
|
134 |
+
return {}
|
135 |
|
|
|
|
|
|
|
136 |
progress(0.3, desc="Generating image...")
|
137 |
with torch.no_grad():
|
138 |
+
result = pipe(
|
139 |
prompt=prompt,
|
140 |
negative_prompt=negative_prompt,
|
141 |
num_inference_steps=int(steps),
|
142 |
guidance_scale=float(guidance_scale),
|
143 |
+
image=image,
|
144 |
+
strength=strength,
|
145 |
extra_condition_scale=1.0,
|
146 |
+
generator=generator,
|
147 |
+
num_images_per_prompt=num_images,
|
148 |
+
callback_on_step_end=callback_on_step_end
|
149 |
+
)
|
150 |
+
|
151 |
+
generated_image = result.images[0]
|
152 |
|
153 |
+
# Save the image
|
154 |
+
filename = generate_advanced_filename(prompt, current_seed)
|
155 |
+
image_path = os.path.join(output_base_dir, f"{filename}.png")
|
156 |
+
generated_image.save(image_path)
|
157 |
+
save_to_space(generated_image, f"{filename}.png")
|
158 |
+
|
159 |
+
# Save metadata
|
160 |
+
generation_metadata = {
|
161 |
+
"generation_timestamp": timestamp,
|
162 |
+
"prompt": prompt,
|
163 |
+
"negative_prompt": negative_prompt,
|
164 |
+
"seed": current_seed,
|
165 |
+
"generation_parameters": {
|
166 |
+
"guidance_scale": guidance_scale,
|
167 |
+
"steps": steps,
|
168 |
+
"strength": strength,
|
169 |
+
"num_images": num_images
|
170 |
+
},
|
171 |
+
"image_file": os.path.basename(image_path)
|
172 |
+
}
|
173 |
+
|
174 |
+
metadata_path = os.path.join(metadata_dir, f"{filename}_metadata.json")
|
175 |
+
export_generation_metadata(generation_metadata, metadata_path)
|
176 |
+
|
177 |
+
memory_log = tracker.finalize()
|
178 |
progress(1.0, desc="Done!")
|
179 |
+
|
180 |
+
return generated_image
|
181 |
+
|
182 |
except Exception as e:
|
183 |
raise gr.Error(f"An error occurred: {str(e)}")
|
184 |
|
185 |
+
css = """
|
186 |
+
.container { max-width: 900px; margin: auto; }
|
187 |
+
.parameter-hint { font-size: 0.8em; color: #666; margin-top: -5px; }
|
188 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
# Create the Gradio interface
|
191 |
+
with gr.Blocks(css=css) as demo:
|
192 |
gr.Markdown(
|
193 |
"""
|
194 |
+
# Terrain Generator
|
195 |
+
⚠️ Warning: This is a demo running on ZeroGPU. Generation might take a few minutes.
|
196 |
+
For best results, use 15-20 steps for generation.
|
|
|
197 |
"""
|
198 |
)
|
199 |
|
200 |
with gr.Row():
|
201 |
with gr.Column():
|
202 |
+
input_image = gr.Image(label="Input Image", type="pil")
|
|
|
|
|
203 |
|
204 |
prompt = gr.Textbox(
|
205 |
label="Prompt",
|
206 |
+
placeholder="Describe the terrain..."
|
207 |
)
|
208 |
negative_prompt = gr.Textbox(
|
209 |
label="Negative Prompt",
|
210 |
+
placeholder="What to avoid..."
|
211 |
+
)
|
212 |
+
guidance_scale = gr.Slider(
|
213 |
+
label="Guidance Scale",
|
214 |
+
minimum=1,
|
215 |
+
maximum=20,
|
216 |
+
value=7.5,
|
217 |
+
info="Higher = more prompt adherence, Lower = more creativity"
|
218 |
+
)
|
219 |
+
steps = gr.Slider(
|
220 |
+
label="Steps",
|
221 |
+
minimum=1,
|
222 |
+
maximum=50,
|
223 |
+
value=20,
|
224 |
+
info="More steps = higher quality but slower"
|
225 |
+
)
|
226 |
+
seed = gr.Textbox(
|
227 |
+
label="Seed (empty for random)",
|
228 |
+
placeholder="Enter a number for reproducible results",
|
229 |
+
info="Controls randomness. Same seed = same output."
|
230 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
generate = gr.Button("Generate")
|
232 |
|
233 |
with gr.Column():
|
|
|
234 |
result = gr.Image(label="Generated Image")
|
235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
generate.click(
|
237 |
fn=generate_image,
|
238 |
inputs=[
|
|
|
241 |
negative_prompt,
|
242 |
guidance_scale,
|
243 |
steps,
|
|
|
|
|
244 |
seed
|
245 |
],
|
246 |
+
outputs=result
|
247 |
)
|
248 |
|
249 |
demo.queue()
|