Gemini899 commited on
Commit
ce193d2
·
verified ·
1 Parent(s): eb62af0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -104
app.py CHANGED
@@ -1,108 +1,42 @@
1
  import gradio as gr
2
  import re
3
- import torch
4
  from PIL import Image
5
-
 
 
6
  import spaces
7
- from diffusers import StableDiffusionXLImg2ImgPipeline
8
-
9
- #
10
- # Load the two SDXL pipelines (base + refiner) globally, so they only load once.
11
- #
12
- BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
13
- REFINER_MODEL_ID = "stabilityai/stable-diffusion-xl-refiner-1.0"
14
 
 
15
  dtype = torch.float16
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
- pipe_base = StableDiffusionXLImg2ImgPipeline.from_pretrained(BASE_MODEL_ID, torch_dtype=dtype).to(device)
19
- pipe_refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(REFINER_MODEL_ID, torch_dtype=dtype).to(device)
 
 
20
 
21
- #
22
- # Helper functions
23
- #
24
- def sanitize_prompt(prompt: str) -> str:
25
- # Simple sanitation: remove suspicious characters
26
  allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
27
  return allowed_chars.sub("", prompt)
28
 
29
- def resize_to_multiple_of_64(image: Image.Image, max_dim: int = 1024):
30
- """
31
- Resizes the image so that both width/height <= max_dim,
32
- and each dimension is a multiple of 64.
33
- (SDXL often uses 1024x1024. You can do multiples of 128 if you prefer.)
34
- """
35
- w, h = image.size
36
-
37
- # If image is bigger than max_dim in any dimension, scale it down
38
- ratio = min(max_dim / w, max_dim / h, 1.0)
39
- new_w = int(w * ratio)
40
- new_h = int(h * ratio)
41
-
42
- # Round down to multiples of 64 for best results in SDXL
43
- new_w = new_w - (new_w % 64)
44
- new_h = new_h - (new_h % 64)
45
-
46
- new_w = max(new_w, 64)
47
- new_h = max(new_h, 64)
48
- return image.resize((new_w, new_h), Image.LANCZOS)
49
-
50
- @spaces.GPU(duration=240) # Increase time if needed (SDXL can be slow)
51
- def run_img2img_sdxl(
52
- init_image,
53
- prompt: str,
54
- strength: float,
55
- seed: int,
56
- steps_base: int,
57
- steps_refiner: int,
58
- ):
59
- """
60
- Runs a two-step SDXL (base + refiner) pass for high-quality img2img.
61
- """
62
- if init_image is None:
63
- print("No input image provided.")
64
  return None
65
-
66
- # Clean up prompt
67
- prompt = sanitize_prompt(prompt)
68
-
69
- # Ensure reproducibility
70
  generator = torch.Generator(device).manual_seed(seed)
71
-
72
- # Possibly resize the input to a smaller multiple-of-64 dimension
73
- # (1024x1024 or smaller is typical for SDXL)
74
- init_image = resize_to_multiple_of_64(init_image, max_dim=1024)
75
-
76
- # 1) Base pass
77
- base_output = pipe_base(
78
  prompt=prompt,
79
- image=init_image,
80
  strength=strength,
81
- guidance_scale=8.0, # Adjust if you want more or less adherence to prompt
82
- num_inference_steps=steps_base,
83
  generator=generator
84
- )
85
- base_image = base_output.images[0]
86
 
87
- # 2) Refiner pass
88
- # Typically set strength=0.0 for the refiner to do final detailing,
89
- # and possibly a slightly higher guidance scale.
90
- refiner_output = pipe_refiner(
91
- prompt=prompt,
92
- image=base_image,
93
- strength=0.0, # strictly refine
94
- guidance_scale=9.0,
95
- num_inference_steps=steps_refiner,
96
- generator=generator
97
- )
98
- final_image = refiner_output.images[0]
99
-
100
- return final_image
101
 
