naonauno commited on
Commit
0071020
·
verified ·
1 Parent(s): 65d6cf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -138
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
- import cv2
5
  from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
6
  from model import UNet2DConditionModelEx
7
  from pipeline import StableDiffusionControlLoraV3Pipeline
@@ -11,189 +10,229 @@ from huggingface_hub import login
11
  import spaces
12
  import random
13
  from pathlib import Path
 
 
 
 
14
 
15
  # Login using the token
16
  login(token=os.environ.get("HF_TOKEN"))
17
 
18
- # For deterministic generation
19
- torch.manual_seed(42)
20
- torch.backends.cudnn.deterministic = True
21
-
22
- # Initialize the models
23
- base_model = "runwayml/stable-diffusion-v1-5"
24
-
25
- # Load the custom UNet
26
- unet = UNet2DConditionModelEx.from_pretrained(
27
- base_model,
28
- subfolder="unet"
29
- )
30
-
31
- unet = unet.add_extra_conditions("ow-gbi-control-lora")
32
-
33
- pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
34
- base_model,
35
- unet=unet
36
- )
37
-
38
- # Performance optimizations
39
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
40
- pipe.enable_attention_slicing()
41
- pipe.enable_vae_slicing()
42
-
43
- pipe.load_lora_weights(
44
- "models",
45
- weight_name="40kHalf.safetensors"
46
- )
47
-
48
- def get_random_condition_image():
49
- conditions_dir = Path("conditions")
50
- if conditions_dir.exists():
51
- image_files = list(conditions_dir.glob("*.[jp][pn][g]"))
52
- if image_files:
53
- random_image = random.choice(image_files)
54
- return str(random_image)
55
- return None
56
-
57
- def get_canny_image(image, low_threshold=100, high_threshold=200):
58
- if isinstance(image, Image.Image):
59
- image = np.array(image)
60
- elif isinstance(image, str):
61
- image = np.array(Image.open(image))
62
-
63
- if len(image.shape) == 2:
64
- image = np.stack([image] * 3, axis=-1)
65
- elif image.shape[2] == 4:
66
- image = image[..., :3]
67
-
68
- canny_image = cv2.Canny(image, low_threshold, high_threshold)
69
- canny_image = np.stack([canny_image] * 3, axis=-1)
70
- return Image.fromarray(canny_image)
71
 
72
- @spaces.GPU(duration=180) # Reduced to 3 minutes
73
- def generate_image(input_image, prompt, negative_prompt, guidance_scale, steps, low_threshold, high_threshold, seed, progress=gr.Progress()):
74
- if input_image is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  raise gr.Error("Please provide an input image!")
76
 
77
  try:
 
 
 
 
78
  if seed is not None and seed != "":
79
  try:
80
  generator = torch.Generator().manual_seed(int(seed))
 
81
  except ValueError:
82
  generator = torch.Generator()
 
83
  else:
84
  generator = torch.Generator()
 
 
 
 
 
 
 
 
 
85
 
86
- progress(0.1, desc="Processing input image...")
87
- canny_image = get_canny_image(input_image, low_threshold, high_threshold)
88
-
89
  progress(0.3, desc="Generating image...")
90
  with torch.no_grad():
91
- image = pipe(
92
  prompt=prompt,
93
  negative_prompt=negative_prompt,
94
  num_inference_steps=int(steps),
95
  guidance_scale=float(guidance_scale),
96
- image=canny_image,
 
97
  extra_condition_scale=1.0,
98
- generator=generator
99
- ).images[0]
 
 
 
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  progress(1.0, desc="Done!")
102
- return canny_image, image
103
-
 
104
  except Exception as e:
105
  raise gr.Error(f"An error occurred: {str(e)}")
106
 
