|  | import torch | 
					
						
						|  | import sys | 
					
						
						|  | import os | 
					
						
						|  | current_dir = os.path.dirname(os.path.abspath(__file__)) | 
					
						
						|  | project_root = os.path.dirname(current_dir) | 
					
						
						|  | sys.path.append(project_root) | 
					
						
						|  |  | 
					
						
						|  | from hyvideo.modules.attenion import attention | 
					
						
						|  | from xfuser.core.long_ctx_attention import xFuserLongContextAttention | 
					
						
						|  | from xfuser.core.distributed import ( | 
					
						
						|  | init_distributed_environment, | 
					
						
						|  | initialize_model_parallel, | 
					
						
						|  |  | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def init_dist(backend="nccl"): | 
					
						
						|  | local_rank = int(os.environ["LOCAL_RANK"]) | 
					
						
						|  | rank = int(os.environ["RANK"]) | 
					
						
						|  | world_size = int(os.environ["WORLD_SIZE"]) | 
					
						
						|  |  | 
					
						
						|  | print( | 
					
						
						|  | f"Initializing distributed environment with rank {rank}, world size {world_size}, local rank {local_rank}" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | torch.cuda.set_device(local_rank) | 
					
						
						|  | init_distributed_environment(rank=rank, world_size=world_size) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if world_size > 1: | 
					
						
						|  | ring_degree = world_size // 2 | 
					
						
						|  | ulysses_degree = 2 | 
					
						
						|  | else: | 
					
						
						|  | ring_degree = 1 | 
					
						
						|  | ulysses_degree = 1 | 
					
						
						|  | initialize_model_parallel( | 
					
						
						|  | sequence_parallel_degree=world_size, | 
					
						
						|  | ring_degree=ring_degree, | 
					
						
						|  | ulysses_degree=ulysses_degree, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return rank, world_size | 
					
						
						|  |  | 
					
						
						|  | def test_mm_double_stream_block_attention(rank, world_size): | 
					
						
						|  | device = torch.device(f"cuda:{rank}") | 
					
						
						|  | dtype = torch.bfloat16 | 
					
						
						|  | batch_size = 1 | 
					
						
						|  | seq_len_img = 118800 | 
					
						
						|  | seq_len_txt = 256 | 
					
						
						|  | heads_num = 24 | 
					
						
						|  | head_dim = 128 | 
					
						
						|  |  | 
					
						
						|  | img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype) | 
					
						
						|  | img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype) | 
					
						
						|  | img_v = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype) | 
					
						
						|  | txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype) | 
					
						
						|  | txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype) | 
					
						
						|  | txt_v = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype) | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | torch.distributed.broadcast(img_q, src=0) | 
					
						
						|  | torch.distributed.broadcast(img_k, src=0) | 
					
						
						|  | torch.distributed.broadcast(img_v, src=0) | 
					
						
						|  | torch.distributed.broadcast(txt_q, src=0) | 
					
						
						|  | torch.distributed.broadcast(txt_k, src=0) | 
					
						
						|  | torch.distributed.broadcast(txt_v, src=0) | 
					
						
						|  | q = torch.cat((img_q, txt_q), dim=1) | 
					
						
						|  | k = torch.cat((img_k, txt_k), dim=1) | 
					
						
						|  | v = torch.cat((img_v, txt_v), dim=1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32) | 
					
						
						|  | cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32) | 
					
						
						|  | max_seqlen_q = 119056 | 
					
						
						|  | max_seqlen_kv = 119056 | 
					
						
						|  | mode = "torch" | 
					
						
						|  |  | 
					
						
						|  | original_output = attention( | 
					
						
						|  | q, | 
					
						
						|  | k, | 
					
						
						|  | v, | 
					
						
						|  | mode=mode, | 
					
						
						|  | cu_seqlens_q=cu_seqlens_q, | 
					
						
						|  | cu_seqlens_kv=cu_seqlens_kv, | 
					
						
						|  | max_seqlen_q=max_seqlen_q, | 
					
						
						|  | max_seqlen_kv=max_seqlen_kv, | 
					
						
						|  | batch_size=batch_size | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | hybrid_seq_parallel_attn = xFuserLongContextAttention() | 
					
						
						|  | hybrid_seq_parallel_output = hybrid_seq_parallel_attn( | 
					
						
						|  | None, | 
					
						
						|  | img_q, | 
					
						
						|  | img_k, | 
					
						
						|  | img_v, | 
					
						
						|  | dropout_p=0.0, | 
					
						
						|  | causal=False, | 
					
						
						|  | joint_tensor_query=txt_q, | 
					
						
						|  | joint_tensor_key=txt_k, | 
					
						
						|  | joint_tensor_value=txt_v, | 
					
						
						|  | joint_strategy="rear", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | b, s, a, d = hybrid_seq_parallel_output.shape | 
					
						
						|  | hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1) | 
					
						
						|  |  | 
					
						
						|  | assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}" | 
					
						
						|  |  | 
					
						
						|  | torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3) | 
					
						
						|  | print("test_mm_double_stream_block_attention Passed") | 
					
						
						|  |  | 
					
						
						|  | def test_mm_single_stream_block_attention(rank, world_size): | 
					
						
						|  | device = torch.device(f"cuda:{rank}") | 
					
						
						|  | dtype = torch.bfloat16 | 
					
						
						|  | txt_len = 256 | 
					
						
						|  | batch_size = 1 | 
					
						
						|  | seq_len_img = 118800 | 
					
						
						|  | seq_len_txt = 256 | 
					
						
						|  | heads_num = 24 | 
					
						
						|  | head_dim = 128 | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype) | 
					
						
						|  | img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype) | 
					
						
						|  | txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype) | 
					
						
						|  | txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype) | 
					
						
						|  | v = torch.randn(batch_size, seq_len_img + seq_len_txt, heads_num, head_dim, device=device, dtype=dtype) | 
					
						
						|  |  | 
					
						
						|  | torch.distributed.broadcast(img_q, src=0) | 
					
						
						|  | torch.distributed.broadcast(img_k, src=0) | 
					
						
						|  | torch.distributed.broadcast(txt_q, src=0) | 
					
						
						|  | torch.distributed.broadcast(txt_k, src=0) | 
					
						
						|  | torch.distributed.broadcast(v, src=0) | 
					
						
						|  |  | 
					
						
						|  | q = torch.cat((img_q, txt_q), dim=1) | 
					
						
						|  | k = torch.cat((img_k, txt_k), dim=1) | 
					
						
						|  |  | 
					
						
						|  | cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32) | 
					
						
						|  | cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32) | 
					
						
						|  | max_seqlen_q = 119056 | 
					
						
						|  | max_seqlen_kv = 119056 | 
					
						
						|  | mode = "torch" | 
					
						
						|  |  | 
					
						
						|  | original_output = attention( | 
					
						
						|  | q, | 
					
						
						|  | k, | 
					
						
						|  | v, | 
					
						
						|  | mode=mode, | 
					
						
						|  | cu_seqlens_q=cu_seqlens_q, | 
					
						
						|  | cu_seqlens_kv=cu_seqlens_kv, | 
					
						
						|  | max_seqlen_q=max_seqlen_q, | 
					
						
						|  | max_seqlen_kv=max_seqlen_kv, | 
					
						
						|  | batch_size=batch_size | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | hybrid_seq_parallel_attn = xFuserLongContextAttention() | 
					
						
						|  | hybrid_seq_parallel_output = hybrid_seq_parallel_attn( | 
					
						
						|  | None, | 
					
						
						|  | q[:, :-txt_len, :, :], | 
					
						
						|  | k[:, :-txt_len, :, :], | 
					
						
						|  | v[:, :-txt_len, :, :], | 
					
						
						|  | dropout_p=0.0, | 
					
						
						|  | causal=False, | 
					
						
						|  | joint_tensor_query=q[:, -txt_len:, :, :], | 
					
						
						|  | joint_tensor_key=k[:, -txt_len:, :, :], | 
					
						
						|  | joint_tensor_value=v[:, -txt_len:, :, :], | 
					
						
						|  | joint_strategy="rear", | 
					
						
						|  | ) | 
					
						
						|  | b, s, a, d = hybrid_seq_parallel_output.shape | 
					
						
						|  | hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1) | 
					
						
						|  |  | 
					
						
						|  | assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}" | 
					
						
						|  |  | 
					
						
						|  | torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3) | 
					
						
						|  | print("test_mm_single_stream_block_attention Passed") | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | rank, world_size = init_dist() | 
					
						
						|  | test_mm_double_stream_block_attention(rank, world_size) | 
					
						
						|  | test_mm_single_stream_block_attention(rank, world_size) | 
					
						
						|  |  |