RageshAntony commited on
Commit
02919e4
·
verified ·
1 Parent(s): 6e5e1d5

multi image gen

Browse files
Files changed (1) hide show
  1. app.py +127 -50
app.py CHANGED
@@ -1,27 +1,69 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
-
5
- import spaces
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/stable-diffusion-3.5-large"
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.bfloat16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- @spaces.GPU(duration=65)
24
- def infer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  prompt,
26
  negative_prompt="",
27
  seed=42,
@@ -30,13 +72,23 @@ def infer(
30
  height=1024,
31
  guidance_scale=4.5,
32
  num_inference_steps=40,
33
- progress=gr.Progress(track_tqdm=True),
34
  ):
 
 
 
 
 
 
 
 
35
  if randomize_seed:
36
  seed = random.randint(0, MAX_SEED)
37
-
38
- generator = torch.Generator().manual_seed(seed)
39
-
 
 
40
  image = pipe(
41
  prompt=prompt,
42
  negative_prompt=negative_prompt,
@@ -46,25 +98,22 @@ def infer(
46
  height=height,
47
  generator=generator,
48
  ).images[0]
49
-
 
50
  return image, seed
51
 
52
-
53
- examples = [
54
- "A capybara wearing a suit holding a sign that reads Hello World",
55
- ]
56
-
57
  css = """
58
  #col-container {
59
  margin: 0 auto;
60
- max-width: 640px;
61
  }
62
  """
63
 
64
  with gr.Blocks(css=css) as demo:
65
  with gr.Column(elem_id="col-container"):
66
- gr.Markdown(" # [Stable Diffusion 3.5 Large (8B)](https://huggingface.co/stabilityai/stable-diffusion-3.5-large)")
67
- gr.Markdown("[Learn more](https://stability.ai/news/introducing-stable-diffusion-3-5) about the Stable Diffusion 3.5 series. Try on [Stability AI API](https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post), or [download model](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) to run locally with ComfyUI or diffusers.")
68
  with gr.Row():
69
  prompt = gr.Text(
70
  label="Prompt",
@@ -73,19 +122,15 @@ with gr.Blocks(css=css) as demo:
73
  placeholder="Enter your prompt",
74
  container=False,
75
  )
76
-
77
- run_button = gr.Button("Run", scale=0, variant="primary")
78
-
79
- result = gr.Image(label="Result", show_label=False)
80
-
81
  with gr.Accordion("Advanced Settings", open=False):
82
  negative_prompt = gr.Text(
83
  label="Negative prompt",
84
  max_lines=1,
85
  placeholder="Enter a negative prompt",
86
- visible=False,
87
  )
88
-
89
  seed = gr.Slider(
90
  label="Seed",
91
  minimum=0,
@@ -93,18 +138,17 @@ with gr.Blocks(css=css) as demo:
93
  step=1,
94
  value=0,
95
  )
96
-
97
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
98
-
99
  with gr.Row():
100
  width = gr.Slider(
101
  label="Width",
102
  minimum=512,
103
  maximum=MAX_IMAGE_SIZE,
104
  step=32,
105
- value=1024,
106
  )
107
-
108
  height = gr.Slider(
109
  label="Height",
110
  minimum=512,
@@ -112,7 +156,7 @@ with gr.Blocks(css=css) as demo:
112
  step=32,
113
  value=1024,
114
  )
115
-
116
  with gr.Row():
117
  guidance_scale = gr.Slider(
118
  label="Guidance scale",
@@ -121,19 +165,52 @@ with gr.Blocks(css=css) as demo:
121
  step=0.1,
122
  value=4.5,
123
  )
124
-
125
  num_inference_steps = gr.Slider(
126
  label="Number of inference steps",
127
  minimum=1,
128
  maximum=50,
129
  step=1,
130
- value=40,
131
  )
132
-
133
- gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True, cache_mode="lazy")
134
- gr.on(
135
- triggers=[run_button.click, prompt.submit],
136
- fn=infer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  inputs=[
138
  prompt,
139
  negative_prompt,
@@ -144,8 +221,8 @@ with gr.Blocks(css=css) as demo:
144
  guidance_scale,
145
  num_inference_steps,
146
  ],
147
- outputs=[result, seed],
148
  )
149
 
150
  if __name__ == "__main__":
151
- demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
 
 
4
  import torch
5
+ from diffusers import (
6
+ DiffusionPipeline, FluxPipeline, PixArtSigmaPipeline,
7
+ AuraFlowPipeline, Kandinsky3Pipeline, HunyuanDiTPipeline,
8
+ LuminaText2ImgPipeline
9
+ )
10
+ import spaces
11
 
12
+ # Constants
 
 
 
 
 
 
 
 
 
 
13
  MAX_SEED = np.iinfo(np.int32).max
14
  MAX_IMAGE_SIZE = 1024
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
17
+
18
+ # Model configurations
19
+ MODEL_CONFIGS = {
20
+ "Stable Diffusion 3.5": {
21
+ "repo_id": "stabilityai/stable-diffusion-3.5-large",
22
+ "pipeline_class": DiffusionPipeline
23
+ },
24
+ "FLUX": {
25
+ "repo_id": "black-forest-labs/FLUX.1-dev",
26
+ "pipeline_class": FluxPipeline
27
+ },
28
+ "PixArt": {
29
+ "repo_id": "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
30
+ "pipeline_class": PixArtSigmaPipeline
31
+ },
32
+ "AuraFlow": {
33
+ "repo_id": "fal/AuraFlow",
34
+ "pipeline_class": AuraFlowPipeline
35
+ },
36
+ "Kandinsky": {
37
+ "repo_id": "kandinsky-community/kandinsky-3",
38
+ "pipeline_class": Kandinsky3Pipeline
39
+ },
40
+ "Hunyuan": {
41
+ "repo_id": "Tencent-Hunyuan/HunyuanDiT-Diffusers",
42
+ "pipeline_class": HunyuanDiTPipeline
43
+ },
44
+ "Lumina": {
45
+ "repo_id": "Alpha-VLLM/Lumina-Next-SFT-diffusers",
46
+ "pipeline_class": LuminaText2ImgPipeline
47
+ }
48
+ }
49
 