107
- def random_image_click():
108
- image_path = get_random_condition_image()
109
- if image_path:
110
- return Image.open(image_path)
111
- return None
112
-
113
- # Example data with reduced steps
114
- examples = [
115
- [
116
- "conditions/example1.jpg",
117
- "a futuristic cyberpunk city",
118
- "blurry, bad quality",
119
- 7.5,
120
- 25, # Reduced steps
121
- 100,
122
- 200,
123
- 42
124
- ],
125
- [
126
- "conditions/example2.jpg",
127
- "a serene mountain landscape",
128
- "dark, gloomy",
129
- 7.0,
130
- 25, # Reduced steps
131
- 120,
132
- 180,
133
- 123
134
- ]
135
- ]
136
 
137
  # Create the Gradio interface
138
- with gr.Blocks() as demo:
139
  gr.Markdown(
140
  """
141
- # Control LoRA v3 Demo
142
- ⚠️ Warning: This is a demo of Control LoRA v3. Generation might take a few minutes.
143
- For better results with ZeroGPU, it's recommended to use 20-30 steps.
144
- The model uses edge detection to guide the image generation process.
145
  """
146
  )
147
 
148
  with gr.Row():
149
  with gr.Column():
150
- input_image = gr.Image(label="Input Image", type="numpy")
151
- random_image_btn = gr.Button("Load Random Reference Image")
152
- status_text = gr.Textbox(label="Status", value="Ready", interactive=False)
153
 
154
  prompt = gr.Textbox(
155
  label="Prompt",
156
- placeholder="Enter your prompt here... (e.g., 'a futuristic cyberpunk city')"
157
  )
158
  negative_prompt = gr.Textbox(
159
  label="Negative Prompt",
160
- placeholder="Enter things you don't want to see... (e.g., 'blurry, bad quality')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  )
162
- with gr.Row():
163
- low_threshold = gr.Slider(minimum=1, maximum=255, value=100, label="Canny Low Threshold")
164
- high_threshold = gr.Slider(minimum=1, maximum=255, value=200, label="Canny High Threshold")
165
- guidance_scale = gr.Slider(minimum=1, maximum=20, value=7.5, label="Guidance Scale")
166
- steps = gr.Slider(minimum=1, maximum=50, value=25, label="Steps") # Reduced max steps
167
- seed = gr.Textbox(label="Seed (empty for random)", placeholder="Enter a number for reproducible results")
168
  generate = gr.Button("Generate")
169
 
170
  with gr.Column():
171
- canny_output = gr.Image(label="Canny Edge Detection")
172
  result = gr.Image(label="Generated Image")
173
 
174
- # Set up example gallery
175
- gr.Examples(
176
- examples=examples,
177
- inputs=[
178
- input_image,
179
- prompt,
180
- negative_prompt,
181
- guidance_scale,
182
- steps,
183
- low_threshold,
184
- high_threshold,
185
- seed
186
- ],
187
- outputs=[canny_output, result],
188
- fn=generate_image,
189
- cache_examples=True
190
- )
191
-
192
- random_image_btn.click(
193
- fn=random_image_click,
194
- outputs=input_image
195
- )
196
-
197
  generate.click(
198
  fn=generate_image,
199
  inputs=[
@@ -202,11 +241,9 @@ with gr.Blocks() as demo:
202
  negative_prompt,
203
  guidance_scale,
204
  steps,
205
- low_threshold,
206
- high_threshold,
207
  seed
208
  ],
209
- outputs=[canny_output, result]
210
  )
211
 
212
  demo.queue()
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
 
4
  from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
5
  from model import UNet2DConditionModelEx
6
  from pipeline import StableDiffusionControlLoraV3Pipeline
 
10
  import spaces
11
  import random
12
  from pathlib import Path
13
+ import hashlib
14
+ import datetime
15
+ import json
16
+ from tqdm import tqdm
17
 
18
  # Login using the token
19
  login(token=os.environ.get("HF_TOKEN"))
20
 
21
+ # Setup directories
22
+ HF_SPACE_ID = "naonauno/groundbi-factory"
23
+ OUTPUT_DIR = "/home/user/outputs"
24
+
25
+ os.makedirs('outputs', exist_ok=True)
26
+ os.makedirs('metadata', exist_ok=True)
27
+ metadata_dir = 'metadata'
28
+
29
+ class AdvancedGenerationTracker:
30
+ def __init__(self, total_steps):
31
+ self.progress_bar = tqdm(total=total_steps, desc="Image Generation")
32
+ self.current_step = 0
33
+ self.memory_usage_log = []
34
+
35
+ def update_progress(self, step_size=1):
36
+ self.current_step += step_size
37
+ self.progress_bar.update(step_size)
38
+ self._log_memory_usage()
39
+
40
+ def _log_memory_usage(self):
41
+ if torch.cuda.is_available():
42
+ memory_info = {
43
+ 'step': self.current_step,
44
+ 'cuda_allocated': torch.cuda.memory_allocated(),
45
+ 'cuda_reserved': torch.cuda.memory_reserved(),
46
+ 'cuda_max_allocated': torch.cuda.max_memory_allocated()
47
+ }
48
+ self.memory_usage_log.append(memory_info)
49
+
50
+ def finalize(self):
51
+ self.progress_bar.close()
52
+ return self.memory_usage_log
53
+
54
+ def setup_pipeline():
55
+ unet = UNet2DConditionModelEx.from_pretrained(
56
+ "runwayml/stable-diffusion-v1-5",
57
+ subfolder="unet"
58
+ )
59
+ unet = unet.add_extra_conditions("ow-gbi-control-lora")
60
+
61
+ pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
62
+ "runwayml/stable-diffusion-v1-5",
63
+ unet=unet
64
+ )
65
+
66
+ # Performance optimizations
67
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
68
+ pipe.enable_attention_slicing()
69
+ pipe.enable_vae_slicing()
 
 
 
 
70
 
