In [1]:
import torch
import torch.nn as nn
import time
from tqdm import trange

In [2]:
def test_multihead_attention_loop():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Parameters
    embed_dim = 64*10
    num_heads = 64
    seq_len = 1024
    batch_size = 64
    iterations = 100

    # Create MultiheadAttention module
    mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True).to(device)

    # for i in trange(iterations, ncols=70, colour='blue',):
    for i in range(iterations):
        # Generate fresh input for each iteration
        query = torch.rand(batch_size, seq_len, embed_dim, device=device)
        key = torch.rand(batch_size, seq_len, embed_dim, device=device)
        value = torch.rand(batch_size, seq_len, embed_dim, device=device)

        # Optional causal mask
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, device=device) * float('-inf'), diagonal=1)

        # Run forward pass
        output, attn_weights = mha(query, key, value, attn_mask=attn_mask)

        # Print memory stats
        if device.type == "cuda":
            torch.cuda.synchronize()
            allocated = torch.cuda.memory_allocated() / (1024 ** 2)
            reserved = torch.cuda.memory_reserved() / (1024 ** 2)
            print(f"[Iteration {i+1}] CUDA Memory - Allocated: {allocated:.2f} MB, Reserved: {reserved:.2f} MB")
        else:
            print(f"[Iteration {i+1}] CPU run, no CUDA memory to display.")

        # Optional: simulate some delay
        time.sleep(0.1)

In [None]:
test_multihead_attention_loop()

[Iteration 1] CPU run, no CUDA memory to display.
[Iteration 2] CPU run, no CUDA memory to display.