102
-
103
- #
104
- # Gradio UI
105
- #
106
  css = """
107
  #col-left {
108
  margin: 0 auto;
@@ -115,34 +49,32 @@ css = """
115
  """
116
 
117
  with gr.Blocks(css=css) as demo:
118
- gr.Markdown("## SDXL Img2Img (Base + Refiner) High Quality Demo")
119
 
120
  with gr.Row():
121
  with gr.Column():
122
- init_image = gr.Image(
123
- label="Init Image (Img2Img)",
124
  type="pil",
125
  image_mode="RGB",
126
  height=512
127
  )
128
- prompt = gr.Textbox(
129
- label="Prompt",
130
- placeholder="Describe what you want to see"
131
  )
132
- run_button = gr.Button("Generate")
133
- with gr.Accordion("Advanced Options", open=False):
134
- strength = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Strength (img2img)")
135
- seed = gr.Number(value=42, label="Seed", precision=0)
136
- steps_base = gr.Slider(1, 100, value=50, step=1, label="Steps (Base)")
137
- steps_refiner = gr.Slider(1, 100, value=30, step=1, label="Steps (Refiner)")
138
-
139
  with gr.Column():
140
- result_image = gr.Image(label="Result", height=512)
141
 
142
- # Link the button to our function
143
- run_button.click(
144
- fn=run_img2img_sdxl,
145
- inputs=[init_image, prompt, strength, seed, steps_base, steps_refiner],
146
  outputs=[result_image]
147
  )
148
 
 
1
  import gradio as gr
2
  import re
 
3
  from PIL import Image
4
+ import os
5
+ import torch
6
+ from diffusers import StableDiffusionImg2ImgPipeline
7
  import spaces
 
 
 
 
 
 
 
8
 
9
+ model_id = "SG161222/Realistic_Vision_V2.0"
10
  dtype = torch.float16
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Load the pipeline once at startup
14
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
15
+ model_id, torch_dtype=dtype
16
+ ).to(device)
17
 
18
+ def sanitize_prompt(prompt):
 
 
 
 
19
  allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
20
  return allowed_chars.sub("", prompt)
21
 
22
+ def process_img2img(img, prompt, strength, seed, steps):
23
+ if img is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  return None
 
 
 
 
 
25
  generator = torch.Generator(device).manual_seed(seed)
26
+ return pipe(
 
 
 
 
 
 
27
  prompt=prompt,
28
+ image=img,
29
  strength=strength,
30
+ guidance_scale=7.5, # typical for Realistic Vision
31
+ num_inference_steps=steps,
32
  generator=generator
33
+ ).images[0]
 
34
 
35
+ @spaces.GPU(duration=120)
36
+ def run_app_inference(image, prompt, strength, seed, steps, progress=gr.Progress(track_tqdm=True)):
37
+ progress(0, desc="Starting Inference")
38
+ return process_img2img(image, prompt, strength, seed, steps)
 
 
 
 
 
 
 
 
 
 
39
 
 
 
 
 
40
  css = """
41
  #col-left {
42
  margin: 0 auto;
 
49
  """
50
 
51
  with gr.Blocks(css=css) as demo:
52
+ gr.Markdown("## Realistic Vision v2.0 Img2ImgNo License Acceptance Required")
53
 
54
  with gr.Row():
55
  with gr.Column():
56
+ image_input = gr.Image(
57
+ label="Initial Image (Img2Img)",
58
  type="pil",
59
  image_mode="RGB",
60
  height=512
61
  )
62
+ prompt_input = gr.Textbox(
63
+ label="Prompt",
64
+ placeholder="Describe desired result"
65
  )
66
+ generate_button = gr.Button("Generate")
67
+ with gr.Accordion("Advanced Settings", open=False):
68
+ strength_slider = gr.Slider(0.0, 1.0, value=0.75, step=0.05, label="Strength")
69
+ seed_box = gr.Number(value=0, label="Seed", precision=0)
70
+ steps_box = gr.Slider(1, 100, value=30, step=1, label="Steps")
71
+
 
72
  with gr.Column():
73
+ result_image = gr.Image(label="Output", height=512)
74
 
75
+ generate_button.click(
76
+ fn=run_app_inference,
77
+ inputs=[image_input, prompt_input, strength_slider, seed_box, steps_box],
 
78
  outputs=[result_image]
79
  )
80