lifeng commited on
Commit
806fa90
·
1 Parent(s): 333b9ce

修改内存大小占用

Browse files
Files changed (1) hide show
  1. app.py +23 -22
app.py CHANGED
@@ -25,8 +25,8 @@ lora_base_path = "./models"
25
 
26
 
27
  # pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
28
- pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat8)
29
- transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat8)
30
  pipe.transformer = transformer
31
  pipe.to("cuda")
32
 
@@ -37,27 +37,28 @@ def clear_cache(transformer):
37
  # Define the Gradio interface
38
  @spaces.GPU()
39
  def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
 
40
  # Set the control type
41
- if control_type == "Ghibli":
42
- lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
43
- set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
44
-
45
- # Process the image
46
- spatial_imgs = [spatial_img] if spatial_img else []
47
- image = pipe(
48
- prompt,
49
- height=int(height),
50
- width=int(width),
51
- guidance_scale=3.5,
52
- num_inference_steps=25,
53
- max_sequence_length=512,
54
- generator=torch.Generator("cpu").manual_seed(seed),
55
- subject_images=[],
56
- spatial_images=spatial_imgs,
57
- cond_size=512,
58
- ).images[0]
59
- clear_cache(pipe.transformer)
60
- return image
61
 
62
  # Define the Gradio interface components
63
  control_types = ["Ghibli"]
 
25
 
26
 
27
  # pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
28
+ pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
29
+ transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
30
  pipe.transformer = transformer
31
  pipe.to("cuda")
32
 
 
37
  # Define the Gradio interface
38
  @spaces.GPU()
39
  def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
40
+ with torch.no_grad():
41
  # Set the control type
42
+ if control_type == "Ghibli":
43
+ lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
44
+ set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
45
+
46
+ # Process the image
47
+ spatial_imgs = [spatial_img] if spatial_img else []
48
+ image = pipe(
49
+ prompt,
50
+ height=int(height),
51
+ width=int(width),
52
+ guidance_scale=3.5,
53
+ num_inference_steps=25,
54
+ max_sequence_length=512,
55
+ generator=torch.Generator("cpu").manual_seed(seed),
56
+ subject_images=[],
57
+ spatial_images=spatial_imgs,
58
+ cond_size=512,
59
+ ).images[0]
60
+ clear_cache(pipe.transformer)
61
+ return image
62
 
63
  # Define the Gradio interface components
64
  control_types = ["Ghibli"]