50
+ # Initialize model pipelines
51
+ pipes = {}
52
+
53
+ def load_pipeline(model_name):
54
+ config = MODEL_CONFIGS[model_name]
55
+ pipe = config["pipeline_class"].from_pretrained(
56
+ config["repo_id"],
57
+ torch_dtype=TORCH_DTYPE
58
+ )
59
+ pipe = pipe.to(DEVICE)
60
+ if hasattr(pipe, 'enable_model_cpu_offload'):
61
+ pipe.enable_model_cpu_offload()
62
+ return pipe
63
+
64
+ @spaces.GPU(duration=180)
65
+ def generate_image(
66
+ model_name,
67
  prompt,
68
  negative_prompt="",
69
  seed=42,
 
72
  height=1024,
73
  guidance_scale=4.5,
74
  num_inference_steps=40,
75
+ progress=gr.Progress(track_tqdm=True)
76
  ):
77
+ progress(0, desc=f"Loading {model_name} model...")
78
+
79
+ # Load model if not already loaded
80
+ if model_name not in pipes:
81
+ pipes[model_name] = load_pipeline(model_name)
82
+
83
+ pipe = pipes[model_name]
84
+
85
  if randomize_seed:
86
  seed = random.randint(0, MAX_SEED)
87
+
88
+ generator = torch.Generator(DEVICE).manual_seed(seed)
89
+
90
+ progress(0.3, desc=f"Generating image with {model_name}...")
91
+
92
  image = pipe(
93
  prompt=prompt,
94
  negative_prompt=negative_prompt,
 
98
  height=height,
99
  generator=generator,
100
  ).images[0]
101
+
102
+ progress(1.0, desc=f"Generation complete with {model_name}")
103
  return image, seed
104
 
105
+ # Gradio Interface
 
 
 
 
106
  css = """
107
  #col-container {
108
  margin: 0 auto;
109
+ max-width: 1024px;
110
  }
111
  """
112
 
113
  with gr.Blocks(css=css) as demo:
114
  with gr.Column(elem_id="col-container"):
115
+ gr.Markdown("# Multi-Model Image Generation")
116
+
117
  with gr.Row():
118
  prompt = gr.Text(
119
  label="Prompt",
 
122
  placeholder="Enter your prompt",
123
  container=False,
124
  )
125
+ run_button = gr.Button("Generate", scale=0, variant="primary")
126
+
 
 
 
127
  with gr.Accordion("Advanced Settings", open=False):
128
  negative_prompt = gr.Text(
129
  label="Negative prompt",
130
  max_lines=1,
131
  placeholder="Enter a negative prompt",
 
132
  )
133
+
134
  seed = gr.Slider(
135
  label="Seed",
136
  minimum=0,
 
138
  step=1,
139
  value=0,
140
  )
141
+
142
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
143
+
144
  with gr.Row():
145
  width = gr.Slider(
146
  label="Width",
147
  minimum=512,
148
  maximum=MAX_IMAGE_SIZE,
149
  step=32,
150
+ value=1024,
151
  )
 
152
  height = gr.Slider(
153
  label="Height",
154
  minimum=512,
 
156
  step=32,
157
  value=1024,
158
  )
159
+
160
  with gr.Row():
161
  guidance_scale = gr.Slider(
162
  label="Guidance scale",
 
165
  step=0.1,
166
  value=4.5,
167
  )
 
168
  num_inference_steps = gr.Slider(
169
  label="Number of inference steps",
170
  minimum=1,
171
  maximum=50,
172
  step=1,
173
+ value=40,
174
  )
175
+
176
+ # Create tabs for each model
177
+ with gr.Tabs() as tabs:
178
+ results = {}
179
+ seeds = {}
180
+ for model_name in MODEL_CONFIGS.keys():
181
+ with gr.Tab(model_name):
182
+ results[model_name] = gr.Image(label=f"{model_name} Result")
183
+ seeds[model_name] = gr.Number(label="Seed used", visible=False)
184
+
185
+ examples = [
186
+ "A capybara wearing a suit holding a sign that reads Hello World",
187
+ "A serene landscape with mountains and a lake at sunset",
188
+ ]
189
+ gr.Examples(examples=examples, inputs=[prompt])
190
+
191
+ # Handle generation for each model
192
+ def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress()):
193
+ outputs = []
194
+ for model_name in MODEL_CONFIGS.keys():
195
+ try:
196
+ image, used_seed = generate_image(
197
+ model_name, prompt, negative_prompt, seed,
198
+ randomize_seed, width, height, guidance_scale,
199
+ num_inference_steps, progress
200
+ )
201
+ outputs.extend([image, used_seed])
202
+ except Exception as e:
203
+ outputs.extend([None, None])
204
+ print(f"Error generating with {model_name}: {str(e)}")
205
+ return outputs
206
+
207
+ # Set up the generation trigger
208
+ output_components = []
209
+ for model_name in MODEL_CONFIGS.keys():
210
+ output_components.extend([results[model_name], seeds[model_name]])
211
+
212
+ run_button.click(
213
+ fn=generate_all,
214
  inputs=[
215
  prompt,
216
  negative_prompt,
 
221
  guidance_scale,
222
  num_inference_steps,
223
  ],
224
+ outputs=output_components,
225
  )
226
 
227
  if __name__ == "__main__":
228
+ demo.launch()