lifeng commited on
Commit
a442a5c
·
1 Parent(s): 806fa90

修改参数

Browse files
Files changed (1) hide show
  1. app.py +23 -21
app.py CHANGED
@@ -29,6 +29,9 @@ pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
29
  transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
30
  pipe.transformer = transformer
31
  pipe.to("cuda")
 
 
 
32
 
33
  def clear_cache(transformer):
34
  for name, attn_processor in transformer.attn_processors.items():
@@ -37,28 +40,27 @@ def clear_cache(transformer):
37
  # Define the Gradio interface
38
  @spaces.GPU()
39
  def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
40
- with torch.no_grad():
41
  # Set the control type
42
- if control_type == "Ghibli":
43
- lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
44
- set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
45
-
46
- # Process the image
47
- spatial_imgs = [spatial_img] if spatial_img else []
48
- image = pipe(
49
- prompt,
50
- height=int(height),
51
- width=int(width),
52
- guidance_scale=3.5,
53
- num_inference_steps=25,
54
- max_sequence_length=512,
55
- generator=torch.Generator("cpu").manual_seed(seed),
56
- subject_images=[],
57
- spatial_images=spatial_imgs,
58
- cond_size=512,
59
- ).images[0]
60
- clear_cache(pipe.transformer)
61
- return image
62
 
63
  # Define the Gradio interface components
64
  control_types = ["Ghibli"]
 
29
  transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
30
  pipe.transformer = transformer
31
  pipe.to("cuda")
32
+ # 在初始化模型后立即清理GPU缓存和启用注意力切片
33
+ torch.cuda.empty_cache() # 清理GPU缓存
34
+ pipe.enable_attention_slicing() # 启用注意力切片以减少内存使用
35
 
36
  def clear_cache(transformer):
37
  for name, attn_processor in transformer.attn_processors.items():
 
40
  # Define the Gradio interface
41
  @spaces.GPU()
42
  def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
 
43
  # Set the control type
44
+ if control_type == "Ghibli":
45
+ lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
46
+ set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
47
+
48
+ # Process the image
49
+ spatial_imgs = [spatial_img] if spatial_img else []
50
+ image = pipe(
51
+ prompt,
52
+ height=int(height),
53
+ width=int(width),
54
+ guidance_scale=3.5,
55
+ num_inference_steps=25,
56
+ max_sequence_length=512,
57
+ generator=torch.Generator("cpu").manual_seed(seed),
58
+ subject_images=[],
59
+ spatial_images=spatial_imgs,
60
+ cond_size=512,
61
+ ).images[0]
62
+ clear_cache(pipe.transformer)
63
+ return image
64
 
65
  # Define the Gradio interface components
66
  control_types = ["Ghibli"]