71
+ pipe.load_lora_weights(
72
+ "models",
73
+ weight_name="40kHalf.safetensors"
74
+ )
75
+ return pipe
76
+
77
+ pipe = setup_pipeline()
78
+
79
+ def save_to_space(image, filename):
80
+ path = os.path.join(OUTPUT_DIR, filename)
81
+ os.makedirs(os.path.dirname(path), exist_ok=True)
82
+ image.save(path)
83
+ return path
84
+
85
+ def generate_advanced_filename(prompt, seed, style=None):
86
+ hash_input = f"{prompt}_{seed}"
87
+ filename_hash = hashlib.md5(hash_input.encode()).hexdigest()[:8]
88
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
89
+ style_prefix = f"{style}_" if style else ""
90
+ return f"{style_prefix}{timestamp}_{filename_hash}"
91
+
92
+ def export_generation_metadata(metadata, output_path):
93
+ with open(output_path, 'w') as f:
94
+ json.dump(metadata, f, indent=2)
95
+ return output_path
96
+
97
+ @spaces.GPU(duration=180)
98
+ def generate_image(
99
+ image,
100
+ prompt,
101
+ negative_prompt,
102
+ guidance_scale,
103
+ steps,
104
+ seed,
105
+ strength=0.8,
106
+ num_images=1,
107
+ progress=gr.Progress()
108
+ ):
109
+ if image is None:
110
  raise gr.Error("Please provide an input image!")
111
 
112
  try:
113
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
114
+ output_base_dir = os.path.join('outputs', timestamp)
115
+ os.makedirs(output_base_dir, exist_ok=True)
116
+
117
  if seed is not None and seed != "":
118
  try:
119
  generator = torch.Generator().manual_seed(int(seed))
120
+ current_seed = int(seed)
121
  except ValueError:
122
  generator = torch.Generator()
123
+ current_seed = random.randint(1, 1000000)
124
  else:
125
  generator = torch.Generator()
126
+ current_seed = random.randint(1, 1000000)
127
+
128
+ tracker = AdvancedGenerationTracker(steps)
129
+
130
+ def callback_on_step_end(pipeline, step, timestep, callback_kwargs):
131
+ tracker.update_progress()
132
+ if progress is not None:
133
+ progress(step/steps)
134
+ return {}
135
 
 
 
 
136
  progress(0.3, desc="Generating image...")
