soiz1 commited on
Commit
2f729e1
·
verified ·
1 Parent(s): 106d5e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -97
app.py CHANGED
@@ -12,20 +12,19 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
15
  import os
16
 
17
  import gradio as gr
18
- import huggingface_hub
19
  import pillow_avif
20
- import spaces
21
  import torch
22
- import gc
23
  from huggingface_hub import snapshot_download
24
  from pillow_heif import register_heif_opener
25
- from PIL import Image, ImageDraw, ImageFont
26
 
27
  from pipelines.pipeline_infu_flux import InfUFluxPipeline
28
 
 
29
  # Register HEIF support for Pillow
30
  register_heif_opener()
31
 
@@ -47,9 +46,26 @@ loaded_pipeline_config = {
47
 
48
 
49
  def download_models():
50
- snapshot_download(repo_id='ByteDance/InfiniteYou', revision="61248f501725e31c16a90a1c7ffe63f6e2839c65", local_dir='./models/InfiniteYou', local_dir_use_symlinks=False)
 
 
 
 
 
 
 
 
 
 
 
 
51
  try:
52
- snapshot_download(repo_id='black-forest-labs/FLUX.1-dev', local_dir='./models/FLUX.1-dev', local_dir_use_symlinks=False)
 
 
 
 
 
53
  except Exception as e:
54
  print(e)
55
  print('\nYou are downloading `black-forest-labs/FLUX.1-dev` to `./models/FLUX.1-dev` but failed. '
@@ -61,36 +77,11 @@ def download_models():
61
  exit()
62
 
63
 
64
- def init_pipeline(model_version, enable_realism, enable_anti_blur):
65
- loaded_pipeline_config["enable_realism"] = enable_realism
66
- loaded_pipeline_config["enable_anti_blur"] = enable_anti_blur
67
- loaded_pipeline_config["model_version"] = model_version
68
-
69
- pipeline = loaded_pipeline_config['pipeline']
70
- gc.collect()
71
- torch.cuda.empty_cache()
72
-
73
- model_path = f'./models/InfiniteYou/infu_flux_v1.0/{model_version}'
74
- print(f'loading model from {model_path}')
75
-
76
- pipeline = InfUFluxPipeline(
77
- base_model_path='./models/FLUX.1-dev',
78
- infu_model_path=model_path,
79
- insightface_root_path='./models/InfiniteYou/supports/insightface',
80
- image_proj_num_tokens=8,
81
- infu_flux_version='v1.0',
82
- model_version=model_version,
83
- )
84
-
85
- loaded_pipeline_config['pipeline'] = pipeline
86
-
87
- pipeline.pipe.delete_adapters(['realism', 'anti_blur'])
88
- loras = []
89
- if enable_realism: loras.append(['realism', 1.0])
90
- if enable_anti_blur: loras.append(['anti_blur', 1.0])
91
- pipeline.load_loras_state_dict(loras)
92
-
93
- return pipeline
94
 
95
 
96
  def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
@@ -107,62 +98,43 @@ def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
107
  loaded_pipeline_config["model_version"] = model_version
108
 
109
  pipeline = loaded_pipeline_config['pipeline']
110
- if pipeline is None or pipeline.model_version != model_version:
111
  print(f'Switching model to {model_version}')
112
- pipeline.model_version = model_version
 
 
 
 
113
  if model_version == 'aes_stage2':
114
- pipeline.infusenet_sim.cpu()
115
- pipeline.image_proj_model_sim.cpu()
116
- torch.cuda.empty_cache()
117
- pipeline.infusenet_aes.to('cuda')
118
- pipeline.pipe.controlnet = pipeline.infusenet_aes
119
- pipeline.image_proj_model_aes.to('cuda')
120
- pipeline.image_proj_model = pipeline.image_proj_model_aes
121
  else:
122
- pipeline.infusenet_aes.cpu()
123
- pipeline.image_proj_model_aes.cpu()
124
- torch.cuda.empty_cache()
125
- pipeline.infusenet_sim.to('cuda')
126
- pipeline.pipe.controlnet = pipeline.infusenet_sim
127
- pipeline.image_proj_model_sim.to('cuda')
128
- pipeline.image_proj_model = pipeline.image_proj_model_sim
 
 
 
 
129
 
130
  loaded_pipeline_config['pipeline'] = pipeline
131
 
132
  pipeline.pipe.delete_adapters(['realism', 'anti_blur'])
133
  loras = []
134
- if enable_realism: loras.append(['realism', 1.0])
135
- if enable_anti_blur: loras.append(['anti_blur', 1.0])
136
- pipeline.load_loras_state_dict(loras)
 
 
137
 
138
  return pipeline
139
 
140
 
141
- def add_safety_watermark(image, text='AI Generated', font_path=None):
142
- width, height = image.size
143
- draw = ImageDraw.Draw(image)
144
-
145
- font_size = int(height * 0.028)
146
- if font_path:
147
- font = ImageFont.truetype(font_path, font_size)
148
- else:
149
- font = ImageFont.load_default(size=font_size)
150
-
151
- text_bbox = draw.textbbox((0, 0), text, font=font)
152
- text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
153
- x = width - text_width - 10
154
- y = height - text_height - 20
155
-
156
- shadow_offset = 2
157
- shadow_color = "black"
158
- draw.text((x + shadow_offset, y + shadow_offset), text, font=font, fill=shadow_color)
159
-
160
- draw.text((x, y), text, font=font, fill="white")
161
-
162
- return image
163
-
164
-
165
- @spaces.GPU(duration=120)
166
  def generate_image(
167
  input_image,
168
  control_image,
@@ -179,12 +151,12 @@ def generate_image(
179
  enable_anti_blur,
180
  model_version
181
  ):
182
- try:
183
- pipeline = prepare_pipeline(model_version=model_version, enable_realism=enable_realism, enable_anti_blur=enable_anti_blur)
184
 
185
- if seed == 0:
186
- seed = torch.seed() & 0xFFFFFFFF
187
 
 
188
  image = pipeline(
189
  id_image=input_image,
190
  prompt=prompt,
@@ -198,13 +170,12 @@ def generate_image(
198
  infusenet_guidance_start=infusenet_guidance_start,
199
  infusenet_guidance_end=infusenet_guidance_end,
200
  )
201
- image = add_safety_watermark(image)
202
  except Exception as e:
203
  print(e)
204
  gr.Error(f"An error occurred: {e}")
205
  return gr.update()
206
 
207
- return gr.update(value=image, label=f"Generated Image, seed = {seed}")
208
 
209
 
210
  def generate_examples(id_image, control_image, prompt_text, seed, enable_realism, enable_anti_blur, model_version):
@@ -233,8 +204,7 @@ with gr.Blocks() as demo:
233
  <a href="https://bytedance.github.io/InfiniteYou">[Project Page]</a>&ensp;
234
  <a href="https://arxiv.org/abs/2503.16418">[Paper]</a>&ensp;
235
  <a href="https://github.com/bytedance/InfiniteYou">[Code]</a>&ensp;
236
- <a href="https://huggingface.co/ByteDance/InfiniteYou">[Model]</a>&ensp;
237
- <a href="https://github.com/bytedance/ComfyUI_InfiniteYou">[ComfyUI]</a>
238
  </div>
239
  """)
240
 
@@ -298,7 +268,7 @@ with gr.Blocks() as demo:
298
  inputs=[ui_id_image, ui_control_image, ui_prompt_text, ui_seed, ui_enable_realism, ui_enable_anti_blur, ui_model_version],
299
  outputs=[image_output],
300
  fn=generate_examples,
301
- cache_examples=False
302
  )
303
 
304
  ui_btn_generate.click(
@@ -335,12 +305,12 @@ with gr.Blocks() as demo:
335
  The images used in this demo are sourced from consented subjects or generated by the models. These pictures are intended solely to show the capabilities of our research. If you have any concerns, please contact us, and we will promptly remove any inappropriate content.
336
 
337
  The use of the released code, model, and demo must strictly adhere to the respective licenses.
338
- Our code is released under the [Apache License 2.0](https://github.com/bytedance/InfiniteYou/blob/main/LICENSE),
339
  and our model is released under the [Creative Commons Attribution-NonCommercial 4.0 International Public License](https://huggingface.co/ByteDance/InfiniteYou/blob/main/LICENSE)
340
  for academic research purposes only. Any manual or automatic downloading of the face models from [InsightFace](https://github.com/deepinsight/insightface),
341
  the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) base model, LoRAs, *etc.*, must follow their original licenses and be used only for academic research purposes.
342
 
343
- This research aims to positively impact the field of Generative AI. Any usage of this method must be responsible and comply with local laws. The developers do not assume any responsibility for any potential misuse. We added the "AI Generated" watermark for enhanced safety.
344
  """
345
  )
346
 
@@ -363,14 +333,13 @@ with gr.Blocks() as demo:
363
  """
364
  )
365
 
366
-
367
- huggingface_hub.login(os.getenv('PRIVATE_HF_TOKEN'))
368
-
369
  download_models()
370
 
371
- init_pipeline(model_version=ModelVersion.DEFAULT_VERSION, enable_realism=ENABLE_REALISM_DEFAULT, enable_anti_blur=ENABLE_ANTI_BLUR_DEFAULT)
372
 
373
- # demo.queue()
374
- demo.launch()
375
  # demo.launch(server_name='0.0.0.0') # IPv4
376
  # demo.launch(server_name='[::]') # IPv6
 
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import gc
16
+ import shutil
17
  import os
18
 
19
  import gradio as gr
 
20
  import pillow_avif
 
21
  import torch
 
22
  from huggingface_hub import snapshot_download
23
  from pillow_heif import register_heif_opener
 
24
 
25
  from pipelines.pipeline_infu_flux import InfUFluxPipeline
26
 
27
+
28
  # Register HEIF support for Pillow
29
  register_heif_opener()
30
 
 
46
 
47
 
48
  def download_models():
49
+ # InfiniteYou モデル (必要な部分だけ)
50
+ snapshot_download(
51
+ repo_id='ByteDance/InfiniteYou',
52
+ local_dir='./models/InfiniteYou',
53
+ local_dir_use_symlinks=False,
54
+ allow_patterns=[
55
+ "infu_flux_v1.0/aes_stage2/**", # aes_stage2 モデル
56
+ "supports/insightface/**", # 顔認識に必要
57
+ "supports/optional_loras/**", # LoRA が必要なら残す
58
+ ]
59
+ )
60
+
61
+ # FLUX.1-dev (ベースモデルは必須なので別途取得)
62
  try:
63
+ snapshot_download(
64
+ repo_id='black-forest-labs/FLUX.1-dev',
65
+ local_dir='./models/FLUX.1-dev',
66
+ local_dir_use_symlinks=False,
67
+ allow_patterns=["*.safetensors", "*.json", "*.txt"]
68
+ )
69
  except Exception as e:
70
  print(e)
71
  print('\nYou are downloading `black-forest-labs/FLUX.1-dev` to `./models/FLUX.1-dev` but failed. '
 
77
  exit()
78
 
79
 
80
+ def clean_hf_cache():
81
+ cache_dir = os.path.expanduser("~/.cache/huggingface")
82
+ if os.path.exists(cache_dir):
83
+ print(f"Cleaning Hugging Face cache at {cache_dir}")
84
+ shutil.rmtree(cache_dir, ignore_errors=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
  def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
 
98
  loaded_pipeline_config["model_version"] = model_version
99
 
100
  pipeline = loaded_pipeline_config['pipeline']
101
+ if pipeline is None or pipeline.model_version != model_version:
102
  print(f'Switching model to {model_version}')
103
+ del pipeline
104
+ del loaded_pipeline_config['pipeline']
105
+ gc.collect()
106
+ torch.cuda.empty_cache()
107
+
108
  if model_version == 'aes_stage2':
109
+ model_path = f'./models/InfiniteYou/infu_flux_v1.0/aes_stage2'
110
+ elif model_version == 'sim_stage1':
111
+ model_path = f'./models/InfiniteYou/infu_flux_v1.0/sim_stage1'
 
 
 
 
112
  else:
113
+ raise ValueError(f'Model version {model_version} not supported.')
114
+ print(f'Loading model from {model_path}')
115
+
116
+ pipeline = InfUFluxPipeline(
117
+ base_model_path='./models/FLUX.1-dev',
118
+ infu_model_path=model_path,
119
+ insightface_root_path='./models/InfiniteYou/supports/insightface',
120
+ image_proj_num_tokens=8,
121
+ infu_flux_version='v1.0',
122
+ model_version=model_version,
123
+ )
124
 
125
  loaded_pipeline_config['pipeline'] = pipeline
126
 
127
  pipeline.pipe.delete_adapters(['realism', 'anti_blur'])
128
  loras = []
129
+ if enable_realism:
130
+ loras.append(['./models/InfiniteYou/supports/optional_loras/flux_realism_lora.safetensors', 'realism', 1.0])
131
+ if enable_anti_blur:
132
+ loras.append(['./models/InfiniteYou/supports/optional_loras/flux_anti_blur_lora.safetensors', 'anti_blur', 1.0])
133
+ pipeline.load_loras(loras)
134
 
135
  return pipeline
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def generate_image(
139
  input_image,
140
  control_image,
 
151
  enable_anti_blur,
152
  model_version
153
  ):
154
+ pipeline = prepare_pipeline(model_version=model_version, enable_realism=enable_realism, enable_anti_blur=enable_anti_blur)
 
155
 
156
+ if seed == 0:
157
+ seed = torch.seed() & 0xFFFFFFFF
158
 
159
+ try:
160
  image = pipeline(
161
  id_image=input_image,
162
  prompt=prompt,
 
170
  infusenet_guidance_start=infusenet_guidance_start,
171
  infusenet_guidance_end=infusenet_guidance_end,
172
  )
 
173
  except Exception as e:
174
  print(e)
175
  gr.Error(f"An error occurred: {e}")
176
  return gr.update()
177
 
178
+ return gr.update(value = image, label=f"Generated Image, seed = {seed}")
179
 
180
 
181
  def generate_examples(id_image, control_image, prompt_text, seed, enable_realism, enable_anti_blur, model_version):
 
204
  <a href="https://bytedance.github.io/InfiniteYou">[Project Page]</a>&ensp;
205
  <a href="https://arxiv.org/abs/2503.16418">[Paper]</a>&ensp;
206
  <a href="https://github.com/bytedance/InfiniteYou">[Code]</a>&ensp;
207
+ <a href="https://huggingface.co/ByteDance/InfiniteYou">[Model]</a>
 
208
  </div>
209
  """)
210
 
 
268
  inputs=[ui_id_image, ui_control_image, ui_prompt_text, ui_seed, ui_enable_realism, ui_enable_anti_blur, ui_model_version],
269
  outputs=[image_output],
270
  fn=generate_examples,
271
+ cache_examples=True,
272
  )
273
 
274
  ui_btn_generate.click(
 
305
  The images used in this demo are sourced from consented subjects or generated by the models. These pictures are intended solely to show the capabilities of our research. If you have any concerns, please contact us, and we will promptly remove any inappropriate content.
306
 
307
  The use of the released code, model, and demo must strictly adhere to the respective licenses.
308
+ Our code is released under the [Apache 2.0 License](https://github.com/bytedance/InfiniteYou/blob/main/LICENSE),
309
  and our model is released under the [Creative Commons Attribution-NonCommercial 4.0 International Public License](https://huggingface.co/ByteDance/InfiniteYou/blob/main/LICENSE)
310
  for academic research purposes only. Any manual or automatic downloading of the face models from [InsightFace](https://github.com/deepinsight/insightface),
311
  the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) base model, LoRAs, *etc.*, must follow their original licenses and be used only for academic research purposes.
312
 
313
+ This research aims to positively impact the field of Generative AI. Any usage of this method must be responsible and comply with local laws. The developers do not assume any responsibility for any potential misuse.
314
  """
315
  )
316
 
 
333
  """
334
  )
335
 
 
 
 
336
  download_models()
337
 
338
+ prepare_pipeline(model_version=ModelVersion.DEFAULT_VERSION, enable_realism=ENABLE_REALISM_DEFAULT, enable_anti_blur=ENABLE_ANTI_BLUR_DEFAULT)
339
 
340
+ demo.queue()
341
+ demo.launch(server_name='localhost') # localhost
342
  # demo.launch(server_name='0.0.0.0') # IPv4
343
  # demo.launch(server_name='[::]') # IPv6
344
+
345
+ clean_hf_cache()