Fabrice-TIERCELIN commited on
Commit
e5e0923
·
verified ·
1 Parent(s): 5e1edf7

Upload 11 files

Browse files
hyvideo/modules/attenion.py CHANGED
@@ -178,7 +178,7 @@ def parallel_attention(
178
  joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
179
  joint_strategy="rear",
180
  )
181
- if flash_attn.__version__ >= "2.7.0":
182
  attn2, *_ = _flash_attn_forward(
183
  q[:,cu_seqlens_q[1]:],
184
  k[:,cu_seqlens_kv[1]:],
 
178
  joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
179
  joint_strategy="rear",
180
  )
181
+ if flash_attn.__version__ >= '2.7.0':
182
  attn2, *_ = _flash_attn_forward(
183
  q[:,cu_seqlens_q[1]:],
184
  k[:,cu_seqlens_kv[1]:],
hyvideo/modules/fp8_optimization.py CHANGED
@@ -83,7 +83,7 @@ def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={
83
  setattr(module, "fp8_matmul_enabled", True)
84
 
85
  # loading fp8 mapping file
86
- fp8_map_path = dit_weight_path.replace(".pt", "_map.pt")
87
  if os.path.exists(fp8_map_path):
88
  fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
89
  else:
@@ -91,7 +91,7 @@ def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={
91
 
92
  fp8_layers = []
93
  for key, layer in module.named_modules():
94
- if isinstance(layer, nn.Linear) and ("double_blocks" in key or "single_blocks" in key):
95
  fp8_layers.append(key)
96
  original_forward = layer.forward
97
  layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
 
83
  setattr(module, "fp8_matmul_enabled", True)
84
 
85
  # loading fp8 mapping file
86
+ fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
87
  if os.path.exists(fp8_map_path):
88
  fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
89
  else:
 
91
 
92
  fp8_layers = []
93
  for key, layer in module.named_modules():
94
+ if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
95
  fp8_layers.append(key)
96
  original_forward = layer.forward
97
  layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))