137
  with torch.no_grad():
138
+ result = pipe(
139
  prompt=prompt,
140
  negative_prompt=negative_prompt,
141
  num_inference_steps=int(steps),
142
  guidance_scale=float(guidance_scale),
143
+ image=image,
144
+ strength=strength,
145
  extra_condition_scale=1.0,
146
+ generator=generator,
147
+ num_images_per_prompt=num_images,
148
+ callback_on_step_end=callback_on_step_end
149
+ )
150
+
151
+ generated_image = result.images[0]
152
 
153
+ # Save the image
154
+ filename = generate_advanced_filename(prompt, current_seed)
155
+ image_path = os.path.join(output_base_dir, f"{filename}.png")
156
+ generated_image.save(image_path)
157
+ save_to_space(generated_image, f"{filename}.png")
158
+
159
+ # Save metadata
160
+ generation_metadata = {
161
+ "generation_timestamp": timestamp,
162
+ "prompt": prompt,
163
+ "negative_prompt": negative_prompt,
164
+ "seed": current_seed,
165
+ "generation_parameters": {
166
+ "guidance_scale": guidance_scale,
167
+ "steps": steps,
168
+ "strength": strength,
169
+ "num_images": num_images
170
+ },
171
+ "image_file": os.path.basename(image_path)
172
+ }
173
+
174
+ metadata_path = os.path.join(metadata_dir, f"{filename}_metadata.json")
175
+ export_generation_metadata(generation_metadata, metadata_path)
176
+
177
+ memory_log = tracker.finalize()
178
  progress(1.0, desc="Done!")
179
+
180
+ return generated_image
181
+
182
  except Exception as e:
183
  raise gr.Error(f"An error occurred: {str(e)}")
184
 
185
+ css = """
186
+ .container { max-width: 900px; margin: auto; }
187
+ .parameter-hint { font-size: 0.8em; color: #666; margin-top: -5px; }
188
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  # Create the Gradio interface
191
+ with gr.Blocks(css=css) as demo:
192
  gr.Markdown(
193
  """
194
+ # Terrain Generator
195
+ ⚠️ Warning: This is a demo running on ZeroGPU. Generation might take a few minutes.
196
+ For best results, use 15-20 steps for generation.
 
197
  """
198
  )
199
 
200
  with gr.Row():
201
  with gr.Column():
202
+ input_image = gr.Image(label="Input Image", type="pil")
 
 
203
 
204
  prompt = gr.Textbox(
205
  label="Prompt",
206
+ placeholder="Describe the terrain..."
207
  )
208
  negative_prompt = gr.Textbox(
209
  label="Negative Prompt",
210
+ placeholder="What to avoid..."
211
+ )
212
+ guidance_scale = gr.Slider(
213
+ label="Guidance Scale",
214
+ minimum=1,
215
+ maximum=20,
216
+ value=7.5,
217
+ info="Higher = more prompt adherence, Lower = more creativity"
218
+ )
219
+ steps = gr.Slider(
220
+ label="Steps",
221
+ minimum=1,
222
+ maximum=50,
223
+ value=20,
224
+ info="More steps = higher quality but slower"
225
+ )
226
+ seed = gr.Textbox(
227
+ label="Seed (empty for random)",
228
+ placeholder="Enter a number for reproducible results",
229
+ info="Controls randomness. Same seed = same output."
230
  )
 
 
 
 
 
 
231
  generate = gr.Button("Generate")
232
 
233
  with gr.Column():
 
234
  result = gr.Image(label="Generated Image")
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  generate.click(
237
  fn=generate_image,
238
  inputs=[
 
241
  negative_prompt,
242
  guidance_scale,
243
  steps,
 
 
244
  seed
245
  ],
246
+ outputs=result
247
  )
248
 
249
  demo.queue()