{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "133a2e2e", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import time\n", "from tqdm import trange" ] }, { "cell_type": "code", "execution_count": 2, "id": "d3259fde", "metadata": {}, "outputs": [], "source": [ "def test_multihead_attention_loop():\n", " # Set device\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", " # Parameters\n", " embed_dim = 64*10\n", " num_heads = 64\n", " seq_len = 1024\n", " batch_size = 64\n", " iterations = 100\n", "\n", " # Create MultiheadAttention module\n", " mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True).to(device)\n", "\n", " # for i in trange(iterations, ncols=70, colour='blue',):\n", " for i in range(iterations):\n", " # Generate fresh input for each iteration\n", " query = torch.rand(batch_size, seq_len, embed_dim, device=device)\n", " key = torch.rand(batch_size, seq_len, embed_dim, device=device)\n", " value = torch.rand(batch_size, seq_len, embed_dim, device=device)\n", "\n", " # Optional causal mask\n", " attn_mask = torch.triu(torch.ones(seq_len, seq_len, device=device) * float('-inf'), diagonal=1)\n", "\n", " # Run forward pass\n", " output, attn_weights = mha(query, key, value, attn_mask=attn_mask)\n", "\n", " # Print memory stats\n", " if device.type == \"cuda\":\n", " torch.cuda.synchronize()\n", " allocated = torch.cuda.memory_allocated() / (1024 ** 2)\n", " reserved = torch.cuda.memory_reserved() / (1024 ** 2)\n", " print(f\"[Iteration {i+1}] CUDA Memory - Allocated: {allocated:.2f} MB, Reserved: {reserved:.2f} MB\")\n", " else:\n", " print(f\"[Iteration {i+1}] CPU run, no CUDA memory to display.\")\n", "\n", " # Optional: simulate some delay\n", " time.sleep(0.1)" ] }, { "cell_type": "code", "execution_count": null, "id": "c40d9048", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Iteration 1] CPU run, no CUDA memory to display.\n", "[Iteration 2] CPU run, no CUDA memory to display.\n" ] } ], "source": [ "test_multihead_attention_loop()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }