{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"machine_shape":"hm","gpuType":"L4","authorship_tag":"ABX9TyO9WIr+dMkZzui0zEfQ5GlL"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"DrPDyW5bjJh2","executionInfo":{"status":"ok","timestamp":1746937541098,"user_tz":-420,"elapsed":207,"user":{"displayName":"Laam Pham","userId":"04566654796696849937"}},"outputId":"e7fd988c-82b4-4256-e5c7-5b4231d49e1f"},"outputs":[{"output_type":"stream","name":"stdout","text":["Sun May 11 04:25:41 2025 \n","+-----------------------------------------------------------------------------------------+\n","| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |\n","|-----------------------------------------+------------------------+----------------------+\n","| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n","| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n","| | | MIG M. |\n","|=========================================+========================+======================|\n","| 0 NVIDIA L4 Off | 00000000:00:03.0 Off | 0 |\n","| N/A 64C P8 19W / 72W | 0MiB / 23034MiB | 0% Default |\n","| | | N/A |\n","+-----------------------------------------+------------------------+----------------------+\n"," \n","+-----------------------------------------------------------------------------------------+\n","| Processes: |\n","| GPU GI CI PID Type Process name GPU Memory |\n","| ID ID Usage |\n","|=========================================================================================|\n","| No running processes found |\n","+-----------------------------------------------------------------------------------------+\n"]}],"source":["!nvidia-smi"]},{"cell_type":"code","source":["import torch\n","import torch.nn as nn\n","from transformers import GPT2Model, GPT2Config, GPT2Tokenizer\n","import time\n","import copy\n","\n","# 0. Define the Tanh \"Normalization\" Layer\n","class TanhReplacement(nn.Module):\n"," def __init__(self, original_ln_config=None): # original_ln_config is for API compatibility if needed by some architectures\n"," super().__init__()\n"," # Tanh has no learnable parameters or specific configuration like hidden_size or eps\n"," # If the original LayerNorm had learnable parameters, they are now gone.\n"," # The shape of the output will be the same as the input, just like LayerNorm.\n"," if original_ln_config:\n"," self.normalized_shape = original_ln_config.normalized_shape\n"," self.eps = original_ln_config.eps\n"," # We don't actually use these, but store them for potential inspection\n","\n"," def forward(self, x):\n"," return torch.tanh(x)\n","\n","# 1. Setup: Model and Tokenizer\n","model_name = \"gpt2\" # Smallest GPT-2 for quicker testing\n","config = GPT2Config.from_pretrained(model_name)\n","tokenizer = GPT2Tokenizer.from_pretrained(model_name)\n","\n","# 2. Load Original GPT-2 Model\n","model_orig = GPT2Model.from_pretrained(model_name, config=config)\n","model_orig.eval() # Set to evaluation mode\n","\n","# 3. Create Modified Model (DyT-like)\n","model_dyt = copy.deepcopy(model_orig) # Deep copy to modify independently\n","\n","# Replace LayerNorms in the transformer blocks\n","for i, block in enumerate(model_dyt.h):\n"," # LayerNorm before MultiHeadAttention\n"," block.ln_1 = TanhReplacement(original_ln_config=block.ln_1) # Pass original LN config for potential reference\n"," # LayerNorm before MLP\n"," block.ln_2 = TanhReplacement(original_ln_config=block.ln_2)\n"," # print(f\"Replaced LayerNorms in block {i}\")\n","\n","# Replace final LayerNorm (if present in GPT2Model's main structure, GPT2Model has ln_f)\n","if hasattr(model_dyt, 'ln_f') and isinstance(model_dyt.ln_f, nn.LayerNorm):\n"," model_dyt.ln_f = TanhReplacement(original_ln_config=model_dyt.ln_f)\n"," # print(\"Replaced final LayerNorm (ln_f)\")\n","\n","model_dyt.eval()\n","\n","# --- Sanity check: Print model structures (optional) ---\n","# print(\"Original Model Structure (relevant parts):\")\n","# for i, block in enumerate(model_orig.h):\n","# print(f\"Block {i}: ln_1={block.ln_1}, ln_2={block.ln_2}\")\n","# print(f\"Final ln_f={model_orig.ln_f}\")\n","\n","# print(\"\\nModified Model Structure (relevant parts):\")\n","# for i, block in enumerate(model_dyt.h):\n","# print(f\"Block {i}: ln_1={block.ln_1}, ln_2={block.ln_2}\")\n","# print(f\"Final ln_f={model_dyt.ln_f}\")\n","# --- End Sanity check ---\n","\n","\n","# 4. Prepare Input Data\n","text = \"Remember that William Gibson quote The future is already here, it's just not evenly distributed? Surprise - the future is already here, and it is shockingly distributed. Power to the people. Personally, I love it.\"\n","# Using a moderately long sequence for better timing\n","text = \" \".join([\"test\"] * 100)\n","\n","text = \"\"\"\n","Transformative technologies usually follow a top-down diffusion path: originating in government or military contexts, passing through corporations, and eventually reaching individuals - think electricity, cryptography, computers, flight, the internet, or GPS. This progression feels intuitive, new and powerful technologies are usually scarce, capital-intensive, and their use requires specialized technical expertise in the early stages.\n","\n","So it strikes me as quite unique and remarkable that LLMs display a dramatic reversal of this pattern - they generate disproportionate benefit for regular people, while their impact is a lot more muted and lagging in corporations and governments. ChatGPT is the fastest growing consumer application in history, with 400 million weekly active users who use it for writing, coding, translation, tutoring, summarization, deep research, brainstorming, etc. This isn't a minor upgrade to what existed before, it is a major multiplier to an individual's power level across a broad range of capabilities. And the barrier to use is incredibly low - the models are cheap (free, even), fast, available to anyone on demand behind a url (or even local machine), and they speak anyone's native language, including tone, slang or emoji. This is insane. As far as I can tell, the average person has never experienced a technological unlock this dramatic, this fast.\n","\n","Why then are the benefits a lot more muted in the corporate and government realms? I think the first reason is that LLMs offer a very specific profile of capability - that of merely quasi-expert knowledge/performance, but simultaneously across a very wide variety of domains. In other words, they are simultaneously versatile but also shallow and fallible. Meanwhile, an organization's unique superpower is the ability to concentrate diverse expertise into a single entity by employing engineers, researchers, analysts, lawyers, marketers, etc. While LLMs can certainly make these experts more efficient individually (e.g. drafting initial legal clauses, generating boilerplate code, etc.), the improvement to the organization takes the form of becoming a bit better at the things it could already do. In contrast, an individual will usually only be an expert in at most one thing, so the broad quasi-expertise offered by the LLM fundamentally allows them to do things they couldn't do before. People can now vibe code apps. They can approach legal documents. They can grok esoteric research papers. They can do data analytics. They can generate multimodal content for branding and marketing. They can do all of this at an adequate capability without involving an additional expert.\n","\n","Second, organizations deal with problems of a lot greater complexity and necessary coordination, think: various integrations, legacy systems, corporate brand or style guides, stringent security protocols, privacy considerations, internationalization, regulatory compliance and legal risk. There are a lot more variables, a lot more constraints, a lot more considerations, and a lot lower margin for error. It's not so easy to put all of it into a context window. You can't just vibe code something. You might be one disastrous hallucination away from losing your job. And third, there is the well-documented inertia of a larger organization, featuring culture, historical precedents, political turf wars that escalate in periods of rapid change, communication overhead, re-training challenges of a distributed workforce and good old-fashioned bureaucracy. These are major headwinds when it comes to rapid adoption of a sparkling new, versatile-but-shallow-and-fallible tool. I don't wish to downplay the impacts of LLMs in corporations or governments, but at least for the moment and in aggregate across society, they have been significantly more life altering for individuals than they have been for organizations. Mary, Jim and Joes are experiencing the majority of the benefit, not Google or the government of the United States.\n","\n","Looking forward, the continued diffusion of LLMs of course depends on continued performance improvement and its capability profile. The \"benefit distribution\" overall is particularly interesting to chart, and depends heavily on the dynamic range of the performance as a function of capital expenditure. Today, frontier-grade LLM performance is very accessible and cheap. Beyond this point, you cannot spend a marginal dollar to get better performance, reliability or autonomy. Money can't buy better ChatGPT. Bill Gates talks to GPT 4o just like you do. But can this be expected to last? Train-time scaling (increase parameters, data), test-time scaling (increase time) and model ensembles (increase batch) are forces increasing the dynamic range. On the other hand, model distillation (the ability to train disproportionately powerful small models by training to mimic the big model) has been a force decreasing dynamic range. Certainly, the moment money can buy dramatically better ChatGPT, things change. Large organizations get to concentrate their vast resources to buy more intelligence. And within the category of \"individual\" too, the elite may once again split away from the rest of society. Their child will be tutored by GPT-8-pro-max-high, yours by GPT-6 mini.\n","\n","But at least at this moment in time, we find ourselves in a unique and unprecedented situation in the history of technology. If you go back through various sci-fi you'll see that very few would have predicted that the AI revolution would feature this progression. It was supposed to be a top secret government megabrain project wielded by the generals, not ChatGPT appearing basically overnight and for free on a device already in everyone's pocket. Remember that William Gibson quote \"The future is already here, it's just not evenly distributed\"? Surprise - the future is already here, and it is shockingly distributed. Power to the people. Personally, I love it.\n","\"\"\"\n","\n","inputs = tokenizer(text, return_tensors=\"pt\")\n","input_ids = inputs[\"input_ids\"]\n","attention_mask = inputs[\"attention_mask\"]\n","\n","# 5. Move models and data to device (GPU if available, else CPU)\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","print(f\"Using device: {device}\")\n","\n","model_orig.to(device)\n","model_dyt.to(device)\n","input_ids = input_ids.to(device)\n","attention_mask = attention_mask.to(device)\n","\n","# 6. Inference Time Measurement Function\n","def measure_inference_time(model, input_ids_tensor, attention_mask_tensor, num_runs=100, warmup_runs=10):\n"," # Warmup runs\n"," for _ in range(warmup_runs):\n"," with torch.no_grad():\n"," _ = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)\n","\n"," if device.type == 'cuda':\n"," torch.cuda.synchronize() # Ensure warmup is complete\n","\n"," total_time = 0\n"," start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_runs)] if device.type == 'cuda' else None\n"," end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_runs)] if device.type == 'cuda' else None\n","\n"," for i in range(num_runs):\n"," if device.type == 'cuda':\n"," start_events[i].record()\n"," else:\n"," start_time = time.perf_counter()\n","\n"," with torch.no_grad():\n"," _ = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)\n","\n"," if device.type == 'cuda':\n"," end_events[i].record()\n"," else:\n"," end_time = time.perf_counter()\n"," total_time += (end_time - start_time)\n","\n"," if device.type == 'cuda':\n"," torch.cuda.synchronize() # Wait for all runs to complete\n"," for i in range(num_runs):\n"," total_time += start_events[i].elapsed_time(end_events[i]) / 1000.0 # elapsed_time is in ms\n","\n"," avg_time = total_time / num_runs\n"," return avg_time\n","\n","# 7. Run and Compare\n","print(f\"\\nBenchmarking with sequence length: {input_ids.shape[1]}\")\n","num_runs = 200\n","warmup_runs = 20\n","\n","# --- Test outputs to ensure they are different (as expected) ---\n","# with torch.no_grad():\n","# out_orig = model_orig(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state\n","# out_dyt = model_dyt(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state\n","# print(f\"Output norm difference: {torch.norm(out_orig - out_dyt)}\")\n","# assert not torch.allclose(out_orig, out_dyt), \"Outputs should be different!\"\n","# ---\n","\n","avg_time_orig = measure_inference_time(model_orig, input_ids, attention_mask, num_runs, warmup_runs)\n","print(f\"Original GPT-2 avg inference time: {avg_time_orig*1000:.4f} ms\")\n","\n","avg_time_dyt = measure_inference_time(model_dyt, input_ids, attention_mask, num_runs, warmup_runs)\n","print(f\"DyT-like (Tanh) GPT-2 avg inference time: {avg_time_dyt*1000:.4f} ms\")\n","\n","if avg_time_dyt < avg_time_orig:\n"," speedup_percentage = ((avg_time_orig - avg_time_dyt) / avg_time_orig) * 100\n"," print(f\"Speedup with Tanh replacement: {speedup_percentage:.2f}%\")\n","else:\n"," slowdown_percentage = ((avg_time_dyt - avg_time_orig) / avg_time_orig) * 100\n"," print(f\"Slowdown with Tanh replacement: {slowdown_percentage:.2f}% (This is unexpected if Tanh is truly faster)\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":474},"id":"BqnapHTijcaV","executionInfo":{"status":"error","timestamp":1746937609166,"user_tz":-420,"elapsed":1241,"user":{"displayName":"Laam Pham","userId":"04566654796696849937"}},"outputId":"7da6806f-e1a2-4dba-bc57-4ecb03a9a201"},"execution_count":4,"outputs":[{"output_type":"stream","name":"stderr","text":["Token indices sequence length is longer than the specified maximum sequence length for this model (1171 > 1024). Running this sequence through the model will result in indexing errors\n"]},{"output_type":"stream","name":"stdout","text":["Using device: cuda\n","\n","Benchmarking with sequence length: 1171\n"]},{"output_type":"error","ename":"RuntimeError","evalue":"CUDA error: device-side assert triggered\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0;31m# ---\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 144\u001b[0;31m \u001b[0mavg_time_orig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmeasure_inference_time\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_orig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_runs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwarmup_runs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 145\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Original GPT-2 avg inference time: {avg_time_orig*1000:.4f} ms\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m\u001b[0m in \u001b[0;36mmeasure_inference_time\u001b[0;34m(model, input_ids_tensor, attention_mask_tensor, num_runs, warmup_runs)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwarmup_runs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minput_ids_tensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mattention_mask_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'cuda'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1737\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1738\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1739\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1740\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1741\u001b[0m \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1748\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1749\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1750\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1751\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1752\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/transformers/models/gpt2/modeling_gpt2.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 818\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minputs_embeds\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 819\u001b[0m \u001b[0minputs_embeds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwte\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 820\u001b[0;31m \u001b[0mposition_embeds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwpe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mposition_ids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 821\u001b[0m \u001b[0mhidden_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minputs_embeds\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mposition_embeds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs_embeds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 822\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1737\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1738\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1739\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1740\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1741\u001b[0m \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1748\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1749\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1750\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1751\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1752\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/sparse.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 190\u001b[0;31m return F.embedding(\n\u001b[0m\u001b[1;32m 191\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36membedding\u001b[0;34m(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001b[0m\n\u001b[1;32m 2549\u001b[0m \u001b[0;31m# remove once script supports set_grad_enabled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2550\u001b[0m \u001b[0m_no_grad_embedding_renorm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_norm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2551\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpadding_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscale_grad_by_freq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msparse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2552\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2553\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mRuntimeError\u001b[0m: CUDA error: device-side assert triggered\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n"]}]},{"cell_type":"code","source":["import torch\n","import torch.nn as nn\n","from transformers import GPT2Model, GPT2Config, GPT2Tokenizer\n","import time\n","import copy\n","\n","# 0. Define the Tanh \"Normalization\" Layer\n","class TanhReplacement(nn.Module):\n"," def __init__(self, original_ln_config=None): # original_ln_config is for API compatibility if needed by some architectures\n"," super().__init__()\n"," # Tanh has no learnable parameters or specific configuration like hidden_size or eps\n"," # If the original LayerNorm had learnable parameters, they are now gone.\n"," # The shape of the output will be the same as the input, just like LayerNorm.\n"," if original_ln_config:\n"," self.normalized_shape = original_ln_config.normalized_shape\n"," self.eps = original_ln_config.eps\n"," # We don't actually use these, but store them for potential inspection\n","\n"," def forward(self, x):\n"," return torch.tanh(x)\n","\n","# 1. Setup: Model and Tokenizer\n","model_name = \"gpt2\" # Smallest GPT-2 for quicker testing\n","config = GPT2Config.from_pretrained(model_name)\n","tokenizer = GPT2Tokenizer.from_pretrained(model_name)\n","\n","# 2. Load Original GPT-2 Model\n","model_orig = GPT2Model.from_pretrained(model_name, config=config)\n","model_orig.eval() # Set to evaluation mode\n","\n","# 3. Create Modified Model (DyT-like)\n","model_dyt = copy.deepcopy(model_orig) # Deep copy to modify independently\n","\n","# Replace LayerNorms in the transformer blocks\n","for i, block in enumerate(model_dyt.h):\n"," # LayerNorm before MultiHeadAttention\n"," block.ln_1 = TanhReplacement(original_ln_config=block.ln_1) # Pass original LN config for potential reference\n"," # LayerNorm before MLP\n"," block.ln_2 = TanhReplacement(original_ln_config=block.ln_2)\n"," # print(f\"Replaced LayerNorms in block {i}\")\n","\n","# Replace final LayerNorm (if present in GPT2Model's main structure, GPT2Model has ln_f)\n","if hasattr(model_dyt, 'ln_f') and isinstance(model_dyt.ln_f, nn.LayerNorm):\n"," model_dyt.ln_f = TanhReplacement(original_ln_config=model_dyt.ln_f)\n"," # print(\"Replaced final LayerNorm (ln_f)\")\n","\n","model_dyt.eval()\n","\n","# --- Sanity check: Print model structures (optional) ---\n","# print(\"Original Model Structure (relevant parts):\")\n","# for i, block in enumerate(model_orig.h):\n","# print(f\"Block {i}: ln_1={block.ln_1}, ln_2={block.ln_2}\")\n","# print(f\"Final ln_f={model_orig.ln_f}\")\n","\n","# print(\"\\nModified Model Structure (relevant parts):\")\n","# for i, block in enumerate(model_dyt.h):\n","# print(f\"Block {i}: ln_1={block.ln_1}, ln_2={block.ln_2}\")\n","# print(f\"Final ln_f={model_dyt.ln_f}\")\n","# --- End Sanity check ---\n","\n","\n","# 4. Prepare Input Data\n","text = \"Replace this with your desired input text for benchmarking. Longer sequences might show more difference.\"\n","# Using a moderately long sequence for better timing\n","text = \" \".join([\"test\"] * 100)\n","inputs = tokenizer(text, return_tensors=\"pt\")\n","input_ids = inputs[\"input_ids\"]\n","attention_mask = inputs[\"attention_mask\"]\n","\n","# 5. Move models and data to device (GPU if available, else CPU)\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","print(f\"Using device: {device}\")\n","\n","model_orig.to(device)\n","model_dyt.to(device)\n","input_ids = input_ids.to(device)\n","attention_mask = attention_mask.to(device)\n","\n","# 6. Inference Time Measurement Function\n","def measure_inference_time(model, input_ids_tensor, attention_mask_tensor, num_runs=100, warmup_runs=10):\n"," # Warmup runs\n"," for _ in range(warmup_runs):\n"," with torch.no_grad():\n"," _ = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)\n","\n"," if device.type == 'cuda':\n"," torch.cuda.synchronize() # Ensure warmup is complete\n","\n"," total_time = 0\n"," start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_runs)] if device.type == 'cuda' else None\n"," end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_runs)] if device.type == 'cuda' else None\n","\n"," for i in range(num_runs):\n"," if device.type == 'cuda':\n"," start_events[i].record()\n"," else:\n"," start_time = time.perf_counter()\n","\n"," with torch.no_grad():\n"," _ = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)\n","\n"," if device.type == 'cuda':\n"," end_events[i].record()\n"," else:\n"," end_time = time.perf_counter()\n"," total_time += (end_time - start_time)\n","\n"," if device.type == 'cuda':\n"," torch.cuda.synchronize() # Wait for all runs to complete\n"," for i in range(num_runs):\n"," total_time += start_events[i].elapsed_time(end_events[i]) / 1000.0 # elapsed_time is in ms\n","\n"," avg_time = total_time / num_runs\n"," return avg_time\n","\n","# 7. Run and Compare\n","print(f\"\\nBenchmarking with sequence length: {input_ids.shape[1]}\")\n","num_runs = 200\n","warmup_runs = 20\n","\n","# --- Test outputs to ensure they are different (as expected) ---\n","# with torch.no_grad():\n","# out_orig = model_orig(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state\n","# out_dyt = model_dyt(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state\n","# print(f\"Output norm difference: {torch.norm(out_orig - out_dyt)}\")\n","# assert not torch.allclose(out_orig, out_dyt), \"Outputs should be different!\"\n","# ---\n","\n","avg_time_orig = measure_inference_time(model_orig, input_ids, attention_mask, num_runs, warmup_runs)\n","print(f\"Original GPT-2 avg inference time: {avg_time_orig*1000:.4f} ms\")\n","\n","avg_time_dyt = measure_inference_time(model_dyt, input_ids, attention_mask, num_runs, warmup_runs)\n","print(f\"DyT-like (Tanh) GPT-2 avg inference time: {avg_time_dyt*1000:.4f} ms\")\n","\n","if avg_time_dyt < avg_time_orig:\n"," speedup_percentage = ((avg_time_orig - avg_time_dyt) / avg_time_orig) * 100\n"," print(f\"Speedup with Tanh replacement: {speedup_percentage:.2f}%\")\n","else:\n"," slowdown_percentage = ((avg_time_dyt - avg_time_orig) / avg_time_orig) * 100\n"," print(f\"Slowdown with Tanh replacement: {slowdown_percentage:.2f}% (This is unexpected if Tanh is truly faster)\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"IiIYSDjYjzhn","executionInfo":{"status":"ok","timestamp":1746937583133,"user_tz":-420,"elapsed":6848,"user":{"displayName":"Laam Pham","userId":"04566654796696849937"}},"outputId":"04721b36-199e-47d3-b414-080b24b81c4c"},"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Using device: cuda\n","\n","Benchmarking with sequence length: 100\n","Original GPT-2 avg inference time: 13.5885 ms\n","DyT-like (Tanh) GPT-2 avg inference time: 12.9903 ms\n","Speedup with Tanh replacement: 4.40%\n"]}]},{"cell_type":"code","source":["import torch\n","import torch.nn as nn\n","from transformers import GPT2Model, GPT2Config, GPT2Tokenizer\n","import time\n","import copy\n","import os\n","\n","# IMPORTANT FOR DEBUGGING CUDA ERRORS:\n","# This makes CUDA kernel launches synchronous. If a kernel errors,\n","# the Python stack trace will now point to the exact line that launched the faulty kernel.\n","os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n","\n","# 0. Define the Tanh \"Normalization\" Layer\n","class TanhReplacement(nn.Module):\n"," def __init__(self, original_ln_config=None): # original_ln_config is for API compatibility if needed\n"," super().__init__()\n"," # Store original config for potential inspection, though tanh doesn't use it\n"," if original_ln_config:\n"," self.normalized_shape = original_ln_config.normalized_shape\n"," self.eps = original_ln_config.eps\n"," # Tanh has no learnable parameters or specific configuration like hidden_size or eps\n"," # If the original LayerNorm had learnable parameters (gamma, beta), they are now gone.\n","\n"," def forward(self, x):\n"," # --- Debugging: Check for NaNs/Infs in input to this specific Tanh layer ---\n"," # if torch.isnan(x).any() or torch.isinf(x).any():\n"," # print(f\"WARNING: NaN or Inf detected in input to TanhReplacement! Min: {x.min().item()}, Max: {x.max().item()}\")\n","\n"," output = torch.tanh(x)\n","\n"," # --- Debugging: Check for NaNs/Infs in output of this Tanh layer ---\n"," # (tanh(nan) -> nan; tanh(inf) -> 1.0; tanh(-inf) -> -1.0)\n"," # if torch.isnan(output).any():\n"," # print(f\"WARNING: NaN detected in TanhReplacement output! Input was likely NaN.\")\n"," # if torch.isinf(output).any(): # Should not happen for tanh if input is not inf\n"," # print(f\"WARNING: Inf detected in TanhReplacement output! This is unexpected for tanh.\")\n"," return output\n","\n","# Function to recursively replace LayerNorm modules\n","def replace_layernorm_with_tanh_recursive(module):\n"," has_replaced = False\n"," for name, child_module in module.named_children():\n"," if isinstance(child_module, nn.LayerNorm):\n"," # print(f\"Replacing LayerNorm '{name}' in {module.__class__.__name__} with TanhReplacement\")\n"," setattr(module, name, TanhReplacement(original_ln_config=child_module))\n"," has_replaced = True\n"," else:\n"," has_replaced = replace_layernorm_with_tanh_recursive(child_module) or has_replaced # Recurse\n"," return has_replaced\n","\n","# 1. Setup: Model and Tokenizer\n","model_name = \"gpt2\" # Smallest GPT-2 for quicker testing\n","config = GPT2Config.from_pretrained(model_name)\n","tokenizer = GPT2Tokenizer.from_pretrained(model_name)\n","if tokenizer.pad_token is None: # GPT-2 often doesn't have a pad token; add if missing for batching\n"," tokenizer.pad_token = tokenizer.eos_token\n","\n","\n","# 2. Load Original GPT-2 Model\n","model_orig = GPT2Model.from_pretrained(model_name, config=config)\n","model_orig.eval() # Set to evaluation mode\n","\n","# 3. Create Modified Model (DyT-like)\n","model_dyt = copy.deepcopy(model_orig) # Deep copy to modify independently\n","replaced_in_dyt = replace_layernorm_with_tanh_recursive(model_dyt)\n","# print(f\"LayerNorms replaced in DyT model: {replaced_in_dyt}\")\n","model_dyt.eval()\n","\n","\n","# --- Sanity check: Print model structures (optional) ---\n","# print(\"Original Model Structure (example LayerNorm):\")\n","# if len(model_orig.h) > 0:\n","# print(f\"Block 0 ln_1: {model_orig.h[0].ln_1}\")\n","# if hasattr(model_orig, 'ln_f'):\n","# print(f\"Final ln_f: {model_orig.ln_f}\")\n","\n","# print(\"\\nModified Model Structure (example TanhReplacement):\")\n","# if len(model_dyt.h) > 0:\n","# print(f\"Block 0 ln_1: {model_dyt.h[0].ln_1}\")\n","# if hasattr(model_dyt, 'ln_f'):\n","# print(f\"Final ln_f: {model_dyt.ln_f}\")\n","# --- End Sanity check ---\n","\n","\n","# 4. Prepare Input Data\n","# text = \"Replace this with your desired input text for benchmarking. Longer sequences might show more difference.\"\n","text = \" \".join([\"test\"] * 50) # Using a shorter sequence for faster debugging if issues persist\n","inputs = tokenizer(text, return_tensors=\"pt\", padding=True, truncation=True, max_length=128)\n","input_ids = inputs[\"input_ids\"]\n","attention_mask = inputs[\"attention_mask\"]\n","\n","# 5. Move models and data to device (GPU if available, else CPU)\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","print(f\"Using device: {device}\")\n","\n","model_orig.to(device)\n","model_dyt.to(device)\n","input_ids = input_ids.to(device)\n","attention_mask = attention_mask.to(device)\n","\n","# 6. Inference Time Measurement Function\n","def measure_inference_time(model, model_desc, input_ids_tensor, attention_mask_tensor, num_runs=100, warmup_runs=10):\n"," # Warmup runs\n"," for _ in range(warmup_runs):\n"," with torch.no_grad():\n"," _ = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)\n","\n"," if device.type == 'cuda':\n"," torch.cuda.synchronize() # Ensure warmup is complete\n","\n"," total_time = 0\n"," start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_runs)] if device.type == 'cuda' else None\n"," end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_runs)] if device.type == 'cuda' else None\n","\n"," for i in range(num_runs):\n"," if device.type == 'cuda':\n"," start_events[i].record()\n"," else:\n"," start_time = time.perf_counter()\n","\n"," with torch.no_grad():\n"," _ = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)\n","\n"," if device.type == 'cuda':\n"," end_events[i].record()\n"," else:\n"," end_time = time.perf_counter()\n"," total_time += (end_time - start_time)\n","\n"," if device.type == 'cuda':\n"," torch.cuda.synchronize() # Wait for all runs to complete\n"," for i in range(num_runs):\n"," total_time += start_events[i].elapsed_time(end_events[i]) / 1000.0 # elapsed_time is in ms\n","\n"," avg_time = total_time / num_runs\n"," return avg_time\n","\n","# 7. Run and Compare\n","print(f\"\\nBenchmarking with sequence length: {input_ids.shape[1]}\")\n","num_runs = 50 # Reduced for potentially faster debugging cycles, increase for stable benchmarks\n","warmup_runs = 5\n","\n","avg_time_orig = float('nan')\n","avg_time_dyt = float('nan')\n","\n","try:\n"," avg_time_orig = measure_inference_time(model_orig, \"Original GPT-2\", input_ids, attention_mask, num_runs, warmup_runs)\n"," print(f\"Original GPT-2 avg inference time: {avg_time_orig*1000:.4f} ms\")\n","except Exception as e:\n"," print(f\"Error benchmarking original model: {e}\")\n"," # This would be very unexpected\n","\n","try:\n"," avg_time_dyt = measure_inference_time(model_dyt, \"DyT-like (Tanh) GPT-2\", input_ids, attention_mask, num_runs, warmup_runs)\n"," print(f\"DyT-like (Tanh) GPT-2 avg inference time: {avg_time_dyt*1000:.4f} ms\")\n","except RuntimeError as e:\n"," print(f\"\\nERROR during benchmarking of DyT-like model: {e}\")\n"," if \"CUDA error: device-side assert triggered\" in str(e):\n"," print(\"\\n\" + \"=\"*50)\n"," print(\"A 'device-side assert' was triggered in the DyT-like model.\")\n"," print(\"This LIKELY means that replacing LayerNorm with tanh in the pre-trained GPT-2\")\n"," print(\"caused numerical instability (e.g., NaNs or Infs), which then made a GPU kernel fail.\")\n"," print(\"The Python stack trace (above, with CUDA_LAUNCH_BLOCKING=1) should now pinpoint\")\n"," print(\"the exact operation in the Hugging Face code where the problem occurred.\")\n"," print(\"Common culprits are operations like softmax in attention if inputs become too large/NaN.\")\n"," print(\"This experiment highlights that LayerNorm is critical for the stability of pre-trained models;\")\n"," print(\"simply swapping it out without retraining is generally not viable for correct model output.\")\n"," print(\"The DyT paper's hypothesis is about training models with tanh units from scratch.\")\n"," print(\"=\"*50 + \"\\n\")\n"," # No further comparison if DyT model failed\n"," avg_time_dyt = float('nan') # Ensure it's NaN if it failed\n","\n","if not (torch.isnan(torch.tensor(avg_time_orig)) or torch.isnan(torch.tensor(avg_time_dyt))):\n"," if avg_time_dyt < avg_time_orig:\n"," speedup_percentage = ((avg_time_orig - avg_time_dyt) / avg_time_orig) * 100\n"," print(f\"Speedup with Tanh replacement: {speedup_percentage:.2f}%\")\n"," else:\n"," slowdown_percentage = ((avg_time_dyt - avg_time_orig) / avg_time_orig) * 100\n"," print(f\"Slowdown with Tanh replacement: {slowdown_percentage:.2f}% (This is unexpected if Tanh is truly faster and model runs)\")\n","else:\n"," print(\"\\nCould not complete the performance comparison due to errors or incomplete runs.\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ELbp7_sUk6no","executionInfo":{"status":"ok","timestamp":1746937556758,"user_tz":-420,"elapsed":10717,"user":{"displayName":"Laam Pham","userId":"04566654796696849937"}},"outputId":"2f46402a-3965-45e1-f03d-70f99fa6d1b8"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["Using device: cuda\n","\n","Benchmarking with sequence length: 50\n","Original GPT-2 avg inference time: 12.8963 ms\n","DyT-like (Tanh) GPT-2 avg inference time: 12.3657 ms\n","Speedup with Tanh replacement: 4.11%\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"Ngv4Pof_lvmb"},"execution_count":null,"outputs":[]}]}