p1atdev commited on
Commit
e518b27
·
1 Parent(s): d310f55

feat: init tkg

Browse files
Files changed (6) hide show
  1. .python-version +1 -0
  2. app.py +104 -45
  3. pyproject.toml +21 -0
  4. requirements.txt +2 -2
  5. tkg.py +117 -0
  6. uv.lock +0 -0
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
app.py CHANGED
@@ -1,36 +1,44 @@
1
- import gradio as gr
2
- import numpy as np
3
  import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
8
 
 
 
 
 
 
 
 
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
 
 
 
 
 
18
  pipe = pipe.to(device)
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
 
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
 
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
36
  if randomize_seed:
@@ -38,7 +46,27 @@ def infer(
38
 
39
  generator = torch.Generator().manual_seed(seed)
40
 
41
- image = pipe(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  prompt=prompt,
43
  negative_prompt=negative_prompt,
44
  guidance_scale=guidance_scale,
@@ -46,27 +74,55 @@ def infer(
46
  width=width,
47
  height=height,
48
  generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
  ]
59
 
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
 
71
  with gr.Row():
72
  prompt = gr.Text(
@@ -79,14 +135,16 @@ with gr.Blocks(css=css) as demo:
79
 
80
  run_button = gr.Button("Run", scale=0, variant="primary")
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
83
 
84
  with gr.Accordion("Advanced Settings", open=False):
85
  negative_prompt = gr.Text(
86
  label="Negative prompt",
87
  max_lines=1,
88
  placeholder="Enter a negative prompt",
89
- visible=False,
90
  )
91
 
92
  seed = gr.Slider(
@@ -105,7 +163,7 @@ with gr.Blocks(css=css) as demo:
105
  minimum=256,
106
  maximum=MAX_IMAGE_SIZE,
107
  step=32,
108
- value=1024, # Replace with defaults that work for your model
109
  )
110
 
111
  height = gr.Slider(
@@ -113,7 +171,7 @@ with gr.Blocks(css=css) as demo:
113
  minimum=256,
114
  maximum=MAX_IMAGE_SIZE,
115
  step=32,
116
- value=1024, # Replace with defaults that work for your model
117
  )
118
 
119
  with gr.Row():
@@ -122,7 +180,7 @@ with gr.Blocks(css=css) as demo:
122
  minimum=0.0,
123
  maximum=10.0,
124
  step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
  )
127
 
128
  num_inference_steps = gr.Slider(
@@ -130,10 +188,11 @@ with gr.Blocks(css=css) as demo:
130
  minimum=1,
131
  maximum=50,
132
  step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
135
 
136
  gr.Examples(examples=examples, inputs=[prompt])
 
137
  gr.on(
138
  triggers=[run_button.click, prompt.submit],
139
  fn=infer,
@@ -147,7 +206,7 @@ with gr.Blocks(css=css) as demo:
147
  guidance_scale,
148
  num_inference_steps,
149
  ],
150
- outputs=[result, seed],
151
  )
152
 
153
  if __name__ == "__main__":
 
 
 
1
  import random
2
 
3
+ import numpy as np
 
4
  import torch
5
 
6
+ from diffusers import StableDiffusionXLPipeline
7
+ import spaces
8
+
9
+ import gradio as gr
10
+
11
+ from .tkg import apply_tkg_noise, COLOR_SET_MAP
12
+
13
+ torch.backends.cuda.matmul.allow_tf32 = True
14
+ torch.backends.cudnn.allow_tf32 = True
15
+
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model_repo_id = "cagliostrolab/animagine-xl-4.0" # Replace to the model you would like to use
18
 
 
 
 
 
19
 
20
+ pipe = StableDiffusionXLPipeline.from_pretrained(
21
+ "cagliostrolab/animagine-xl-4.0",
22
+ torch_dtype=torch.bfloat16,
23
+ custom_pipeline="lpw_stable_diffusion_xl",
24
+ add_watermarker=False,
25
+ )
26
  pipe = pipe.to(device)
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
+ MAX_IMAGE_SIZE = 2048
30
 
31
+ @spaces.GPU
 
32
  def infer(
33
+ prompt: str,
34
+ negative_prompt: str,
35
+ seed: int,
36
+ randomize_seed: bool,
37
+ width: int,
38
+ height: int,
39
+ guidance_scale: float,
40
+ num_inference_steps: int,
41
+ tkg_channels: list[int] = [0, 1, 1, 0],
42
  progress=gr.Progress(track_tqdm=True),
43
  ):
44
  if randomize_seed:
 
46
 
47
  generator = torch.Generator().manual_seed(seed)
48
 
49
+ latents = torch.randn(
50
+ (
51
+ 1,
52
+ 2,
53
+ height // 8,
54
+ width // 8,
55
+ ),
56
+ generator=generator,
57
+ device=device,
58
+ )
59
+ latents = apply_tkg_noise(
60
+ latents,
61
+ shift=0.11,
62
+ delta_shift=0.1,
63
+ std_dev=0.5,
64
+ factor=8,
65
+ channels=tkg_channels,
66
+ )
67
+
68
+ images = pipe(
69
+ latents=latents,
70
  prompt=prompt,
71
  negative_prompt=negative_prompt,
72
  guidance_scale=guidance_scale,
 
74
  width=width,
75
  height=height,
76
  generator=generator,
77
+ ).images
78
+
79
+ w_tkg, wo_tkg = images
80
+
81
+ return w_tkg, wo_tkg, seed
82
+
83
+ def color_name_to_channels(color_name: str) -> list[int]:
84
+ if color_name in COLOR_SET_MAP:
85
+ return COLOR_SET_MAP[color_name]
86
+ else:
87
+ raise ValueError(f"Unknown color name: {color_name}")
88
+
89
+ def on_generate(
90
+ prompt: str,
91
+ negative_prompt: str,
92
+ seed: int,
93
+ randomize_seed: bool,
94
+ width: int,
95
+ height: int,
96
+ guidance_scale: float,
97
+ num_inference_steps: int,
98
+ color_name: str,
99
+ *args,
100
+ **kwargs
101
+ ):
102
+ tkg_channels = color_name_to_channels(color_name)
103
+ # TODO: custom channels
104
+
105
+ return infer(
106
+ prompt,
107
+ negative_prompt,
108
+ seed,
109
+ randomize_seed,
110
+ width,
111
+ height,
112
+ guidance_scale,
113
+ num_inference_steps,
114
+ tkg_channels=tkg_channels,
115
+ )
116
 
117
  examples = [
118
+ # "1girl, arima kana, oshi no ko, hoshimachi suisei, hoshimachi suisei \(1st costume\), cosplay, looking at viewer, smile, outdoors, night, v, masterpiece, high score, great score, absurdres",
119
+ "1girl, solo, upper body, looking at viewer, straight-on, masterpiece, best quality",
 
120
  ]
121
 
 
 
 
 
 
 
122
 
123
+ with gr.Blocks() as demo:
124
+ with gr.Column():
125
+ gr.Markdown("# TKG Chroma-Key with AnimagineXL 4.0")
126
 
127
  with gr.Row():
128
  prompt = gr.Text(
 
135
 
136
  run_button = gr.Button("Run", scale=0, variant="primary")
137
 
138
+ with gr.Row():
139
+ result_w_tkg = gr.Image(label="Result", show_label=False)
140
+ result_wo_tkg = gr.Image(label="Result", show_label=False)
141
 
142
  with gr.Accordion("Advanced Settings", open=False):
143
  negative_prompt = gr.Text(
144
  label="Negative prompt",
145
  max_lines=1,
146
  placeholder="Enter a negative prompt",
147
+ default="lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry",
148
  )
149
 
150
  seed = gr.Slider(
 
163
  minimum=256,
164
  maximum=MAX_IMAGE_SIZE,
165
  step=32,
166
+ value=832,
167
  )
168
 
169
  height = gr.Slider(
 
171
  minimum=256,
172
  maximum=MAX_IMAGE_SIZE,
173
  step=32,
174
+ value=1152,
175
  )
176
 
177
  with gr.Row():
 
180
  minimum=0.0,
181
  maximum=10.0,
182
  step=0.1,
183
+ value=5.0,
184
  )
185
 
186
  num_inference_steps = gr.Slider(
 
188
  minimum=1,
189
  maximum=50,
190
  step=1,
191
+ value=25,
192
  )
193
 
194
  gr.Examples(examples=examples, inputs=[prompt])
195
+
196
  gr.on(
197
  triggers=[run_button.click, prompt.submit],
198
  fn=infer,
 
206
  guidance_scale,
207
  num_inference_steps,
208
  ],
209
+ outputs=[result_w_tkg, result_wo_tkg, seed],
210
  )
211
 
212
  if __name__ == "__main__":
pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "app"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10,<3.13"
7
+ dependencies = [
8
+ "accelerate>=1.10.0",
9
+ "diffusers>=0.35.1",
10
+ "spaces>=0.40.1",
11
+ "torch>=2.8.0",
12
+ "transformers>=4.55.4",
13
+ "xformers>=0.0.32.post2",
14
+ ]
15
+
16
+ [dependency-groups]
17
+ dev = [
18
+ "gradio>=5.43.1",
19
+ "ruff>=0.12.10",
20
+ "ty>=0.0.1a19",
21
+ ]
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  accelerate
2
  diffusers
3
- invisible_watermark
4
  torch
5
  transformers
6
- xformers
 
 
1
  accelerate
2
  diffusers
 
3
  torch
4
  transformers
5
+ xformers
6
+ spaces
tkg.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def get_mean_shifted_latents(
8
+ latents: torch.Tensor,
9
+ shift: float = 0.11,
10
+ delta_shift: float = 0.1,
11
+ channels: list[int] = [0, 1, 1, 0], # list of {-1, 0, 1}
12
+ ) -> torch.Tensor:
13
+ shifted_latents = latents.clone()
14
+
15
+ for idx, sign in enumerate(channels):
16
+ if sign == 0:
17
+ # skip
18
+ continue
19
+
20
+ latent_channel = shifted_latents[:, idx, :, :]
21
+
22
+ positive_ratio = (latent_channel > 0).float().mean()
23
+ target_ratio = positive_ratio + shift * sign
24
+
25
+ # gradually shift latent_channel
26
+ while True:
27
+ latent_channel += delta_shift * sign
28
+ new_positive_ratio = (latent_channel > 0).float().mean()
29
+ if new_positive_ratio >= target_ratio:
30
+ break
31
+
32
+ # replace the channel in the original latents
33
+ shifted_latents[:, idx, :, :] = latent_channel
34
+
35
+ return shifted_latents
36
+
37
+
38
+ def get_2d_gaussian(
39
+ latent_height: int,
40
+ latent_width: int,
41
+ std_dev: float,
42
+ device: torch.device,
43
+ center_x: float = 0.0,
44
+ center_y: float = 0.0,
45
+ factor: int = 8, # idk why
46
+ ):
47
+ y = torch.linspace(-1, 1, steps=latent_height // factor, device=device)
48
+ x = torch.linspace(-1, 1, steps=latent_width // factor, device=device)
49
+
50
+ y_grid, x_grid = torch.meshgrid(y, x, indexing="ij")
51
+
52
+ x_grid = x_grid - center_x
53
+ y_grid = y_grid - center_y
54
+
55
+ gauss = torch.exp(-((x_grid**2 + y_grid**2) / (2 * std_dev**2)))
56
+ gauss = gauss[None, None, :, :] # add batch and channel dimensions
57
+
58
+ return gauss
59
+
60
+
61
+ def apply_tkg_noise(
62
+ latents: torch.Tensor,
63
+ shift: float = 0.11,
64
+ delta_shift: float = 0.1,
65
+ std_dev: float = 0.5,
66
+ factor: int = 8,
67
+ channels: list[int] = [0, 1, 1, 0],
68
+ ):
69
+ batch_size, num_channels, latent_height, latent_width = latents.shape
70
+
71
+ shifted_latents = get_mean_shifted_latents(
72
+ latents,
73
+ shift=shift,
74
+ delta_shift=delta_shift,
75
+ channels=channels,
76
+ )
77
+ gauss_mask = get_2d_gaussian(
78
+ latent_height=latent_height,
79
+ latent_width=latent_width,
80
+ std_dev=std_dev,
81
+ center_x=0.0,
82
+ center_y=0.0,
83
+ factor=factor,
84
+ device=latents.device,
85
+ )
86
+ gauss_mask = F.interpolate(
87
+ gauss_mask,
88
+ size=(latent_height, latent_width),
89
+ mode="bilinear",
90
+ align_corners=False,
91
+ )
92
+
93
+ gauss_mask = gauss_mask.expand(batch_size, num_channels, -1, -1)
94
+
95
+ noised_latents = shifted_latents * (1 - gauss_mask) + latents * gauss_mask
96
+
97
+ return noised_latents
98
+
99
+
100
+ class ColorSet(NamedTuple):
101
+ name: str
102
+ channels: list[int]
103
+
104
+
105
+ # ref: Figure 28. Additional Result in various color Background with SD
106
+ COLOR_SETS: list[ColorSet] = [
107
+ ColorSet("green", [0, 1, 1, 0]),
108
+ ColorSet("cyan", [0, 1, 0, 0]),
109
+ ColorSet("magenta", [0, -1, -1, -1]),
110
+ ColorSet("purple", [0, 0, -1, -1]),
111
+ ColorSet("black", [-1, 0, 0, 1]),
112
+ ColorSet("orange", [-1, -1, 1, 0]),
113
+ ColorSet("white", [0, 0, 0, -1]),
114
+ ColorSet("yellow", [0, -1, 1, -1]),
115
+ ]
116
+
117
+ COLOR_SET_MAP: dict[str, ColorSet] = {c.name: c for c in COLOR_SETS}
uv.lock ADDED
The diff for this file is too large to render. See raw diff