diff --git "a/data/basic_loading_and_analysing.ipynb" "b/data/basic_loading_and_analysing.ipynb" new file mode 100644--- /dev/null +++ "b/data/basic_loading_and_analysing.ipynb" @@ -0,0 +1,3767 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "MNk7IylTv610" + }, + "source": [ + "# Loading and Analysing Pre-Trained Sparse Autoencoders" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i_DusoOvwV0M" + }, + "source": [ + "## Imports & Installs" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "yfDUxRx0wSRl" + }, + "outputs": [], + "source": [ + "import os\n", + "try:\n", + " import google.colab # type: ignore\n", + " from google.colab import output\n", + "\n", + " COLAB = True\n", + " %pip install sae-lens transformer-lens sae-dashboard\n", + "except:\n", + " COLAB = False\n", + " from IPython import get_ipython # type: ignore\n", + "\n", + " ipython = get_ipython()\n", + " assert ipython is not None\n", + " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + " ipython.run_line_magic(\"autoreload\", \"2\")\n", + "\n", + "# Standard imports\n", + "import os\n", + "import torch\n", + "from tqdm import tqdm\n", + "import plotly.express as px\n", + "\n", + "# Imports for displaying vis in Colab / notebook\n", + "import webbrowser\n", + "import http.server\n", + "import socketserver\n", + "import threading\n", + "\n", + "PORT = 8000\n", + "\n", + "torch.set_grad_enabled(False);" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7aGgWkbav610" + }, + "source": [ + "## Set Up" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rQSD7trbv610", + "outputId": "222a40c4-75d4-46e2-ed3f-991841144926" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Device: cuda:0\n" + ] + } + ], + "source": [ + "# For the most part I'll try to import functions and classes near where they are used\n", + "# to make it clear where they come from.\n", + "\n", + "if torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + "else:\n", + " device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "print(f\"Device: {device}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "cPUq_bdW8mcp" + }, + "outputs": [], + "source": [ + "def display_vis_inline(filename: str, height: int = 850):\n", + " \"\"\"\n", + " Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each\n", + " vis has a unique port without having to define a port within the function.\n", + " \"\"\"\n", + " if not (COLAB):\n", + " webbrowser.open(filename)\n", + "\n", + " else:\n", + " global PORT\n", + "\n", + " def serve(directory):\n", + " os.chdir(directory)\n", + "\n", + " # Create a handler for serving files\n", + " handler = http.server.SimpleHTTPRequestHandler\n", + "\n", + " # Create a socket server with the handler\n", + " with socketserver.TCPServer((\"\", PORT), handler) as httpd:\n", + " print(f\"Serving files from {directory} on port {PORT}\")\n", + " httpd.serve_forever()\n", + "\n", + " thread = threading.Thread(target=serve, args=(\"/content\",))\n", + " thread.start()\n", + "\n", + " output.serve_kernel_port_as_iframe(\n", + " PORT, path=f\"/{filename}\", height=height, cache_in_notebook=True\n", + " )\n", + " PORT += 1" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XoMx3VZpv611" + }, + "source": [ + "# Loading a pretrained Sparse Autoencoder\n", + "\n", + "Below we load a Transformerlens model, a pretrained SAE and a dataset from huggingface." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sNSfL80Uv611" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-01-13 14:07:35.384788: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2025-01-13 14:07:35.395937: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2025-01-13 14:07:35.407588: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "2025-01-13 14:07:35.411060: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "2025-01-13 14:07:35.422719: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2025-01-13 14:07:36.128010: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", + "/home/watchtower/.pyenv/versions/venv1/lib/python3.10/site-packages/sae_lens/sae.py:145: UserWarning: \n", + "This SAE has non-empty model_from_pretrained_kwargs. \n", + "For optimal performance, load the model like so:\n", + "model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)\n", + " warnings.warn(\n", + "/home/watchtower/.pyenv/versions/venv1/lib/python3.10/site-packages/sae_lens/sae.py:635: UserWarning: norm_scaling_factor not found for saidines12/sae-telugu-features and blocks.8.hook_resid_post, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding.\n", + " warnings.warn(\n", + "WARNING:root:You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constantSetting center_unembed=False instead.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "230e7c03481b4479b558cbefb005015f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00value=%{x}
count=%{y}", + "legendgroup": "0", + "marker": { + "color": "#636efa", + "pattern": { + "shape": "" + } + }, + "name": "0", + "offsetgroup": "0", + "orientation": "v", + "showlegend": true, + "type": "histogram", + "x": [ + 14300, + 11917, + 12035, + 11422, + 10833, + 11250, + 11891, + 10705, + 11326, + 10154, + 11129, + 10761, + 11463, + 11350, + 11813, + 11784, + 11891, + 11720, + 11526, + 11250, + 10940, + 11182, + 11176, + 11108, + 11367, + 11911, + 12002, + 11221, + 11897, + 11790, + 12199, + 10533, + 11564, + 10842, + 11062, + 11854, + 12024, + 11462, + 11538, + 12068, + 11225, + 11637, + 11892, + 11050, + 11307, + 10969, + 11015, + 10934, + 10988, + 11254, + 10045, + 11506, + 11500, + 11403, + 11036, + 11058, + 10977, + 11190, + 11857, + 11459, + 11489, + 11801, + 11886, + 11784, + 11746, + 10951, + 11371, + 11777, + 11702, + 11131, + 11616, + 10629, + 10729, + 10421, + 10581, + 10984, + 11809, + 10997, + 11581, + 12207, + 11956, + 11200, + 11421, + 11258, + 11525, + 10728, + 12150, + 11663, + 11316, + 11330, + 11797, + 12115, + 10676, + 11072, + 12149, + 10742, + 11394, + 11635, + 11599, + 11273, + 11091, + 10741, + 11099, + 11486, + 11826, + 11262, + 10474, + 10784, + 11309, + 11155, + 11346, + 11893, + 11583, + 11459, + 12281, + 10555, + 11432, + 12388, + 11508, + 10154, + 11072, + 10997, + 10544, + 11161, + 11105, + 11230, + 11400, + 11676, + 11099, + 11301, + 11660, + 11338, + 11134, + 10124, + 10443, + 10298, + 11326, + 11026, + 11238, + 11896, + 11776, + 11191, + 10712, + 11303, + 11426, + 11863, + 11906, + 11108, + 11129, + 10534, + 11267, + 11445, + 10884, + 11637, + 11843, + 10994, + 10003, + 10419, + 10652, + 11748, + 11736, + 11975, + 10812, + 11441, + 11503, + 11313, + 11586, + 10790, + 11627, + 11870, + 10817, + 11102, + 11598, + 11827, + 12047, + 11048, + 10082, + 10357, + 11290, + 10261, + 11647, + 10696, + 11254, + 11051, + 10639, + 10334, + 10451, + 11340, + 10591, + 10812, + 11105, + 10849, + 9376, + 11713, + 11413, + 10866, + 10281, + 10167, + 11044, + 11192, + 11711, + 11629, + 11243, + 11204, + 9996, + 10796, + 11488, + 11133, + 10886, + 11094, + 10952, + 10829, + 11171, + 11560, + 11633, + 11679, + 10672, + 10522, + 10753, + 11149, + 10101, + 10799, + 11378, + 9892, + 10844, + 11664, + 11937, + 12020, + 11691, + 11534, + 10025, + 11238, + 10366, + 10503, + 11152, + 10987, + 11375, + 11033, + 10428, + 10600, + 10711, + 10796, + 10598, + 10313, + 11078, + 10929, + 11074, + 10835, + 11325, + 10982, + 11478, + 10190, + 10388, + 10550, + 11761, + 11706, + 11075, + 9996, + 10624, + 11820, + 11311, + 10461, + 11305, + 11572, + 11732, + 11785, + 11248, + 10693, + 11552, + 11426, + 10683, + 10369, + 11148, + 11237, + 10343, + 10781, + 10886, + 10884, + 11477, + 11599, + 11641, + 10590, + 10752, + 11789, + 10869, + 10576, + 11078, + 11028, + 10591, + 10943, + 10807, + 10793, + 11017, + 11362, + 10278, + 10747, + 10927, + 11332, + 11106, + 10436, + 11103, + 11387, + 11235, + 10736, + 10913, + 11696, + 11630, + 10865, + 11037, + 11783, + 11068, + 11249, + 11637, + 10387, + 11370, + 11008, + 11141, + 11565, + 11459, + 11030, + 11332, + 11347, + 11001, + 10972, + 11465, + 11541, + 11694, + 11365, + 11301, + 10999, + 11297, + 11078, + 11191, + 11211, + 11035, + 10744, + 11608, + 11215, + 11625, + 10606, + 10889, + 11637, + 10763, + 10611, + 11212, + 11394, + 10667, + 11450, + 11413, + 11208, + 11215, + 10831, + 10871, + 11631, + 11149, + 12073, + 11709, + 10866, + 11806, + 10742, + 10704, + 11487, + 11630, + 11361, + 11658, + 11654, + 10788, + 11436, + 12092, + 11459, + 11596, + 10246, + 10477, + 10696, + 10984, + 11066, + 10780, + 10644, + 11315, + 11660, + 11638, + 10765, + 10884, + 11073, + 10332, + 11073, + 11057, + 11622, + 11006, + 11015, + 11809, + 11501, + 10511, + 10965, + 12005, + 11675, + 10864, + 10356, + 10842, + 11911, + 11602, + 11376, + 11047, + 10433, + 11561, + 11025, + 11965, + 12085, + 11876, + 10615, + 10973, + 11197, + 11447, + 10900, + 10436, + 10883, + 11500, + 10904, + 11544, + 11708, + 10654, + 10188, + 10555, + 11423, + 10564, + 11380, + 11072, + 11008, + 10784, + 10743, + 11302, + 10486, + 11211, + 11727, + 11501, + 11000, + 10693, + 11604, + 10677, + 11142, + 10991, + 10801, + 10894, + 11775, + 11483, + 10796, + 11252, + 11006, + 11285, + 11956, + 11829, + 11780, + 11093, + 11116, + 11174, + 10046, + 10926, + 11678, + 12158, + 11844, + 11926, + 10625, + 10443, + 10736, + 11381, + 11232, + 12146, + 11829, + 11743, + 11049, + 10728, + 11567, + 11121, + 10856, + 11486, + 10664, + 10919, + 11021, + 10785, + 11198, + 11146, + 10965, + 10428, + 11044, + 10747, + 11185, + 11190, + 11363, + 11068, + 11058, + 10836, + 11369, + 11028, + 11576, + 11234, + 10568, + 11060, + 11052, + 11647, + 11528, + 10966, + 11655, + 11735, + 11570, + 11442, + 11149, + 11788, + 11207, + 10263, + 10788, + 11476, + 10703, + 11634, + 10923, + 10312, + 11447, + 11571, + 11086, + 11157, + 11257, + 11118, + 10643, + 10956, + 11721, + 12030, + 11953, + 11649, + 11535, + 10450, + 11060, + 11256, + 10797, + 11150, + 11234, + 11072, + 11157, + 11557, + 11492, + 10587, + 11383, + 10885, + 10960, + 11390, + 10659, + 11258, + 11570, + 11317, + 10814, + 10766, + 11485, + 11004, + 10077, + 10651, + 10812, + 10961, + 11004, + 11249, + 10917, + 10624, + 10582, + 11004, + 11008, + 11026, + 11191, + 11477, + 10435, + 11084, + 11853, + 10970, + 11616, + 11925, + 10685, + 11050, + 10964, + 11537, + 11055, + 11151, + 11030, + 11607, + 11792, + 10897, + 11714, + 11500, + 11185, + 11279, + 10591, + 11474, + 10899, + 11368, + 11597, + 11262, + 10674, + 10378, + 11393, + 10792, + 11757, + 11513, + 11413, + 12131, + 11260, + 11396, + 11483, + 11538, + 11816, + 11516, + 10580, + 10698, + 11429, + 10252, + 10679, + 11713, + 11086, + 11505, + 10419, + 11815, + 12026, + 11483, + 10855, + 11512, + 11596, + 11694, + 11268, + 11153, + 11144, + 10287, + 10393, + 10604, + 11285, + 10645, + 10346, + 11547, + 11704, + 11121, + 11444, + 11692, + 11400, + 10200, + 10896, + 11143, + 10826, + 11555, + 11593, + 11709, + 11881, + 11628, + 10567, + 11539, + 11995, + 11383, + 11237, + 11912, + 11419, + 10367, + 11067, + 10317, + 10918, + 11618, + 11437, + 10333, + 10751, + 11248, + 10755, + 11149, + 11240, + 10315, + 11064, + 11456, + 11584, + 10416, + 10630, + 10332, + 10797, + 10306, + 10374, + 9870, + 10148, + 11088, + 10648, + 11146, + 11087, + 11118, + 11080, + 10688, + 10859, + 10904, + 10731, + 11016, + 11076, + 10486, + 11272, + 11696, + 11069, + 11465, + 11176, + 10862, + 11642, + 11210, + 10262, + 10948, + 10926, + 11047, + 11593, + 11429, + 10480, + 10910, + 11497, + 10973, + 11169, + 11206, + 10689, + 10066, + 11160, + 11632, + 11619, + 10871, + 10724, + 10743, + 11283, + 10910, + 11802, + 11561, + 11301, + 11361, + 11353, + 11815, + 10802, + 11535, + 11603, + 10940, + 10555, + 10780, + 11374, + 11110, + 11386, + 11269, + 10885, + 10910, + 10623, + 9314, + 10267, + 11150, + 10828, + 11780, + 9697, + 10972, + 10931, + 11044, + 11478, + 11188, + 11799, + 11162, + 11369, + 11177, + 11209, + 11920, + 11420, + 11585, + 10441, + 10212, + 11722, + 10255, + 11123, + 11331, + 10896, + 10777, + 10797, + 10517, + 10966, + 11555, + 11489, + 11520, + 10726, + 11651, + 11512, + 10361, + 8953, + 10603, + 13196, + 11105, + 10976, + 12010, + 10571, + 12106, + 11313, + 11899, + 11931, + 11167, + 11042, + 12206, + 11482, + 10613, + 11566, + 11169, + 10974, + 11109, + 10715, + 11284, + 11404, + 11644, + 11913, + 10945, + 10978, + 11838, + 11531, + 11756, + 11581, + 11789, + 11159, + 10889, + 11396, + 12233, + 12324, + 11507, + 11144, + 12174, + 10666, + 11536, + 10702, + 11480, + 11270, + 10789, + 11209, + 10886, + 10748, + 11460, + 11522, + 11016, + 10327, + 10824, + 11965, + 11557, + 11463, + 11050, + 11117, + 11495, + 11203, + 11428, + 11262, + 10579, + 11297, + 11270, + 11533, + 10975, + 11405, + 11137, + 11405, + 11669, + 10840, + 11123, + 10864, + 10060, + 11042, + 11832, + 12017, + 10978, + 11599, + 11456, + 12245, + 12048, + 11321, + 11746, + 11358, + 10943, + 10688, + 11655, + 11345, + 11648, + 11347, + 10806, + 10885, + 11059, + 11344, + 11974, + 12028, + 10794, + 11937, + 11427, + 12016, + 11596, + 11615, + 10718, + 11544, + 11220, + 10771, + 10759, + 10955, + 10731, + 10877, + 11394, + 10848, + 11138, + 10926, + 11371, + 11546, + 11029, + 11541, + 11220, + 10667, + 11323, + 11542, + 11058, + 11520, + 11282, + 11672, + 10899, + 11388, + 11493, + 11301, + 11316, + 10549, + 11250, + 11438, + 11777, + 11473, + 11143, + 11650, + 11508, + 10674, + 10654, + 11150, + 11371, + 10865, + 11557, + 10183, + 11694, + 11432, + 11330, + 11345, + 11159, + 11150, + 10865, + 11305, + 11095, + 10894, + 10404, + 11076, + 11329, + 10428, + 10978, + 11558, + 10542, + 10878, + 11898, + 12367, + 12347, + 11077, + 11470, + 12024, + 11010, + 12117, + 11521, + 11071, + 11739, + 10762, + 11661, + 12216, + 11175, + 10692, + 11922, + 12059, + 11452, + 10559, + 11231, + 11862, + 11101, + 11336, + 11071, + 11556, + 10748, + 11126, + 11130, + 11378, + 10631, + 11627, + 11399, + 11496, + 11358, + 10863, + 10726, + 10792, + 11996, + 10663, + 10595, + 11623, + 10774, + 11541, + 11665, + 11127, + 11106, + 10163, + 11091, + 10415, + 10782, + 11799, + 11589, + 11648, + 11966, + 11658, + 11319, + 11895, + 11242, + 10218, + 9991, + 10327, + 10120, + 11261, + 11053, + 10759, + 12043, + 10775, + 11008, + 11013, + 11382, + 11095, + 10530, + 10705, + 11531, + 10947, + 11656, + 10945, + 11052, + 11780, + 10616, + 10453, + 11002, + 11231, + 11412, + 10654, + 11392, + 11530, + 11319, + 10828, + 10246, + 11212, + 9532, + 11396, + 10736, + 10469, + 10575, + 11309, + 11451, + 10525, + 10238, + 10964, + 11145, + 10946, + 11070, + 11666, + 12309, + 12279, + 11752, + 10828, + 11867, + 11444, + 11543, + 10348, + 10925, + 11210, + 10937, + 10588, + 11746, + 11258, + 11344, + 11143, + 11151, + 10906, + 10912, + 11046, + 11029, + 11532, + 10554, + 11030, + 11430, + 11352, + 11324, + 10111, + 10401, + 10656, + 11181, + 10495, + 10788, + 11524, + 11153, + 11418, + 10664, + 11306, + 10212, + 11744, + 11150, + 11044, + 11066, + 10546, + 11598, + 10498, + 10776, + 11531, + 12068, + 12260, + 11338, + 10743, + 11139, + 10825, + 10609, + 11262, + 11248, + 10974, + 11444, + 11448, + 11523, + 10244, + 11097, + 11539, + 10739, + 11900, + 10847, + 11690, + 10406, + 10310, + 10533, + 11623, + 11693, + 10745, + 10743, + 12322, + 11542, + 11739, + 10920, + 11340, + 11266, + 10724, + 11411, + 11253, + 10332, + 10876, + 10904, + 11545, + 11333, + 11354, + 11865, + 10938, + 11382, + 11298, + 11084, + 11102, + 10154, + 10771, + 10975, + 10564, + 11817, + 11213, + 11977, + 11707, + 11109, + 10789, + 10718, + 10535, + 11364, + 11527, + 11291, + 11502, + 11341, + 10068, + 11554, + 11715, + 11525, + 12070, + 12495, + 12106, + 11691, + 11738, + 12029, + 11114, + 10819, + 11026, + 12003, + 11878, + 11537, + 10982, + 11496, + 11575, + 10710, + 11471, + 11417, + 10437, + 11030, + 11548, + 12227, + 12283, + 11182, + 10954, + 10644, + 10537, + 11677, + 11690, + 11534, + 11740, + 10646, + 11676, + 12381, + 11855, + 10817, + 11987, + 12188, + 10315, + 10942, + 11540, + 11408, + 11047, + 11426, + 11634, + 11674, + 10551, + 10479, + 11321, + 11379, + 11645, + 11688, + 11122, + 11335, + 11515, + 10775, + 11194, + 11612, + 11602, + 11472, + 10045, + 10325, + 11172, + 11268, + 10129, + 10927, + 10681, + 10947, + 10795, + 11045, + 11151, + 10810, + 10201, + 10975, + 11743, + 11067, + 11501, + 10746, + 11615, + 11316, + 12269, + 12226, + 11617, + 10320, + 10855, + 11406, + 11542, + 11852, + 11497, + 11495, + 11151, + 11682, + 11417, + 10623, + 10463, + 11573, + 12785, + 11387, + 8944, + 10846, + 11483, + 11979, + 11548, + 10621, + 11011, + 10737, + 11331, + 11233, + 11371, + 11781, + 11619, + 11662, + 11887, + 11425, + 10500, + 11259, + 11739, + 11367, + 11044, + 11504, + 11502, + 11760, + 11295, + 11790, + 11602, + 11737, + 11045, + 11723, + 10979, + 11432, + 10999, + 12760, + 11653, + 11309, + 10985, + 11274, + 11653, + 11556, + 11451, + 11798, + 11268, + 11651, + 11122, + 10462, + 11135, + 10698, + 10518, + 11235, + 11709, + 10862, + 11016, + 10847, + 11576, + 11914, + 11823, + 11719, + 10920, + 10615, + 10233, + 10894, + 11165, + 11305, + 11320, + 10978, + 11046, + 10795, + 10805, + 11376, + 11268, + 11018, + 11213, + 11337, + 11375, + 11918, + 10632, + 11897, + 10831, + 11039, + 10997, + 11444, + 10982, + 11189, + 10603, + 11506, + 10661, + 11366, + 10128, + 10862, + 11530, + 11157, + 11027, + 10943, + 10857, + 11280, + 11693, + 11911, + 10180, + 11220, + 11177, + 12045, + 10384, + 11105, + 10793, + 11002, + 10801, + 11737, + 10624, + 11078, + 11471, + 11843, + 10787, + 10454, + 11285, + 10380, + 10790, + 11109, + 11803, + 11639, + 11251, + 11702, + 11999, + 10872, + 11888, + 10795, + 11201, + 11305, + 10702, + 10328, + 11318, + 10986, + 10818, + 11072, + 11649, + 11559, + 11851, + 11749, + 10580, + 11344, + 10824, + 10975, + 11439, + 11034, + 11314, + 11003, + 11285, + 11456, + 11561, + 11892, + 11526, + 11462, + 11657, + 11677, + 9984, + 10640, + 10739, + 11528, + 11138, + 11148, + 11819, + 10455, + 11011, + 12077, + 12101, + 10431, + 11131, + 11315, + 10831, + 10471, + 11563, + 11369, + 10378, + 11337, + 11354, + 11612, + 10745, + 11049, + 11548, + 11729, + 11628, + 11639, + 11324, + 10701, + 10934, + 10884, + 11311, + 11312, + 10952, + 10844, + 11324, + 11438, + 10963, + 11362, + 11576, + 11996, + 11458, + 10729, + 10636, + 10618, + 10149, + 10496, + 10878, + 10768, + 10764, + 10386, + 11180, + 11593, + 11534, + 11841, + 11681, + 10390, + 10427, + 11028, + 11599, + 10857, + 10785, + 10792, + 11780, + 11312, + 11007, + 10477, + 11069, + 11639, + 11801, + 10348, + 11142, + 10815, + 11159, + 11020, + 11516, + 10614, + 11353, + 11265, + 11410, + 10991, + 11214, + 10044, + 10927, + 11363, + 11878, + 11317, + 11166, + 11131, + 11797, + 11340, + 11486, + 11512, + 10048, + 11020, + 10289, + 11172, + 11134, + 11450, + 11557, + 10289, + 11004, + 11570, + 11940, + 11529, + 10952, + 11973, + 9563, + 9553, + 10230, + 11469, + 11384, + 11396, + 10855, + 10377, + 10788, + 10866, + 11284, + 11502, + 11713, + 10964, + 10473, + 11338, + 11684, + 11485, + 11923, + 11780, + 11394, + 10535, + 11702, + 10197, + 11443, + 11386, + 11671, + 11602, + 10451, + 11014, + 10424, + 10912, + 11720, + 10692, + 11034, + 11393, + 11703, + 10940, + 11080, + 10562, + 11314, + 11729, + 11082, + 10787, + 10625, + 10113, + 10915, + 11320, + 11030, + 11320, + 11808, + 10849, + 11360, + 11113, + 10772, + 10372, + 10995, + 10951, + 11685, + 11700, + 10673, + 11353, + 11267, + 11647, + 10777, + 10980, + 11512, + 11999, + 11766, + 11343, + 10893, + 10779, + 10125, + 10171, + 10977, + 10866, + 11592, + 10171, + 11160, + 10678, + 11292, + 11820, + 11273, + 9821, + 10441, + 11078, + 11478, + 10939, + 11646, + 10852, + 10711, + 10645, + 11196, + 11002, + 11057, + 10961, + 11177, + 11195, + 10463, + 10969, + 11527, + 11324, + 10443, + 10617, + 10973, + 11081, + 10450, + 10531, + 10673, + 11206, + 10679, + 10967, + 10377, + 11611, + 11624, + 11390, + 11546, + 10432, + 10734, + 10909, + 10474, + 11201, + 11360, + 11514, + 10129, + 11459, + 10663, + 11284, + 10773, + 11212, + 11020, + 11390, + 10338, + 10856, + 10566, + 11494, + 10414, + 11015, + 11823, + 11026, + 11029, + 11579, + 10634, + 11884, + 10374, + 11404, + 11282, + 10441, + 10056, + 11280, + 11075, + 10787, + 11114, + 11484, + 11552, + 11325, + 11171, + 11304, + 10242, + 11553, + 11540, + 10890, + 10905, + 11140, + 11638, + 11515, + 11621, + 11548, + 11194, + 9434, + 10826, + 11482, + 11767, + 11890, + 11373, + 11093, + 10882, + 10532, + 10672, + 11422, + 10321, + 10801, + 11006, + 11386, + 11382, + 10377, + 11055, + 11376, + 11413, + 11239, + 10932, + 11314, + 11028, + 10939, + 11387, + 10992, + 11480, + 10644, + 10473, + 11182, + 11088, + 11230, + 10397, + 10938, + 10818, + 11471, + 10844, + 11107, + 11046, + 10993, + 10604, + 11250, + 11271, + 10739, + 11166, + 11360, + 10330, + 10427, + 10420, + 10513, + 11506, + 10615, + 11137, + 11865, + 10857, + 11379, + 11374, + 11259, + 10328, + 10764, + 11579, + 10638, + 10757, + 11178, + 11126, + 8860, + 11052, + 11646, + 11620, + 11706, + 11541, + 11880, + 11719, + 11647, + 11029, + 11557, + 10876, + 11470, + 10783, + 12285, + 11670, + 11288, + 10524, + 11151, + 11569, + 11351, + 11355, + 11746, + 11224, + 11711, + 11134, + 10864, + 10861, + 11084, + 11120, + 11102, + 11290, + 10380, + 10121, + 10615, + 10698, + 11587, + 11407, + 11562, + 11766, + 10244, + 10704, + 10996, + 11752, + 10904, + 11222, + 11236, + 11955, + 12020, + 10880, + 10426, + 11106, + 11730, + 9815, + 11525, + 10978, + 11387, + 12137, + 11771, + 10245, + 11565, + 11353, + 11542, + 10967, + 10040, + 11422, + 11432, + 10917, + 11518, + 11619, + 11768, + 11593, + 11750, + 11761, + 11079, + 10640, + 11664, + 10865, + 10560, + 11347, + 11571, + 11544, + 11901, + 10878, + 10445, + 11663, + 10420, + 10686, + 10686, + 10912, + 10699, + 10796, + 11840, + 11787, + 11362, + 11131, + 11044, + 10883, + 11580, + 10983, + 10153, + 10963, + 11346, + 11296, + 10437, + 11050, + 11050, + 11651, + 11051, + 10484, + 11291, + 12008, + 11272, + 10399, + 10538, + 10279, + 11336, + 11277, + 11235, + 11621, + 11000, + 11249, + 11708, + 11479, + 11552, + 10842, + 11065, + 10976, + 11343, + 11024, + 10128, + 10331, + 11296, + 11667, + 10739, + 11440, + 11145, + 11277, + 10677, + 10910, + 11207, + 11217, + 10661, + 11538, + 10079, + 11411, + 11164, + 11699, + 10223, + 10687, + 10927, + 10055, + 10802, + 10844, + 11163, + 10447, + 10652, + 10963, + 11074, + 10611, + 11374, + 11207, + 11465, + 11263, + 10529, + 10873, + 11191, + 10970, + 11257, + 11283, + 11265, + 11672, + 10679, + 11404, + 11345, + 12297, + 11943, + 10261, + 11238, + 11334, + 10983, + 11317, + 10931, + 10505, + 11014, + 11756, + 11801, + 11441, + 11456, + 10028, + 11081, + 11859, + 11450, + 11814, + 11359, + 10683, + 11686, + 9834, + 11170, + 11662, + 11654, + 10123, + 10853, + 11724, + 10878, + 10902, + 11583, + 11168, + 11689, + 10694, + 11251, + 11911, + 10522, + 10883, + 10857, + 11405, + 11467, + 10573, + 10597, + 10895, + 11020, + 10362, + 10514, + 11368, + 11616, + 11162, + 11313, + 11493, + 11996, + 11668, + 10477, + 11768, + 11550, + 11299, + 11021, + 10916, + 11538, + 11464, + 11967, + 11934, + 11055, + 11476, + 11856, + 11990, + 11963, + 11738, + 10800, + 11194, + 10990, + 11110, + 10123, + 11441, + 11549, + 11070, + 10076, + 11360, + 11030, + 10112, + 11281, + 11918, + 12103, + 11594, + 11563, + 10418, + 11001, + 10925, + 11120, + 11697, + 11602, + 10451, + 10678, + 10749, + 11367, + 11894, + 11681, + 10263, + 9912, + 10283, + 10640, + 11081, + 10279, + 11090, + 11828, + 12172, + 11395, + 11079, + 11634 + ], + "xaxis": "x", + "yaxis": "y" + } + ], + "layout": { + "barmode": "relative", + "legend": { + "title": { + "text": "variable" + }, + "tracegroupgap": 0 + }, + "margin": { + "t": 60 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "title": { + "text": "value" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "title": { + "text": "count" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sae.eval() # prevents error if we're expecting a dead neuron mask for who grads\n", + "\n", + "with torch.no_grad():\n", + " # activation store can give us tokens.\n", + " batch_tokens = token_dataset_tensors[:1] #[token_dataset[\"input_ids\"][:10]]\n", + " _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)\n", + "\n", + " # Use the SAE\n", + " feature_acts = sae.encode(cache[sae.cfg.hook_name])\n", + " sae_out = sae.decode(feature_acts)\n", + "\n", + " # save some room\n", + " del cache\n", + "\n", + " # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position\n", + " l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()\n", + " print(\"average l0\", l0.mean().item())\n", + " px.histogram(l0.flatten().cpu().numpy()).show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ijoelLtdv611" + }, + "source": [ + "Note that while the mean L0 is 11317, it varies with the specific activation.\n", + "\n", + "To estimate reconstruction performance, we calculate the CE loss of the model with and without the SAE being used in place of the activations. This will vary depending on the tokens." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "id": "fwrSvREJv612" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Orig 1.8834620714187622\n", + "reconstr 1.9335556030273438\n", + "Zero 12.452932357788086\n" + ] + } + ], + "source": [ + "from transformer_lens import utils\n", + "from functools import partial\n", + "\n", + "\n", + "# next we want to do a reconstruction test.\n", + "def reconstr_hook(activation, hook, sae_out):\n", + " return sae_out\n", + "\n", + "\n", + "def zero_abl_hook(activation, hook):\n", + " return torch.zeros_like(activation)\n", + "\n", + "\n", + "print(\"Orig\", model(batch_tokens, return_type=\"loss\").item())\n", + "print(\n", + " \"reconstr\",\n", + " model.run_with_hooks(\n", + " batch_tokens,\n", + " fwd_hooks=[\n", + " (\n", + " sae.cfg.hook_name,\n", + " partial(reconstr_hook, sae_out=sae_out),\n", + " )\n", + " ],\n", + " return_type=\"loss\",\n", + " ).item(),\n", + ")\n", + "print(\n", + " \"Zero\",\n", + " model.run_with_hooks(\n", + " batch_tokens,\n", + " return_type=\"loss\",\n", + " fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],\n", + " ).item(),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "B_TRq_lFv612" + }, + "source": [ + "## Specific Capability Test\n", + "\n", + "Validating model performance on specific tasks when using the reconstructed activation is quite important when studying specific tasks." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "id": "npxKip_Qv612" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenized prompt: ['', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' John', ' gave', ' the', ' bag', ' to']\n", + "Tokenized answer: [' Mary']\n" + ] + }, + { + "data": { + "text/html": [ + "
Performance on answer token:\n",
+       "Rank: 0        Logit: 17.90 Prob: 75.62% Token: | Mary|\n",
+       "
\n" + ], + "text/plain": [ + "Performance on answer token:\n", + "\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m17.90\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m75.62\u001b[0m\u001b[1m% Token: | Mary|\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 0th token. Logit: 17.90 Prob: 75.62% Token: | Mary|\n", + "Top 1th token. Logit: 16.47 Prob: 18.18% Token: | the|\n", + "Top 2th token. Logit: 14.11 Prob: 1.71% Token: | his|\n", + "Top 3th token. Logit: 13.87 Prob: 1.34% Token: | a|\n", + "Top 4th token. Logit: 12.59 Prob: 0.38% Token: | their|\n", + "Top 5th token. Logit: 11.97 Prob: 0.20% Token: | John|\n", + "Top 6th token. Logit: 11.68 Prob: 0.15% Token: | Sarah|\n", + "Top 7th token. Logit: 11.45 Prob: 0.12% Token: | Mrs|\n", + "Top 8th token. Logit: 11.17 Prob: 0.09% Token: | someone|\n", + "Top 9th token. Logit: 11.10 Prob: 0.08% Token: | her|\n" + ] + }, + { + "data": { + "text/html": [ + "
Ranks of the answer tokens: [(' Mary', 0)]\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Mary'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Orig 3.7073116302490234\n", + "reconstr 3.7859463691711426\n", + "Zero 12.452932357788086\n", + "Tokenized prompt: ['', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' John', ' gave', ' the', ' bag', ' to']\n", + "Tokenized answer: [' Mary']\n" + ] + }, + { + "data": { + "text/html": [ + "
Performance on answer token:\n",
+       "Rank: 0        Logit: 17.26 Prob: 67.92% Token: | Mary|\n",
+       "
\n" + ], + "text/plain": [ + "Performance on answer token:\n", + "\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m17.26\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m67.92\u001b[0m\u001b[1m% Token: | Mary|\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 0th token. Logit: 17.26 Prob: 67.92% Token: | Mary|\n", + "Top 1th token. Logit: 16.21 Prob: 23.79% Token: | the|\n", + "Top 2th token. Logit: 13.92 Prob: 2.42% Token: | a|\n", + "Top 3th token. Logit: 12.88 Prob: 0.86% Token: | his|\n", + "Top 4th token. Logit: 11.99 Prob: 0.35% Token: | her|\n", + "Top 5th token. Logit: 11.77 Prob: 0.28% Token: | John|\n", + "Top 6th token. Logit: 11.71 Prob: 0.27% Token: | shop|\n", + "Top 7th token. Logit: 11.62 Prob: 0.24% Token: | their|\n", + "Top 8th token. Logit: 11.44 Prob: 0.20% Token: | him|\n", + "Top 9th token. Logit: 11.31 Prob: 0.18% Token: | Mrs|\n" + ] + }, + { + "data": { + "text/html": [ + "
Ranks of the answer tokens: [(' Mary', 0)]\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Mary'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "example_prompt = \"When John and Mary went to the shops, John gave the bag to\"\n", + "example_answer = \" Mary\"\n", + "utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)\n", + "\n", + "logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)\n", + "tokens = model.to_tokens(example_prompt)\n", + "sae_out = sae(cache[sae.cfg.hook_name])\n", + "\n", + "\n", + "def reconstr_hook(activations, hook, sae_out):\n", + " return sae_out\n", + "\n", + "\n", + "def zero_abl_hook(mlp_out, hook):\n", + " return torch.zeros_like(mlp_out)\n", + "\n", + "\n", + "hook_name = sae.cfg.hook_name\n", + "\n", + "print(\"Orig\", model(tokens, return_type=\"loss\").item())\n", + "print(\n", + " \"reconstr\",\n", + " model.run_with_hooks(\n", + " tokens,\n", + " fwd_hooks=[\n", + " (\n", + " hook_name,\n", + " partial(reconstr_hook, sae_out=sae_out),\n", + " )\n", + " ],\n", + " return_type=\"loss\",\n", + " ).item(),\n", + ")\n", + "print(\n", + " \"Zero\",\n", + " model.run_with_hooks(\n", + " tokens,\n", + " return_type=\"loss\",\n", + " fwd_hooks=[(hook_name, zero_abl_hook)],\n", + " ).item(),\n", + ")\n", + "\n", + "\n", + "with model.hooks(\n", + " fwd_hooks=[\n", + " (\n", + " hook_name,\n", + " partial(reconstr_hook, sae_out=sae_out),\n", + " )\n", + " ]\n", + "):\n", + " utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1swj9KA7v612" + }, + "source": [ + "# Generating Feature Interfaces\n", + "\n", + "Feature dashboards are an important part of SAE Evaluation. They work by:\n", + "- 1. Collecting feature activations over a larger number of examples.\n", + "- 2. Aggregating feature specific statistics (such as max activating examples).\n", + "- 3. Representing that information in a standardized way\n", + "\n", + "For our feature visualizations, we will use a separate library called SAEDashboard." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "id": "edt8ag4fv612" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e6c9cf17032c45de85bd0b6b18adaed5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Forward passes to cache data for vis: 0%| | 0/2 [00:00┏━━━━━━┳━━━━━━┳━━━━━━━┓\n", + "┃ Task Time Pct % ┃\n", + "┡━━━━━━╇━━━━━━╇━━━━━━━┩\n", + "└──────┴──────┴───────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━┳━━━━━━┳━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mTask\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mTime\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mPct %\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━╇━━━━━━╇━━━━━━━┩\n", + "└──────┴──────┴───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# %pip install sae_dashboard\n", + "from sae_dashboard.sae_vis_data import SaeVisConfig\n", + "from sae_dashboard.sae_vis_runner import SaeVisRunner\n", + "\n", + "test_feature_idx_gpt = list(range(100)) + [14057] # TODO: test all the features\n", + "\n", + "feature_vis_config_gpt = SaeVisConfig(\n", + " hook_point=hook_name,\n", + " features=test_feature_idx_gpt,\n", + " minibatch_size_features=64,\n", + " minibatch_size_tokens=256,\n", + " verbose=True,\n", + " device=device,\n", + ")\n", + "\n", + "visualization_data_gpt = SaeVisRunner(\n", + " feature_vis_config_gpt\n", + ").run(\n", + " encoder=sae, # type: ignore\n", + " model=model,\n", + " tokens=token_dataset_tensors[:1], # type: ignore \n", + ")\n", + "# SaeVisData.create(\n", + "# encoder=sae,\n", + "# model=model, # type: ignore\n", + "# tokens=token_dataset[:10000][\"tokens\"], # type: ignore\n", + "# cfg=feature_vis_config_gpt,\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "id": "yQ94Frzbv612" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d55a7d33346e47559a5f60f961a543c6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Saving feature-centric vis: 0%| | 0/101 [00:00