lifeng commited on
Commit
70ca38a
·
1 Parent(s): a442a5c

修改计算类型

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -28,10 +28,14 @@ lora_base_path = "./models"
28
  pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
29
  transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
30
  pipe.transformer = transformer
31
- pipe.to("cuda")
32
- # 在初始化模型后立即清理GPU缓存和启用注意力切片
33
- torch.cuda.empty_cache() # 清理GPU缓存
34
- pipe.enable_attention_slicing() # 启用注意力切片以减少内存使用
 
 
 
 
35
 
36
  def clear_cache(transformer):
37
  for name, attn_processor in transformer.attn_processors.items():
 
28
  pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
29
  transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
30
  pipe.transformer = transformer
31
+ try:
32
+ pipe.to("cuda")
33
+ # 在初始化模型后立即清理GPU缓存和启用注意力切片
34
+ torch.cuda.empty_cache() # 清理GPU缓存
35
+ pipe.enable_attention_slicing() # 启用注意力切片以减少内存使用
36
+ except torch.cuda.OutOfMemoryError:
37
+ print("CUDA out of memory. Switching to CPU.")
38
+ pipe.to("cpu")
39
 
40
  def clear_cache(transformer):
41
  for name, attn_processor in transformer.attn_processors.items():