File size: 127,037 Bytes
daf8ad1
1
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"collapsed_sections":["erBNP6hvU724","PiehSbGeWmor","EKLDJqI17Tkm"],"machine_shape":"hm","gpuType":"L4","mount_file_id":"1Tsf5s2FZEHr9S5ja8MqkIA3DtPGNb0Cy","authorship_tag":"ABX9TyOx4EAq5vmbjU6gpbNxmViP"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"9bf9176fc21b4f15896d7281ced5fa17":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_76cf4a2d556842e497e2e64c2ba0b3fe","IPY_MODEL_a6b3d0c5d84b477c82545fb809fdda1c","IPY_MODEL_456b3fd649ab4c0fbb5348a765fb9613"],"layout":"IPY_MODEL_32bcb31386ca4187ba6a554d26189502"}},"76cf4a2d556842e497e2e64c2ba0b3fe":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_d4696d3bbdc149be9bab0cf4664c68c4","placeholder":"​","style":"IPY_MODEL_1575e84b9dc5454d94258035abf26da1","value":"Loading checkpoint shards: 100%"}},"a6b3d0c5d84b477c82545fb809fdda1c":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_b4e41d9dcb2c4be8bf94919c93c56a83","max":2,"min":0,"orientation":"horizontal","style":"IPY_MODEL_d96c33b98b874cf8862a2bedb3596b4c","value":2}},"456b3fd649ab4c0fbb5348a765fb9613":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_d7f4744342ec4ebd8d2798cae2e90518","placeholder":"​","style":"IPY_MODEL_2f9c4dbc55d04276a5b9fa92de613bf5","value":" 2/2 [00:00&lt;00:00,  2.19it/s]"}},"32bcb31386ca4187ba6a554d26189502":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"d4696d3bbdc149be9bab0cf4664c68c4":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"1575e84b9dc5454d94258035abf26da1":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"b4e41d9dcb2c4be8bf94919c93c56a83":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"d96c33b98b874cf8862a2bedb3596b4c":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"d7f4744342ec4ebd8d2798cae2e90518":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"2f9c4dbc55d04276a5b9fa92de613bf5":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","source":["## Config and Module"],"metadata":{"id":"erBNP6hvU724"}},{"cell_type":"code","execution_count":1,"metadata":{"id":"9WGrocZdPBY_","executionInfo":{"status":"ok","timestamp":1748636886799,"user_tz":420,"elapsed":19,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"outputs":[],"source":["ADAPTER_CONFIG = {\n","    \"adapter_id\": \"003\",\n","    \"name\": \"DualShuntAdapter-G\",\n","\n","    \"t5\": {\n","        \"model\": \"google/flan-t5-base\",\n","        \"hidden_size\": 768,\n","    },\n","    \"clip\": {\n","        \"model\": \"AbstractPhil/omega-vit-g-reformed\",\n","        \"hidden_size\": 1280,\n","    },\n","\n","    \"bottleneck\": 640,\n","    \"heads\": 20,\n","\n","    \"tau_init\": 0.1,\n","    \"max_guidance\": 10.0,\n","\n","    \"proj_layers\": 2,\n","    \"layer_norm\": True,\n","    \"dropout\": 0.1,\n","    \"use_dropout\": True,\n","    \"use_proj_stack\": True,\n","    \"assert_input_dims\": True,\n","\n","    \"routing\": {\n","        \"type\": \"cross_attention\",\n","        \"enable_causal_mask\": False,\n","        \"bidirectional\": True\n","    },\n","\n","    \"version\": \"v0.3.2\",\n","    \"description\": \"Final Dual Shunt Adapter with projection stack, dropout, and stacked residual refinement pocket.\"\n","}\n"]},{"cell_type":"code","source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","# ─── Residual Pocket Block ───────────────────────────────────\n","class BottleneckResBlock(nn.Module):\n","    def __init__(self, dim, kernel=3, dropout=0.1):\n","        super().__init__()\n","        self.norm = nn.LayerNorm(dim)\n","        self.conv = nn.Conv1d(dim, dim, kernel_size=kernel, padding=kernel // 2, groups=1)\n","        self.proj = nn.Sequential(\n","            nn.Linear(dim, dim * 2),\n","            nn.GELU(),\n","            nn.Linear(dim * 2, dim),\n","            nn.Dropout(dropout)\n","        )\n","\n","    def forward(self, x):\n","        residual = x\n","        x = self.norm(x)\n","        x = x.transpose(1, 2)\n","        x = self.conv(x).transpose(1, 2)\n","        return residual + self.proj(x)\n","\n","# ─── Two Stream Shunt Adapter ──────────────────────────────────────\n","class TwoStreamShuntAdapter(nn.Module):\n","    def __init__(self, config: dict):\n","        super().__init__()\n","        self.config = config\n","        self.t5_dim = config[\"t5\"][\"hidden_size\"]\n","        self.clip_dim = config[\"clip\"][\"hidden_size\"]\n","        self.bneck = config[\"bottleneck\"]\n","        self.heads = config[\"heads\"]\n","        self.tau_init = config[\"tau_init\"]\n","        self.max_guidance = config[\"max_guidance\"]\n","\n","        use_norm   = config.get(\"layer_norm\", True)\n","        use_do     = config.get(\"use_dropout\", True)\n","        do_p       = config.get(\"dropout\", 0.1)\n","        proj_depth = config.get(\"proj_layers\", 2)\n","\n","        def build_projection(input_dim, output_dim):\n","            layers = []\n","            last_dim = input_dim\n","            if use_norm:\n","                layers.append(nn.LayerNorm(last_dim))\n","            for i in range(proj_depth):\n","                next_dim = self.bneck * (2 if i == 0 and proj_depth > 1 else 1)\n","                layers.append(nn.Linear(last_dim, next_dim))\n","                layers.append(nn.GELU())\n","                if use_do:\n","                    layers.append(nn.Dropout(do_p))\n","                last_dim = next_dim\n","            layers.append(nn.Linear(last_dim, output_dim))\n","            return nn.Sequential(*layers)\n","\n","        # Projections\n","        self.proj_t5   = build_projection(self.t5_dim, self.bneck)\n","        self.proj_clip = build_projection(self.clip_dim, self.bneck)\n","\n","        # Attention\n","        self.cross_t2c = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)\n","        self.cross_c2t = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)\n","        self.tau       = nn.Parameter(torch.full((self.heads, 1, 1), self.tau_init))\n","\n","        # Residual Pocket\n","        self.pocket_blocks = nn.Sequential(\n","            BottleneckResBlock(self.bneck, dropout=do_p),\n","            BottleneckResBlock(self.bneck, dropout=do_p)\n","        )\n","\n","        # Fuse\n","        self.fuse = nn.Sequential(\n","            nn.LayerNorm(2 * self.bneck),\n","            nn.Linear(2 * self.bneck, self.bneck * 2),\n","            nn.GELU(),\n","            nn.Linear(self.bneck * 2, self.bneck)\n","        )\n","\n","        # Output Projections\n","        self.anchor_proj = build_projection(self.bneck, self.clip_dim)\n","        self.delta_proj  = build_projection(self.bneck, self.clip_dim)\n","        self.logsig_proj = build_projection(self.bneck, self.clip_dim)\n","\n","        self.gate_proj = nn.Sequential(\n","            nn.LayerNorm(self.bneck),\n","            nn.Linear(self.bneck, self.bneck),\n","            nn.GELU(),\n","            nn.Linear(self.bneck, 1),\n","            nn.Tanh(),\n","            nn.Sigmoid()\n","        )\n","\n","        self.guidance_proj = nn.Sequential(\n","            nn.LayerNorm(self.bneck),\n","            nn.Linear(self.bneck, 1),\n","            nn.Sigmoid()\n","        )\n","\n","    def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):\n","        if self.config.get(\"assert_input_dims\", True):\n","            assert t5_seq.size(-1) == self.t5_dim\n","            assert clip_seq.size(-1) == self.clip_dim\n","\n","        t5_b   = self.proj_t5(t5_seq)\n","        clip_b = self.proj_clip(clip_seq)\n","\n","        t2c, attn_t2c = self.cross_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)\n","        c2t, attn_c2t = self.cross_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)\n","\n","        pocket = self.pocket_blocks(t2c)\n","\n","        pocket_mean = pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1)\n","        h = self.fuse(torch.cat([pocket_mean, c2t], dim=-1))\n","\n","        anchor    = self.anchor_proj(h)\n","        delta     = self.delta_proj(h) * self.gate_proj(h)\n","        log_sigma = self.logsig_proj(h)\n","\n","        g_tok  = self.guidance_proj(h).squeeze(-1)\n","        g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance\n","\n","        return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, self.gate_proj(h)\n"],"metadata":{"id":"qjb_vFZRTQaC","executionInfo":{"status":"ok","timestamp":1748636888396,"user_tz":420,"elapsed":1586,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":2,"outputs":[]},{"cell_type":"code","source":["from safetensors.torch import save_file, load_file\n","\n","def save_safetensors(adapter: nn.Module, path: str, metadata: dict = None):\n","    \"\"\"\n","    Save the current adapter state to safetensors format.\n","\n","    All tensors are moved to CPU and saved as float32 for compatibility.\n","    Optional metadata may be embedded (e.g., version, prompt_mode).\n","    \"\"\"\n","    state = {k: v.float().cpu() for k, v in adapter.state_dict().items()}\n","    save_file(state, path, metadata=metadata or {})\n","    print(f\"βœ… Model saved to {path}\")\n","\n","def load_safetensors(adapter: nn.Module, path: str, map_location=\"cpu\"):\n","    \"\"\"\n","    Load a safetensors checkpoint into the adapter.\n","\n","    Uses strict key matching. Tensors are loaded to the specified device.\n","    \"\"\"\n","    state = load_file(path, device=map_location)\n","    adapter.load_state_dict(state, strict=True)\n","    print(f\"βœ… Model loaded from {path}\")\n","\n","\n"],"metadata":{"id":"zpOi5svciXJ6","executionInfo":{"status":"ok","timestamp":1748636888425,"user_tz":420,"elapsed":13,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":3,"outputs":[]},{"cell_type":"markdown","source":["## Data Loader"],"metadata":{"id":"PiehSbGeWmor"}},{"cell_type":"code","source":["import torch\n","import csv\n","\n","# ─────────────────────────────────────────────────────────────\n","# β–‘ Streaming Caption Dataset\n","# ─────────────────────────────────────────────────────────────\n","# ─────────────────────────────────────────────────────────────\n","# β–‘ Streaming Caption Dataset – 32 quality descriptors\n","# ─────────────────────────────────────────────────────────────\n","import csv, random, re\n","from typing import List\n","from torch.utils.data import IterableDataset, get_worker_info\n","from huggingface_hub import hf_hub_download\n","from torch.utils.data import DataLoader\n","from pathlib import Path\n","\n","\n","class ParsedMultiCharDataset(IterableDataset):\n","    \"\"\"\n","    Streams HF-hosted caption shards and, for each text chunk:\n","\n","      β€’ If it starts with β€œa ” (case-insensitive, ignoring leading spaces),\n","        replace that leading token with a random photo/video quality\n","        descriptor followed by β€œ, ”.\n","\n","    No preliminary file scanning; every CSV line is read exactly once.\n","    \"\"\"\n","\n","    # match leading β€œa ” only\n","    _PAT_START_A = re.compile(r\"^\\s*a\\s+\", re.IGNORECASE)\n","\n","    # 32 diverse quality descriptors\n","    _QUALITY_DESCRIPTORS: List[str] = [\n","        \"masterpiece,\",\n","        \"very aesthetic,\",\n","        \"most aesthetic,\",\n","        \"an absolutely perfect depiction of\",\n","        \"awa,\",\n","        \"very awa,\",\n","        \"dimly lit,\",\n","        \"beautifuly lit,\",\n","        \"very beautiful,\",\n","        \"masterful depiction of\",\n","        \"dedicated masterpiece,\",\n","        \"warmly lit\",\n","        \"best quality, most aesthetic,\",\n","        \"beautiful depiction of\",\n","        \"masterful artwork of\",\n","        \"high-resolution photograph,\",\n","        \"hyper-realistic image,\",\n","        \"ultra-detailed photo,\",\n","        \"studio-quality photograph,\",\n","        \"cinematic shot,\",\n","        \"4K HDR image,\",\n","        \"sharp-focus photo,\",\n","        \"professionally lit photograph,\",\n","        \"DSLR capture,\",\n","        \"film-grain photograph,\",\n","        \"bokeh-rich shot,\",\n","        \"medium-format scan,\",\n","        \"analog film still,\",\n","        \"moody cinematic frame,\",\n","        \"dramatic-lighting photo,\",\n","        \"vibrant editorial image,\",\n","        \"macro-lens close-up,\",\n","        \"aerial drone photo,\",\n","        \"soft-focus dreamlike photo,\",\n","        \"low-key studio shot,\",\n","        \"overhead product shot,\",\n","        \"golden-hour photograph,\",\n","        \"noir-style monochrome shot,\",\n","        \"vintage Polaroid scan,\",\n","        \"infrared photograph,\",\n","        \"ultra-wide panorama,\",\n","        \"tilt-shift miniature photo,\",\n","        \"long-exposure night shot,\",\n","        \"time-lapse still,\",\n","        \"splash-photography frame,\",\n","        \"fine-art print scan,\",\n","        \"astrophotography capture,\",\n","        \"score_9, score_8, score_7, score_6,\",\n","        \"score_1, score_2, score_3, score_4,\",\n","        \"masterpiece, most aesthetic,\",\n","        \"most aesthetic, very aesthetic,\",\n","        \"masterpiece, most aesthetic, very aesthetic, realistic, real,\",\n","        \"most aesthetic, realistic, real,\",\n","        \"very aesthetic, realistic, real,\",\n","        \"masterpiece, very aesthetic, realistic, real,\",\n","        \"most aesthetic, very aesthetic, realistic, real,\",\n","        \"masterpiece, very aesthetic, realistic, anime,\",\n","        \"most aesthetic, very aesthetic, realistic, anime,\",\n","        \"very aesthetic, realistic, anime,\",\n","        \"masterpiece, very aesthetic, realistic, anime,\",\n","        \"most aesthetic, very aesthetic, realistic, anime,\",\n","        \"very aesthetic, realistic, anime, anime,\",\n","        \"2d,\",\n","        \"3d,\",\n","        \"anime,\",\n","        \"real,\",\n","        \"cartoon,\"\n","        \"realistic,\",\n","        \"2d,\",\n","        \"3d,\",\n","        \"anime,\",\n","        \"real,\",\n","        \"cartoon,\"\n","        \"realistic,\",\n","        \"2d,\",\n","        \"3d,\",\n","        \"anime,\",\n","        \"real,\",\n","        \"cartoon,\"\n","        \"realistic,\",\n","        \"2d,\",\n","        \"3d,\",\n","        \"anime,\",\n","        \"real,\",\n","        \"cartoon,\"\n","        \"realistic,\",\n","        \"masterpiece, 2d,\",\n","        \"masterpiece, 3d,\",\n","        \"masterpiece, anime,\",\n","        \"masterpiece, real,\",\n","        \"masterpiece, cartoon,\",\n","        \"3d, anime,\",\n","        \"3d, anime, real,\",\n","        \"3d, anime, real, realistic,\",\n","        \"masterpiece, 3d,\",\n","        \"masterpiece, 3d, anime,\",\n","        \"masterpiece, 3d, anime, real,\",\n","        \"masterpiece, 3d, anime, real, realistic,\",\n","        \"masterpiece, anime,\",\n","        \"masterpiece, anime, real,\",\n","        \"masterpiece, anime, real, realistic,\",\n","        \"masterpiece, 3d, anime, real,\",\n","        \"very aesthetic, 3d,\",\n","        \"very aesthetic, 3d, anime,\",\n","        \"very aesthetic, 3d, anime, real,\",\n","        \"very aesthetic, 3d, anime, real, realistic,\",\n","        \"very aesthetic, 3d, anime, real,\",\n","        \"most aesthetic, 3d,\",\n","        \"most aesthetic, 3d, anime,\",\n","        \"most aesthetic, 3d, anime, real,\",\n","        \"most aesthetic, 3d, anime, real, realistic,\",\n","        \"anime, comic,\"\n","        \"manga, anime\",\n","        \"masterpiece, cartoon,\",\n","        \"masterpiece, cartoon, real,\",\n","        \"masterpiece, cartoon, real, realistic,\",\n","        \"masterpiece, cartoon, real,\",\n","        \"most aesthetic, cartoon,\",\n","        \"most aesthetic, cartoon, real,\",\n","        \"most aesthetic, cartoon, real, realistic,\",\n","        \"most aesthetic, cartoon, real,\",\n","        \"grid_a1 head,\"\n","        \"grid_a2 head,\",\n","        \"grid_a3 head,\",\n","        \"grid_a4 head,\",\n","        \"grid_a5 head,\",\n","        \"grid_b1 head,\"\n","        \"grid_b2 head,\",\n","        \"grid_b3 head,\",\n","        \"grid_b4 head,\",\n","        \"grid_b5 head,\",\n","        \"grid_c1 head,\"\n","        \"grid_c2 head,\",\n","        \"grid_c3 head,\",\n","        \"grid_c4 head,\",\n","        \"grid_c5 head,\",\n","        \"grid_d1 head,\"\n","        \"grid_d2 head,\",\n","        \"grid_d3 head,\",\n","        \"grid_d4 head,\",\n","        \"grid_d5 head,\",\n","        \"grid_e1 head,\"\n","        \"grid_e2 head,\",\n","        \"grid_e3 head,\",\n","        \"grid_e4 head,\",\n","        \"grid_e5 head,\",\n","        \"grid_a1 upper body,\"\n","        \"grid_a2 upper body,\",\n","        \"grid_a3 upper body,\",\n","        \"grid_a4 upper body,\",\n","        \"grid_a5 upper body,\",\n","        \"grid_b1 upper body,\"\n","        \"grid_b2 upper body,\",\n","        \"grid_b3 upper body,\",\n","        \"grid_b4 upper body,\",\n","        \"grid_b5 upper body,\",\n","        \"grid_c1 upper body,\"\n","        \"grid_c2 upper body,\",\n","        \"grid_c3 upper body,\",\n","        \"grid_c4 upper body,\",\n","        \"grid_c5 upper body,\",\n","        \"grid_d1 upper body,\"\n","        \"grid_d2 upper body,\",\n","        \"grid_d3 upper body,\",\n","        \"grid_d4 upper body,\",\n","        \"grid_d5 upper body,\",\n","        \"grid_e1 upper body,\"\n","        \"grid_e2 upper body,\",\n","        \"grid_e3 upper body,\",\n","        \"grid_e4 upper body,\",\n","        \"grid_e5 upper body,\",\n","        \"zone_ul upper body,\"\n","        \"zone_ur upper body,\"\n","        \"zone_ll upper body,\"\n","        \"zone_lr upper body,\",\n","        \"zone_ul head,\"\n","        \"zone_ur head,\"\n","        \"zone_ll head,\"\n","        \"zone_lr head,\",\n","\n","        #\"disgusting, 3d, lowres,\",\n","        #\"disgusting, 2d, lowres,\",\n","        #\"disgusting, cartoon, lowres,\",\n","        #\"disgusting, 3d, cartoon, lowres,\",\n","        #\"disgusting, 2d, cartoon, lowres,\",\n","        #\"disgusting\n","\n","\n","    ]\n","\n","    def __init__(\n","        self,\n","        repo_id: str,\n","        delimiter: str = \".,|,.\",\n","        start_file: int = 20,\n","        num_files: int = 80,\n","        shuffle: bool = False\n","    ):\n","        super().__init__()\n","        self.delimiter = delimiter\n","        self.paths = [\n","            hf_hub_download(\n","                repo_id,\n","                f\"captions/caption_{i + start_file:03d}.csv\",\n","                repo_type=\"dataset\",\n","            )\n","            for i in range(num_files)\n","        ]\n","        self.total_rows = 5_000_000 * num_files\n","        self.shuffle = shuffle\n","\n","    def __len__(self):\n","        return self.total_rows\n","\n","    def __iter__(self):\n","        worker = get_worker_info()\n","        paths = (\n","            self.paths\n","            if worker is None\n","            else self.paths[worker.id :: worker.num_workers]\n","        )\n","\n","        pat_a   = self._PAT_START_A\n","        choose  = random.choice\n","        q_pool  = self._QUALITY_DESCRIPTORS\n","        delim   = self.delimiter\n","\n","        for path in paths:\n","            with open(path, encoding=\"utf-8\", newline=\"\") as f:\n","                for row in csv.DictReader(f):\n","                    for chunk in row.get(\"text\", \"\").split(delim):\n","                        chunk = chunk.strip()\n","                        if not chunk:\n","                            continue\n","                        if self.shuffle:\n","                            random.shuffle(q_pool)\n","                        # replace leading β€œa ” with descriptor + comma\n","                        chunk = pat_a.sub(choose(q_pool) + \" \", chunk, count=1)\n","\n","                        yield chunk\n","\n","\n"],"metadata":{"id":"mWRkuGO5U5Ma","executionInfo":{"status":"ok","timestamp":1748636889529,"user_tz":420,"elapsed":267,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":4,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"_FBrEkfzbB0M","executionInfo":{"status":"ok","timestamp":1748636889550,"user_tz":420,"elapsed":19,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":4,"outputs":[]},{"cell_type":"markdown","source":["## Load Models"],"metadata":{"id":"RiRNq1OjXQ4R"}},{"cell_type":"code","source":["import torch\n","from transformers import (\n","    T5EncoderModel, T5TokenizerFast,\n","    CLIPTextModel, CLIPTokenizerFast\n",")\n","\n","import torch\n","from transformers import (\n","    T5EncoderModel, T5TokenizerFast,\n","    CLIPTextModel, CLIPTokenizerFast\n",")\n","\n","from safetensors.torch import save_file, load_file\n","\n","\n","# ─── Runtime Settings ────────────────────────────────────────\n","DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","DTYPE = torch.float32  # πŸ”’ Force full FP32 precision\n","\n","# ─── Load Tokenizers ─────────────────────────────────────────\n","t5_tok = T5TokenizerFast.from_pretrained(\"google/flan-t5-base\")\n","clip_tok = CLIPTokenizerFast.from_pretrained(\"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k\")\n","\n","# ─── Load and Freeze T5 Model ───────────────────────────────\n","t5_mod = T5EncoderModel.from_pretrained(\"google/flan-t5-base\").to(DEVICE, dtype=DTYPE)\n","t5_mod.eval().requires_grad_(False)\n","\n","# ─── Load and Freeze CLIP Model ─────────────────────────────\n","clip_mod = CLIPTextModel.from_pretrained(\"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k\").to(DEVICE, dtype=DTYPE)\n","\n","\n","print(clip_mod.state_dict().keys())\n","\n","temp_clip = load_file(\"/content/drive/MyDrive/clips/OMEGA-24-CLIP_G.safetensors\")\n","\n","print(temp_clip.keys())\n","\n","clip_mod.load_state_dict(temp_clip, strict=False)\n","clip_mod.eval().requires_grad_(False)\n","## ─── Load Dual Shunt Adapter ────────────────────────────────\n","#adapter = DualShuntAdapter(config=ADAPTER_CONFIG).to(DEVICE, dtype=DTYPE)\n","#\n","#\n","#print(\"βœ… All models loaded at float32 precision and ready.\")\n","\n","# ─── Initialize Adapter from Config ──────────────────────────\n","adapter = TwoStreamShuntAdapter(config=ADAPTER_CONFIG).to(DEVICE, dtype=DTYPE)\n","load_safetensors(adapter, \"/content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_e1_step_10000.safetensors\")\n","adapter.train()\n","print(\"βœ… T5, CLIP, and Dual Shunt Adapter loaded and cast.\")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":138,"referenced_widgets":["9bf9176fc21b4f15896d7281ced5fa17","76cf4a2d556842e497e2e64c2ba0b3fe","a6b3d0c5d84b477c82545fb809fdda1c","456b3fd649ab4c0fbb5348a765fb9613","32bcb31386ca4187ba6a554d26189502","d4696d3bbdc149be9bab0cf4664c68c4","1575e84b9dc5454d94258035abf26da1","b4e41d9dcb2c4be8bf94919c93c56a83","d96c33b98b874cf8862a2bedb3596b4c","d7f4744342ec4ebd8d2798cae2e90518","2f9c4dbc55d04276a5b9fa92de613bf5"]},"id":"NOpbg28ZXQH5","executionInfo":{"status":"ok","timestamp":1748636906923,"user_tz":420,"elapsed":12771,"user":{"displayName":"P C","userId":"00707517734723903966"}},"outputId":"8ebbca72-abed-49df-ca20-04250b38be55"},"execution_count":5,"outputs":[{"output_type":"display_data","data":{"text/plain":["Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"9bf9176fc21b4f15896d7281ced5fa17"}},"metadata":{}},{"output_type":"stream","name":"stdout","text":["odict_keys(['text_model.embeddings.token_embedding.weight', 'text_model.embeddings.position_embedding.weight', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.1.self_attn.k_proj.weight', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.1.mlp.fc1.weight', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.2.self_attn.k_proj.weight', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.2.self_attn.v_proj.weight', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.2.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.encoder.layers.2.mlp.fc2.weight', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.3.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.3.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.3.self_attn.out_proj.bias', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.3.mlp.fc2.weight', 'text_model.encoder.layers.3.mlp.fc2.bias', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight', 'text_model.encoder.layers.4.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.4.self_attn.out_proj.bias', 'text_model.encoder.layers.4.layer_norm1.weight', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.4.mlp.fc2.bias', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'text_model.encoder.layers.5.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.5.layer_norm1.weight', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.5.mlp.fc1.weight', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.5.mlp.fc2.weight', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.6.self_attn.k_proj.weight', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.6.self_attn.out_proj.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.6.layer_norm1.bias', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.bias', 'text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.7.self_attn.out_proj.bias', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.7.layer_norm1.bias', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.v_proj.weight', 'text_model.encoder.layers.8.self_attn.v_proj.bias', 'text_model.encoder.layers.8.self_attn.q_proj.weight', 'text_model.encoder.layers.8.self_attn.q_proj.bias', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.8.self_attn.out_proj.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.8.mlp.fc1.weight', 'text_model.encoder.layers.8.mlp.fc1.bias', 'text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.8.mlp.fc2.bias', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.k_proj.weight', 'text_model.encoder.layers.9.self_attn.k_proj.bias', 'text_model.encoder.layers.9.self_attn.v_proj.weight', 'text_model.encoder.layers.9.self_attn.v_proj.bias', 'text_model.encoder.layers.9.self_attn.q_proj.weight', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.9.self_attn.out_proj.weight', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_model.encoder.layers.9.layer_norm2.weight', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.10.self_attn.k_proj.weight', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.10.layer_norm1.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.10.mlp.fc1.bias', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.encoder.layers.10.layer_norm2.weight', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.encoder.layers.12.self_attn.k_proj.weight', 'text_model.encoder.layers.12.self_attn.k_proj.bias', 'text_model.encoder.layers.12.self_attn.v_proj.weight', 'text_model.encoder.layers.12.self_attn.v_proj.bias', 'text_model.encoder.layers.12.self_attn.q_proj.weight', 'text_model.encoder.layers.12.self_attn.q_proj.bias', 'text_model.encoder.layers.12.self_attn.out_proj.weight', 'text_model.encoder.layers.12.self_attn.out_proj.bias', 'text_model.encoder.layers.12.layer_norm1.weight', 'text_model.encoder.layers.12.layer_norm1.bias', 'text_model.encoder.layers.12.mlp.fc1.weight', 'text_model.encoder.layers.12.mlp.fc1.bias', 'text_model.encoder.layers.12.mlp.fc2.weight', 'text_model.encoder.layers.12.mlp.fc2.bias', 'text_model.encoder.layers.12.layer_norm2.weight', 'text_model.encoder.layers.12.layer_norm2.bias', 'text_model.encoder.layers.13.self_attn.k_proj.weight', 'text_model.encoder.layers.13.self_attn.k_proj.bias', 'text_model.encoder.layers.13.self_attn.v_proj.weight', 'text_model.encoder.layers.13.self_attn.v_proj.bias', 'text_model.encoder.layers.13.self_attn.q_proj.weight', 'text_model.encoder.layers.13.self_attn.q_proj.bias', 'text_model.encoder.layers.13.self_attn.out_proj.weight', 'text_model.encoder.layers.13.self_attn.out_proj.bias', 'text_model.encoder.layers.13.layer_norm1.weight', 'text_model.encoder.layers.13.layer_norm1.bias', 'text_model.encoder.layers.13.mlp.fc1.weight', 'text_model.encoder.layers.13.mlp.fc1.bias', 'text_model.encoder.layers.13.mlp.fc2.weight', 'text_model.encoder.layers.13.mlp.fc2.bias', 'text_model.encoder.layers.13.layer_norm2.weight', 'text_model.encoder.layers.13.layer_norm2.bias', 'text_model.encoder.layers.14.self_attn.k_proj.weight', 'text_model.encoder.layers.14.self_attn.k_proj.bias', 'text_model.encoder.layers.14.self_attn.v_proj.weight', 'text_model.encoder.layers.14.self_attn.v_proj.bias', 'text_model.encoder.layers.14.self_attn.q_proj.weight', 'text_model.encoder.layers.14.self_attn.q_proj.bias', 'text_model.encoder.layers.14.self_attn.out_proj.weight', 'text_model.encoder.layers.14.self_attn.out_proj.bias', 'text_model.encoder.layers.14.layer_norm1.weight', 'text_model.encoder.layers.14.layer_norm1.bias', 'text_model.encoder.layers.14.mlp.fc1.weight', 'text_model.encoder.layers.14.mlp.fc1.bias', 'text_model.encoder.layers.14.mlp.fc2.weight', 'text_model.encoder.layers.14.mlp.fc2.bias', 'text_model.encoder.layers.14.layer_norm2.weight', 'text_model.encoder.layers.14.layer_norm2.bias', 'text_model.encoder.layers.15.self_attn.k_proj.weight', 'text_model.encoder.layers.15.self_attn.k_proj.bias', 'text_model.encoder.layers.15.self_attn.v_proj.weight', 'text_model.encoder.layers.15.self_attn.v_proj.bias', 'text_model.encoder.layers.15.self_attn.q_proj.weight', 'text_model.encoder.layers.15.self_attn.q_proj.bias', 'text_model.encoder.layers.15.self_attn.out_proj.weight', 'text_model.encoder.layers.15.self_attn.out_proj.bias', 'text_model.encoder.layers.15.layer_norm1.weight', 'text_model.encoder.layers.15.layer_norm1.bias', 'text_model.encoder.layers.15.mlp.fc1.weight', 'text_model.encoder.layers.15.mlp.fc1.bias', 'text_model.encoder.layers.15.mlp.fc2.weight', 'text_model.encoder.layers.15.mlp.fc2.bias', 'text_model.encoder.layers.15.layer_norm2.weight', 'text_model.encoder.layers.15.layer_norm2.bias', 'text_model.encoder.layers.16.self_attn.k_proj.weight', 'text_model.encoder.layers.16.self_attn.k_proj.bias', 'text_model.encoder.layers.16.self_attn.v_proj.weight', 'text_model.encoder.layers.16.self_attn.v_proj.bias', 'text_model.encoder.layers.16.self_attn.q_proj.weight', 'text_model.encoder.layers.16.self_attn.q_proj.bias', 'text_model.encoder.layers.16.self_attn.out_proj.weight', 'text_model.encoder.layers.16.self_attn.out_proj.bias', 'text_model.encoder.layers.16.layer_norm1.weight', 'text_model.encoder.layers.16.layer_norm1.bias', 'text_model.encoder.layers.16.mlp.fc1.weight', 'text_model.encoder.layers.16.mlp.fc1.bias', 'text_model.encoder.layers.16.mlp.fc2.weight', 'text_model.encoder.layers.16.mlp.fc2.bias', 'text_model.encoder.layers.16.layer_norm2.weight', 'text_model.encoder.layers.16.layer_norm2.bias', 'text_model.encoder.layers.17.self_attn.k_proj.weight', 'text_model.encoder.layers.17.self_attn.k_proj.bias', 'text_model.encoder.layers.17.self_attn.v_proj.weight', 'text_model.encoder.layers.17.self_attn.v_proj.bias', 'text_model.encoder.layers.17.self_attn.q_proj.weight', 'text_model.encoder.layers.17.self_attn.q_proj.bias', 'text_model.encoder.layers.17.self_attn.out_proj.weight', 'text_model.encoder.layers.17.self_attn.out_proj.bias', 'text_model.encoder.layers.17.layer_norm1.weight', 'text_model.encoder.layers.17.layer_norm1.bias', 'text_model.encoder.layers.17.mlp.fc1.weight', 'text_model.encoder.layers.17.mlp.fc1.bias', 'text_model.encoder.layers.17.mlp.fc2.weight', 'text_model.encoder.layers.17.mlp.fc2.bias', 'text_model.encoder.layers.17.layer_norm2.weight', 'text_model.encoder.layers.17.layer_norm2.bias', 'text_model.encoder.layers.18.self_attn.k_proj.weight', 'text_model.encoder.layers.18.self_attn.k_proj.bias', 'text_model.encoder.layers.18.self_attn.v_proj.weight', 'text_model.encoder.layers.18.self_attn.v_proj.bias', 'text_model.encoder.layers.18.self_attn.q_proj.weight', 'text_model.encoder.layers.18.self_attn.q_proj.bias', 'text_model.encoder.layers.18.self_attn.out_proj.weight', 'text_model.encoder.layers.18.self_attn.out_proj.bias', 'text_model.encoder.layers.18.layer_norm1.weight', 'text_model.encoder.layers.18.layer_norm1.bias', 'text_model.encoder.layers.18.mlp.fc1.weight', 'text_model.encoder.layers.18.mlp.fc1.bias', 'text_model.encoder.layers.18.mlp.fc2.weight', 'text_model.encoder.layers.18.mlp.fc2.bias', 'text_model.encoder.layers.18.layer_norm2.weight', 'text_model.encoder.layers.18.layer_norm2.bias', 'text_model.encoder.layers.19.self_attn.k_proj.weight', 'text_model.encoder.layers.19.self_attn.k_proj.bias', 'text_model.encoder.layers.19.self_attn.v_proj.weight', 'text_model.encoder.layers.19.self_attn.v_proj.bias', 'text_model.encoder.layers.19.self_attn.q_proj.weight', 'text_model.encoder.layers.19.self_attn.q_proj.bias', 'text_model.encoder.layers.19.self_attn.out_proj.weight', 'text_model.encoder.layers.19.self_attn.out_proj.bias', 'text_model.encoder.layers.19.layer_norm1.weight', 'text_model.encoder.layers.19.layer_norm1.bias', 'text_model.encoder.layers.19.mlp.fc1.weight', 'text_model.encoder.layers.19.mlp.fc1.bias', 'text_model.encoder.layers.19.mlp.fc2.weight', 'text_model.encoder.layers.19.mlp.fc2.bias', 'text_model.encoder.layers.19.layer_norm2.weight', 'text_model.encoder.layers.19.layer_norm2.bias', 'text_model.encoder.layers.20.self_attn.k_proj.weight', 'text_model.encoder.layers.20.self_attn.k_proj.bias', 'text_model.encoder.layers.20.self_attn.v_proj.weight', 'text_model.encoder.layers.20.self_attn.v_proj.bias', 'text_model.encoder.layers.20.self_attn.q_proj.weight', 'text_model.encoder.layers.20.self_attn.q_proj.bias', 'text_model.encoder.layers.20.self_attn.out_proj.weight', 'text_model.encoder.layers.20.self_attn.out_proj.bias', 'text_model.encoder.layers.20.layer_norm1.weight', 'text_model.encoder.layers.20.layer_norm1.bias', 'text_model.encoder.layers.20.mlp.fc1.weight', 'text_model.encoder.layers.20.mlp.fc1.bias', 'text_model.encoder.layers.20.mlp.fc2.weight', 'text_model.encoder.layers.20.mlp.fc2.bias', 'text_model.encoder.layers.20.layer_norm2.weight', 'text_model.encoder.layers.20.layer_norm2.bias', 'text_model.encoder.layers.21.self_attn.k_proj.weight', 'text_model.encoder.layers.21.self_attn.k_proj.bias', 'text_model.encoder.layers.21.self_attn.v_proj.weight', 'text_model.encoder.layers.21.self_attn.v_proj.bias', 'text_model.encoder.layers.21.self_attn.q_proj.weight', 'text_model.encoder.layers.21.self_attn.q_proj.bias', 'text_model.encoder.layers.21.self_attn.out_proj.weight', 'text_model.encoder.layers.21.self_attn.out_proj.bias', 'text_model.encoder.layers.21.layer_norm1.weight', 'text_model.encoder.layers.21.layer_norm1.bias', 'text_model.encoder.layers.21.mlp.fc1.weight', 'text_model.encoder.layers.21.mlp.fc1.bias', 'text_model.encoder.layers.21.mlp.fc2.weight', 'text_model.encoder.layers.21.mlp.fc2.bias', 'text_model.encoder.layers.21.layer_norm2.weight', 'text_model.encoder.layers.21.layer_norm2.bias', 'text_model.encoder.layers.22.self_attn.k_proj.weight', 'text_model.encoder.layers.22.self_attn.k_proj.bias', 'text_model.encoder.layers.22.self_attn.v_proj.weight', 'text_model.encoder.layers.22.self_attn.v_proj.bias', 'text_model.encoder.layers.22.self_attn.q_proj.weight', 'text_model.encoder.layers.22.self_attn.q_proj.bias', 'text_model.encoder.layers.22.self_attn.out_proj.weight', 'text_model.encoder.layers.22.self_attn.out_proj.bias', 'text_model.encoder.layers.22.layer_norm1.weight', 'text_model.encoder.layers.22.layer_norm1.bias', 'text_model.encoder.layers.22.mlp.fc1.weight', 'text_model.encoder.layers.22.mlp.fc1.bias', 'text_model.encoder.layers.22.mlp.fc2.weight', 'text_model.encoder.layers.22.mlp.fc2.bias', 'text_model.encoder.layers.22.layer_norm2.weight', 'text_model.encoder.layers.22.layer_norm2.bias', 'text_model.encoder.layers.23.self_attn.k_proj.weight', 'text_model.encoder.layers.23.self_attn.k_proj.bias', 'text_model.encoder.layers.23.self_attn.v_proj.weight', 'text_model.encoder.layers.23.self_attn.v_proj.bias', 'text_model.encoder.layers.23.self_attn.q_proj.weight', 'text_model.encoder.layers.23.self_attn.q_proj.bias', 'text_model.encoder.layers.23.self_attn.out_proj.weight', 'text_model.encoder.layers.23.self_attn.out_proj.bias', 'text_model.encoder.layers.23.layer_norm1.weight', 'text_model.encoder.layers.23.layer_norm1.bias', 'text_model.encoder.layers.23.mlp.fc1.weight', 'text_model.encoder.layers.23.mlp.fc1.bias', 'text_model.encoder.layers.23.mlp.fc2.weight', 'text_model.encoder.layers.23.mlp.fc2.bias', 'text_model.encoder.layers.23.layer_norm2.weight', 'text_model.encoder.layers.23.layer_norm2.bias', 'text_model.encoder.layers.24.self_attn.k_proj.weight', 'text_model.encoder.layers.24.self_attn.k_proj.bias', 'text_model.encoder.layers.24.self_attn.v_proj.weight', 'text_model.encoder.layers.24.self_attn.v_proj.bias', 'text_model.encoder.layers.24.self_attn.q_proj.weight', 'text_model.encoder.layers.24.self_attn.q_proj.bias', 'text_model.encoder.layers.24.self_attn.out_proj.weight', 'text_model.encoder.layers.24.self_attn.out_proj.bias', 'text_model.encoder.layers.24.layer_norm1.weight', 'text_model.encoder.layers.24.layer_norm1.bias', 'text_model.encoder.layers.24.mlp.fc1.weight', 'text_model.encoder.layers.24.mlp.fc1.bias', 'text_model.encoder.layers.24.mlp.fc2.weight', 'text_model.encoder.layers.24.mlp.fc2.bias', 'text_model.encoder.layers.24.layer_norm2.weight', 'text_model.encoder.layers.24.layer_norm2.bias', 'text_model.encoder.layers.25.self_attn.k_proj.weight', 'text_model.encoder.layers.25.self_attn.k_proj.bias', 'text_model.encoder.layers.25.self_attn.v_proj.weight', 'text_model.encoder.layers.25.self_attn.v_proj.bias', 'text_model.encoder.layers.25.self_attn.q_proj.weight', 'text_model.encoder.layers.25.self_attn.q_proj.bias', 'text_model.encoder.layers.25.self_attn.out_proj.weight', 'text_model.encoder.layers.25.self_attn.out_proj.bias', 'text_model.encoder.layers.25.layer_norm1.weight', 'text_model.encoder.layers.25.layer_norm1.bias', 'text_model.encoder.layers.25.mlp.fc1.weight', 'text_model.encoder.layers.25.mlp.fc1.bias', 'text_model.encoder.layers.25.mlp.fc2.weight', 'text_model.encoder.layers.25.mlp.fc2.bias', 'text_model.encoder.layers.25.layer_norm2.weight', 'text_model.encoder.layers.25.layer_norm2.bias', 'text_model.encoder.layers.26.self_attn.k_proj.weight', 'text_model.encoder.layers.26.self_attn.k_proj.bias', 'text_model.encoder.layers.26.self_attn.v_proj.weight', 'text_model.encoder.layers.26.self_attn.v_proj.bias', 'text_model.encoder.layers.26.self_attn.q_proj.weight', 'text_model.encoder.layers.26.self_attn.q_proj.bias', 'text_model.encoder.layers.26.self_attn.out_proj.weight', 'text_model.encoder.layers.26.self_attn.out_proj.bias', 'text_model.encoder.layers.26.layer_norm1.weight', 'text_model.encoder.layers.26.layer_norm1.bias', 'text_model.encoder.layers.26.mlp.fc1.weight', 'text_model.encoder.layers.26.mlp.fc1.bias', 'text_model.encoder.layers.26.mlp.fc2.weight', 'text_model.encoder.layers.26.mlp.fc2.bias', 'text_model.encoder.layers.26.layer_norm2.weight', 'text_model.encoder.layers.26.layer_norm2.bias', 'text_model.encoder.layers.27.self_attn.k_proj.weight', 'text_model.encoder.layers.27.self_attn.k_proj.bias', 'text_model.encoder.layers.27.self_attn.v_proj.weight', 'text_model.encoder.layers.27.self_attn.v_proj.bias', 'text_model.encoder.layers.27.self_attn.q_proj.weight', 'text_model.encoder.layers.27.self_attn.q_proj.bias', 'text_model.encoder.layers.27.self_attn.out_proj.weight', 'text_model.encoder.layers.27.self_attn.out_proj.bias', 'text_model.encoder.layers.27.layer_norm1.weight', 'text_model.encoder.layers.27.layer_norm1.bias', 'text_model.encoder.layers.27.mlp.fc1.weight', 'text_model.encoder.layers.27.mlp.fc1.bias', 'text_model.encoder.layers.27.mlp.fc2.weight', 'text_model.encoder.layers.27.mlp.fc2.bias', 'text_model.encoder.layers.27.layer_norm2.weight', 'text_model.encoder.layers.27.layer_norm2.bias', 'text_model.encoder.layers.28.self_attn.k_proj.weight', 'text_model.encoder.layers.28.self_attn.k_proj.bias', 'text_model.encoder.layers.28.self_attn.v_proj.weight', 'text_model.encoder.layers.28.self_attn.v_proj.bias', 'text_model.encoder.layers.28.self_attn.q_proj.weight', 'text_model.encoder.layers.28.self_attn.q_proj.bias', 'text_model.encoder.layers.28.self_attn.out_proj.weight', 'text_model.encoder.layers.28.self_attn.out_proj.bias', 'text_model.encoder.layers.28.layer_norm1.weight', 'text_model.encoder.layers.28.layer_norm1.bias', 'text_model.encoder.layers.28.mlp.fc1.weight', 'text_model.encoder.layers.28.mlp.fc1.bias', 'text_model.encoder.layers.28.mlp.fc2.weight', 'text_model.encoder.layers.28.mlp.fc2.bias', 'text_model.encoder.layers.28.layer_norm2.weight', 'text_model.encoder.layers.28.layer_norm2.bias', 'text_model.encoder.layers.29.self_attn.k_proj.weight', 'text_model.encoder.layers.29.self_attn.k_proj.bias', 'text_model.encoder.layers.29.self_attn.v_proj.weight', 'text_model.encoder.layers.29.self_attn.v_proj.bias', 'text_model.encoder.layers.29.self_attn.q_proj.weight', 'text_model.encoder.layers.29.self_attn.q_proj.bias', 'text_model.encoder.layers.29.self_attn.out_proj.weight', 'text_model.encoder.layers.29.self_attn.out_proj.bias', 'text_model.encoder.layers.29.layer_norm1.weight', 'text_model.encoder.layers.29.layer_norm1.bias', 'text_model.encoder.layers.29.mlp.fc1.weight', 'text_model.encoder.layers.29.mlp.fc1.bias', 'text_model.encoder.layers.29.mlp.fc2.weight', 'text_model.encoder.layers.29.mlp.fc2.bias', 'text_model.encoder.layers.29.layer_norm2.weight', 'text_model.encoder.layers.29.layer_norm2.bias', 'text_model.encoder.layers.30.self_attn.k_proj.weight', 'text_model.encoder.layers.30.self_attn.k_proj.bias', 'text_model.encoder.layers.30.self_attn.v_proj.weight', 'text_model.encoder.layers.30.self_attn.v_proj.bias', 'text_model.encoder.layers.30.self_attn.q_proj.weight', 'text_model.encoder.layers.30.self_attn.q_proj.bias', 'text_model.encoder.layers.30.self_attn.out_proj.weight', 'text_model.encoder.layers.30.self_attn.out_proj.bias', 'text_model.encoder.layers.30.layer_norm1.weight', 'text_model.encoder.layers.30.layer_norm1.bias', 'text_model.encoder.layers.30.mlp.fc1.weight', 'text_model.encoder.layers.30.mlp.fc1.bias', 'text_model.encoder.layers.30.mlp.fc2.weight', 'text_model.encoder.layers.30.mlp.fc2.bias', 'text_model.encoder.layers.30.layer_norm2.weight', 'text_model.encoder.layers.30.layer_norm2.bias', 'text_model.encoder.layers.31.self_attn.k_proj.weight', 'text_model.encoder.layers.31.self_attn.k_proj.bias', 'text_model.encoder.layers.31.self_attn.v_proj.weight', 'text_model.encoder.layers.31.self_attn.v_proj.bias', 'text_model.encoder.layers.31.self_attn.q_proj.weight', 'text_model.encoder.layers.31.self_attn.q_proj.bias', 'text_model.encoder.layers.31.self_attn.out_proj.weight', 'text_model.encoder.layers.31.self_attn.out_proj.bias', 'text_model.encoder.layers.31.layer_norm1.weight', 'text_model.encoder.layers.31.layer_norm1.bias', 'text_model.encoder.layers.31.mlp.fc1.weight', 'text_model.encoder.layers.31.mlp.fc1.bias', 'text_model.encoder.layers.31.mlp.fc2.weight', 'text_model.encoder.layers.31.mlp.fc2.bias', 'text_model.encoder.layers.31.layer_norm2.weight', 'text_model.encoder.layers.31.layer_norm2.bias', 'text_model.final_layer_norm.weight', 'text_model.final_layer_norm.bias'])\n","dict_keys(['logit_scale', 'text_model.embeddings.position_embedding.weight', 'text_model.embeddings.token_embedding.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.1.mlp.fc1.weight', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.1.self_attn.k_proj.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.10.layer_norm1.bias', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.10.layer_norm2.weight', 'text_model.encoder.layers.10.mlp.fc1.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.10.self_attn.k_proj.weight', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.12.layer_norm1.bias', 'text_model.encoder.layers.12.layer_norm1.weight', 'text_model.encoder.layers.12.layer_norm2.bias', 'text_model.encoder.layers.12.layer_norm2.weight', 'text_model.encoder.layers.12.mlp.fc1.bias', 'text_model.encoder.layers.12.mlp.fc1.weight', 'text_model.encoder.layers.12.mlp.fc2.bias', 'text_model.encoder.layers.12.mlp.fc2.weight', 'text_model.encoder.layers.12.self_attn.k_proj.bias', 'text_model.encoder.layers.12.self_attn.k_proj.weight', 'text_model.encoder.layers.12.self_attn.out_proj.bias', 'text_model.encoder.layers.12.self_attn.out_proj.weight', 'text_model.encoder.layers.12.self_attn.q_proj.bias', 'text_model.encoder.layers.12.self_attn.q_proj.weight', 'text_model.encoder.layers.12.self_attn.v_proj.bias', 'text_model.encoder.layers.12.self_attn.v_proj.weight', 'text_model.encoder.layers.13.layer_norm1.bias', 'text_model.encoder.layers.13.layer_norm1.weight', 'text_model.encoder.layers.13.layer_norm2.bias', 'text_model.encoder.layers.13.layer_norm2.weight', 'text_model.encoder.layers.13.mlp.fc1.bias', 'text_model.encoder.layers.13.mlp.fc1.weight', 'text_model.encoder.layers.13.mlp.fc2.bias', 'text_model.encoder.layers.13.mlp.fc2.weight', 'text_model.encoder.layers.13.self_attn.k_proj.bias', 'text_model.encoder.layers.13.self_attn.k_proj.weight', 'text_model.encoder.layers.13.self_attn.out_proj.bias', 'text_model.encoder.layers.13.self_attn.out_proj.weight', 'text_model.encoder.layers.13.self_attn.q_proj.bias', 'text_model.encoder.layers.13.self_attn.q_proj.weight', 'text_model.encoder.layers.13.self_attn.v_proj.bias', 'text_model.encoder.layers.13.self_attn.v_proj.weight', 'text_model.encoder.layers.14.layer_norm1.bias', 'text_model.encoder.layers.14.layer_norm1.weight', 'text_model.encoder.layers.14.layer_norm2.bias', 'text_model.encoder.layers.14.layer_norm2.weight', 'text_model.encoder.layers.14.mlp.fc1.bias', 'text_model.encoder.layers.14.mlp.fc1.weight', 'text_model.encoder.layers.14.mlp.fc2.bias', 'text_model.encoder.layers.14.mlp.fc2.weight', 'text_model.encoder.layers.14.self_attn.k_proj.bias', 'text_model.encoder.layers.14.self_attn.k_proj.weight', 'text_model.encoder.layers.14.self_attn.out_proj.bias', 'text_model.encoder.layers.14.self_attn.out_proj.weight', 'text_model.encoder.layers.14.self_attn.q_proj.bias', 'text_model.encoder.layers.14.self_attn.q_proj.weight', 'text_model.encoder.layers.14.self_attn.v_proj.bias', 'text_model.encoder.layers.14.self_attn.v_proj.weight', 'text_model.encoder.layers.15.layer_norm1.bias', 'text_model.encoder.layers.15.layer_norm1.weight', 'text_model.encoder.layers.15.layer_norm2.bias', 'text_model.encoder.layers.15.layer_norm2.weight', 'text_model.encoder.layers.15.mlp.fc1.bias', 'text_model.encoder.layers.15.mlp.fc1.weight', 'text_model.encoder.layers.15.mlp.fc2.bias', 'text_model.encoder.layers.15.mlp.fc2.weight', 'text_model.encoder.layers.15.self_attn.k_proj.bias', 'text_model.encoder.layers.15.self_attn.k_proj.weight', 'text_model.encoder.layers.15.self_attn.out_proj.bias', 'text_model.encoder.layers.15.self_attn.out_proj.weight', 'text_model.encoder.layers.15.self_attn.q_proj.bias', 'text_model.encoder.layers.15.self_attn.q_proj.weight', 'text_model.encoder.layers.15.self_attn.v_proj.bias', 'text_model.encoder.layers.15.self_attn.v_proj.weight', 'text_model.encoder.layers.16.layer_norm1.bias', 'text_model.encoder.layers.16.layer_norm1.weight', 'text_model.encoder.layers.16.layer_norm2.bias', 'text_model.encoder.layers.16.layer_norm2.weight', 'text_model.encoder.layers.16.mlp.fc1.bias', 'text_model.encoder.layers.16.mlp.fc1.weight', 'text_model.encoder.layers.16.mlp.fc2.bias', 'text_model.encoder.layers.16.mlp.fc2.weight', 'text_model.encoder.layers.16.self_attn.k_proj.bias', 'text_model.encoder.layers.16.self_attn.k_proj.weight', 'text_model.encoder.layers.16.self_attn.out_proj.bias', 'text_model.encoder.layers.16.self_attn.out_proj.weight', 'text_model.encoder.layers.16.self_attn.q_proj.bias', 'text_model.encoder.layers.16.self_attn.q_proj.weight', 'text_model.encoder.layers.16.self_attn.v_proj.bias', 'text_model.encoder.layers.16.self_attn.v_proj.weight', 'text_model.encoder.layers.17.layer_norm1.bias', 'text_model.encoder.layers.17.layer_norm1.weight', 'text_model.encoder.layers.17.layer_norm2.bias', 'text_model.encoder.layers.17.layer_norm2.weight', 'text_model.encoder.layers.17.mlp.fc1.bias', 'text_model.encoder.layers.17.mlp.fc1.weight', 'text_model.encoder.layers.17.mlp.fc2.bias', 'text_model.encoder.layers.17.mlp.fc2.weight', 'text_model.encoder.layers.17.self_attn.k_proj.bias', 'text_model.encoder.layers.17.self_attn.k_proj.weight', 'text_model.encoder.layers.17.self_attn.out_proj.bias', 'text_model.encoder.layers.17.self_attn.out_proj.weight', 'text_model.encoder.layers.17.self_attn.q_proj.bias', 'text_model.encoder.layers.17.self_attn.q_proj.weight', 'text_model.encoder.layers.17.self_attn.v_proj.bias', 'text_model.encoder.layers.17.self_attn.v_proj.weight', 'text_model.encoder.layers.18.layer_norm1.bias', 'text_model.encoder.layers.18.layer_norm1.weight', 'text_model.encoder.layers.18.layer_norm2.bias', 'text_model.encoder.layers.18.layer_norm2.weight', 'text_model.encoder.layers.18.mlp.fc1.bias', 'text_model.encoder.layers.18.mlp.fc1.weight', 'text_model.encoder.layers.18.mlp.fc2.bias', 'text_model.encoder.layers.18.mlp.fc2.weight', 'text_model.encoder.layers.18.self_attn.k_proj.bias', 'text_model.encoder.layers.18.self_attn.k_proj.weight', 'text_model.encoder.layers.18.self_attn.out_proj.bias', 'text_model.encoder.layers.18.self_attn.out_proj.weight', 'text_model.encoder.layers.18.self_attn.q_proj.bias', 'text_model.encoder.layers.18.self_attn.q_proj.weight', 'text_model.encoder.layers.18.self_attn.v_proj.bias', 'text_model.encoder.layers.18.self_attn.v_proj.weight', 'text_model.encoder.layers.19.layer_norm1.bias', 'text_model.encoder.layers.19.layer_norm1.weight', 'text_model.encoder.layers.19.layer_norm2.bias', 'text_model.encoder.layers.19.layer_norm2.weight', 'text_model.encoder.layers.19.mlp.fc1.bias', 'text_model.encoder.layers.19.mlp.fc1.weight', 'text_model.encoder.layers.19.mlp.fc2.bias', 'text_model.encoder.layers.19.mlp.fc2.weight', 'text_model.encoder.layers.19.self_attn.k_proj.bias', 'text_model.encoder.layers.19.self_attn.k_proj.weight', 'text_model.encoder.layers.19.self_attn.out_proj.bias', 'text_model.encoder.layers.19.self_attn.out_proj.weight', 'text_model.encoder.layers.19.self_attn.q_proj.bias', 'text_model.encoder.layers.19.self_attn.q_proj.weight', 'text_model.encoder.layers.19.self_attn.v_proj.bias', 'text_model.encoder.layers.19.self_attn.v_proj.weight', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.encoder.layers.2.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.2.mlp.fc2.weight', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.2.self_attn.k_proj.weight', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.2.self_attn.v_proj.weight', 'text_model.encoder.layers.20.layer_norm1.bias', 'text_model.encoder.layers.20.layer_norm1.weight', 'text_model.encoder.layers.20.layer_norm2.bias', 'text_model.encoder.layers.20.layer_norm2.weight', 'text_model.encoder.layers.20.mlp.fc1.bias', 'text_model.encoder.layers.20.mlp.fc1.weight', 'text_model.encoder.layers.20.mlp.fc2.bias', 'text_model.encoder.layers.20.mlp.fc2.weight', 'text_model.encoder.layers.20.self_attn.k_proj.bias', 'text_model.encoder.layers.20.self_attn.k_proj.weight', 'text_model.encoder.layers.20.self_attn.out_proj.bias', 'text_model.encoder.layers.20.self_attn.out_proj.weight', 'text_model.encoder.layers.20.self_attn.q_proj.bias', 'text_model.encoder.layers.20.self_attn.q_proj.weight', 'text_model.encoder.layers.20.self_attn.v_proj.bias', 'text_model.encoder.layers.20.self_attn.v_proj.weight', 'text_model.encoder.layers.21.layer_norm1.bias', 'text_model.encoder.layers.21.layer_norm1.weight', 'text_model.encoder.layers.21.layer_norm2.bias', 'text_model.encoder.layers.21.layer_norm2.weight', 'text_model.encoder.layers.21.mlp.fc1.bias', 'text_model.encoder.layers.21.mlp.fc1.weight', 'text_model.encoder.layers.21.mlp.fc2.bias', 'text_model.encoder.layers.21.mlp.fc2.weight', 'text_model.encoder.layers.21.self_attn.k_proj.bias', 'text_model.encoder.layers.21.self_attn.k_proj.weight', 'text_model.encoder.layers.21.self_attn.out_proj.bias', 'text_model.encoder.layers.21.self_attn.out_proj.weight', 'text_model.encoder.layers.21.self_attn.q_proj.bias', 'text_model.encoder.layers.21.self_attn.q_proj.weight', 'text_model.encoder.layers.21.self_attn.v_proj.bias', 'text_model.encoder.layers.21.self_attn.v_proj.weight', 'text_model.encoder.layers.22.layer_norm1.bias', 'text_model.encoder.layers.22.layer_norm1.weight', 'text_model.encoder.layers.22.layer_norm2.bias', 'text_model.encoder.layers.22.layer_norm2.weight', 'text_model.encoder.layers.22.mlp.fc1.bias', 'text_model.encoder.layers.22.mlp.fc1.weight', 'text_model.encoder.layers.22.mlp.fc2.bias', 'text_model.encoder.layers.22.mlp.fc2.weight', 'text_model.encoder.layers.22.self_attn.k_proj.bias', 'text_model.encoder.layers.22.self_attn.k_proj.weight', 'text_model.encoder.layers.22.self_attn.out_proj.bias', 'text_model.encoder.layers.22.self_attn.out_proj.weight', 'text_model.encoder.layers.22.self_attn.q_proj.bias', 'text_model.encoder.layers.22.self_attn.q_proj.weight', 'text_model.encoder.layers.22.self_attn.v_proj.bias', 'text_model.encoder.layers.22.self_attn.v_proj.weight', 'text_model.encoder.layers.23.layer_norm1.bias', 'text_model.encoder.layers.23.layer_norm1.weight', 'text_model.encoder.layers.23.layer_norm2.bias', 'text_model.encoder.layers.23.layer_norm2.weight', 'text_model.encoder.layers.23.mlp.fc1.bias', 'text_model.encoder.layers.23.mlp.fc1.weight', 'text_model.encoder.layers.23.mlp.fc2.bias', 'text_model.encoder.layers.23.mlp.fc2.weight', 'text_model.encoder.layers.23.self_attn.k_proj.bias', 'text_model.encoder.layers.23.self_attn.k_proj.weight', 'text_model.encoder.layers.23.self_attn.out_proj.bias', 'text_model.encoder.layers.23.self_attn.out_proj.weight', 'text_model.encoder.layers.23.self_attn.q_proj.bias', 'text_model.encoder.layers.23.self_attn.q_proj.weight', 'text_model.encoder.layers.23.self_attn.v_proj.bias', 'text_model.encoder.layers.23.self_attn.v_proj.weight', 'text_model.encoder.layers.24.layer_norm1.bias', 'text_model.encoder.layers.24.layer_norm1.weight', 'text_model.encoder.layers.24.layer_norm2.bias', 'text_model.encoder.layers.24.layer_norm2.weight', 'text_model.encoder.layers.24.mlp.fc1.bias', 'text_model.encoder.layers.24.mlp.fc1.weight', 'text_model.encoder.layers.24.mlp.fc2.bias', 'text_model.encoder.layers.24.mlp.fc2.weight', 'text_model.encoder.layers.24.self_attn.k_proj.bias', 'text_model.encoder.layers.24.self_attn.k_proj.weight', 'text_model.encoder.layers.24.self_attn.out_proj.bias', 'text_model.encoder.layers.24.self_attn.out_proj.weight', 'text_model.encoder.layers.24.self_attn.q_proj.bias', 'text_model.encoder.layers.24.self_attn.q_proj.weight', 'text_model.encoder.layers.24.self_attn.v_proj.bias', 'text_model.encoder.layers.24.self_attn.v_proj.weight', 'text_model.encoder.layers.25.layer_norm1.bias', 'text_model.encoder.layers.25.layer_norm1.weight', 'text_model.encoder.layers.25.layer_norm2.bias', 'text_model.encoder.layers.25.layer_norm2.weight', 'text_model.encoder.layers.25.mlp.fc1.bias', 'text_model.encoder.layers.25.mlp.fc1.weight', 'text_model.encoder.layers.25.mlp.fc2.bias', 'text_model.encoder.layers.25.mlp.fc2.weight', 'text_model.encoder.layers.25.self_attn.k_proj.bias', 'text_model.encoder.layers.25.self_attn.k_proj.weight', 'text_model.encoder.layers.25.self_attn.out_proj.bias', 'text_model.encoder.layers.25.self_attn.out_proj.weight', 'text_model.encoder.layers.25.self_attn.q_proj.bias', 'text_model.encoder.layers.25.self_attn.q_proj.weight', 'text_model.encoder.layers.25.self_attn.v_proj.bias', 'text_model.encoder.layers.25.self_attn.v_proj.weight', 'text_model.encoder.layers.26.layer_norm1.bias', 'text_model.encoder.layers.26.layer_norm1.weight', 'text_model.encoder.layers.26.layer_norm2.bias', 'text_model.encoder.layers.26.layer_norm2.weight', 'text_model.encoder.layers.26.mlp.fc1.bias', 'text_model.encoder.layers.26.mlp.fc1.weight', 'text_model.encoder.layers.26.mlp.fc2.bias', 'text_model.encoder.layers.26.mlp.fc2.weight', 'text_model.encoder.layers.26.self_attn.k_proj.bias', 'text_model.encoder.layers.26.self_attn.k_proj.weight', 'text_model.encoder.layers.26.self_attn.out_proj.bias', 'text_model.encoder.layers.26.self_attn.out_proj.weight', 'text_model.encoder.layers.26.self_attn.q_proj.bias', 'text_model.encoder.layers.26.self_attn.q_proj.weight', 'text_model.encoder.layers.26.self_attn.v_proj.bias', 'text_model.encoder.layers.26.self_attn.v_proj.weight', 'text_model.encoder.layers.27.layer_norm1.bias', 'text_model.encoder.layers.27.layer_norm1.weight', 'text_model.encoder.layers.27.layer_norm2.bias', 'text_model.encoder.layers.27.layer_norm2.weight', 'text_model.encoder.layers.27.mlp.fc1.bias', 'text_model.encoder.layers.27.mlp.fc1.weight', 'text_model.encoder.layers.27.mlp.fc2.bias', 'text_model.encoder.layers.27.mlp.fc2.weight', 'text_model.encoder.layers.27.self_attn.k_proj.bias', 'text_model.encoder.layers.27.self_attn.k_proj.weight', 'text_model.encoder.layers.27.self_attn.out_proj.bias', 'text_model.encoder.layers.27.self_attn.out_proj.weight', 'text_model.encoder.layers.27.self_attn.q_proj.bias', 'text_model.encoder.layers.27.self_attn.q_proj.weight', 'text_model.encoder.layers.27.self_attn.v_proj.bias', 'text_model.encoder.layers.27.self_attn.v_proj.weight', 'text_model.encoder.layers.28.layer_norm1.bias', 'text_model.encoder.layers.28.layer_norm1.weight', 'text_model.encoder.layers.28.layer_norm2.bias', 'text_model.encoder.layers.28.layer_norm2.weight', 'text_model.encoder.layers.28.mlp.fc1.bias', 'text_model.encoder.layers.28.mlp.fc1.weight', 'text_model.encoder.layers.28.mlp.fc2.bias', 'text_model.encoder.layers.28.mlp.fc2.weight', 'text_model.encoder.layers.28.self_attn.k_proj.bias', 'text_model.encoder.layers.28.self_attn.k_proj.weight', 'text_model.encoder.layers.28.self_attn.out_proj.bias', 'text_model.encoder.layers.28.self_attn.out_proj.weight', 'text_model.encoder.layers.28.self_attn.q_proj.bias', 'text_model.encoder.layers.28.self_attn.q_proj.weight', 'text_model.encoder.layers.28.self_attn.v_proj.bias', 'text_model.encoder.layers.28.self_attn.v_proj.weight', 'text_model.encoder.layers.29.layer_norm1.bias', 'text_model.encoder.layers.29.layer_norm1.weight', 'text_model.encoder.layers.29.layer_norm2.bias', 'text_model.encoder.layers.29.layer_norm2.weight', 'text_model.encoder.layers.29.mlp.fc1.bias', 'text_model.encoder.layers.29.mlp.fc1.weight', 'text_model.encoder.layers.29.mlp.fc2.bias', 'text_model.encoder.layers.29.mlp.fc2.weight', 'text_model.encoder.layers.29.self_attn.k_proj.bias', 'text_model.encoder.layers.29.self_attn.k_proj.weight', 'text_model.encoder.layers.29.self_attn.out_proj.bias', 'text_model.encoder.layers.29.self_attn.out_proj.weight', 'text_model.encoder.layers.29.self_attn.q_proj.bias', 'text_model.encoder.layers.29.self_attn.q_proj.weight', 'text_model.encoder.layers.29.self_attn.v_proj.bias', 'text_model.encoder.layers.29.self_attn.v_proj.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.3.mlp.fc2.bias', 'text_model.encoder.layers.3.mlp.fc2.weight', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.3.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.out_proj.bias', 'text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.3.self_attn.v_proj.weight', 'text_model.encoder.layers.30.layer_norm1.bias', 'text_model.encoder.layers.30.layer_norm1.weight', 'text_model.encoder.layers.30.layer_norm2.bias', 'text_model.encoder.layers.30.layer_norm2.weight', 'text_model.encoder.layers.30.mlp.fc1.bias', 'text_model.encoder.layers.30.mlp.fc1.weight', 'text_model.encoder.layers.30.mlp.fc2.bias', 'text_model.encoder.layers.30.mlp.fc2.weight', 'text_model.encoder.layers.30.self_attn.k_proj.bias', 'text_model.encoder.layers.30.self_attn.k_proj.weight', 'text_model.encoder.layers.30.self_attn.out_proj.bias', 'text_model.encoder.layers.30.self_attn.out_proj.weight', 'text_model.encoder.layers.30.self_attn.q_proj.bias', 'text_model.encoder.layers.30.self_attn.q_proj.weight', 'text_model.encoder.layers.30.self_attn.v_proj.bias', 'text_model.encoder.layers.30.self_attn.v_proj.weight', 'text_model.encoder.layers.31.layer_norm1.bias', 'text_model.encoder.layers.31.layer_norm1.weight', 'text_model.encoder.layers.31.layer_norm2.bias', 'text_model.encoder.layers.31.layer_norm2.weight', 'text_model.encoder.layers.31.mlp.fc1.bias', 'text_model.encoder.layers.31.mlp.fc1.weight', 'text_model.encoder.layers.31.mlp.fc2.bias', 'text_model.encoder.layers.31.mlp.fc2.weight', 'text_model.encoder.layers.31.self_attn.k_proj.bias', 'text_model.encoder.layers.31.self_attn.k_proj.weight', 'text_model.encoder.layers.31.self_attn.out_proj.bias', 'text_model.encoder.layers.31.self_attn.out_proj.weight', 'text_model.encoder.layers.31.self_attn.q_proj.bias', 'text_model.encoder.layers.31.self_attn.q_proj.weight', 'text_model.encoder.layers.31.self_attn.v_proj.bias', 'text_model.encoder.layers.31.self_attn.v_proj.weight', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.4.layer_norm1.weight', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.4.mlp.fc2.bias', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.4.self_attn.out_proj.bias', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.4.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.5.layer_norm1.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.5.mlp.fc1.weight', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.5.mlp.fc2.weight', 'text_model.encoder.layers.5.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.6.layer_norm1.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.6.self_attn.k_proj.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.6.self_attn.out_proj.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.7.layer_norm1.bias', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.7.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.bias', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.8.mlp.fc1.bias', 'text_model.encoder.layers.8.mlp.fc1.weight', 'text_model.encoder.layers.8.mlp.fc2.bias', 'text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.8.self_attn.out_proj.bias', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.8.self_attn.q_proj.bias', 'text_model.encoder.layers.8.self_attn.q_proj.weight', 'text_model.encoder.layers.8.self_attn.v_proj.bias', 'text_model.encoder.layers.8.self_attn.v_proj.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.9.layer_norm2.weight', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.9.self_attn.k_proj.bias', 'text_model.encoder.layers.9.self_attn.k_proj.weight', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.9.self_attn.out_proj.weight', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.9.self_attn.q_proj.weight', 'text_model.encoder.layers.9.self_attn.v_proj.bias', 'text_model.encoder.layers.9.self_attn.v_proj.weight', 'text_model.final_layer_norm.bias', 'text_model.final_layer_norm.weight', 'text_projection.weight'])\n","βœ… Model loaded from /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_e1_step_10000.safetensors\n","βœ… T5, CLIP, and Dual Shunt Adapter loaded and cast.\n"]}]},{"cell_type":"code","source":["\n","clip_mod.load_state_dict(temp_clip, strict=False)\n","clip_mod = torch.compile(clip_mod, mode=\"reduce-overhead\", dtype=torch.float32)\n","## ─── Load Dual Shunt Adapter ────────────────────────────────"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":176},"id":"URcbaWEJfiKJ","executionInfo":{"status":"error","timestamp":1748636908022,"user_tz":420,"elapsed":1097,"user":{"displayName":"P C","userId":"00707517734723903966"}},"outputId":"81dc483e-e3e0-4295-98c0-190ac89a6237"},"execution_count":6,"outputs":[{"output_type":"error","ename":"TypeError","evalue":"compile() got an unexpected keyword argument 'dtype'","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)","\u001b[0;32m<ipython-input-6-56deb4d1f78d>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mclip_mod\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtemp_clip\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstrict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mclip_mod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclip_mod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"reduce-overhead\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0;31m## ─── Load Dual Shunt Adapter ────────────────────────────────\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mTypeError\u001b[0m: compile() got an unexpected keyword argument 'dtype'"]}]},{"cell_type":"markdown","source":["# Noise Train"],"metadata":{"id":"pj2a104jbp42"}},{"cell_type":"code","source":["import torch\n","import torch.nn.functional as F\n","from tqdm import tqdm\n","from pathlib import Path\n","\n","\n","def train_dual_shunt(\n","    t5_tok, clip_tok, t5_mod, clip_mod, adapter,\n","    repo_id: str,\n","    batch_size: int = 256,\n","    epochs: int = 1,\n","    lr: float = 5e-4,\n","    device: str = \"cuda\",\n","    g_target: float = 7.5,\n","    Ξ»_delta: float = 1.0,\n","    Ξ»_gate: float  = 0.15,\n","    Ξ»_ent: float   = 0.02,\n","    Ξ»_tau: float   = 0.02,\n","    Ξ»_anchor: float= 0.12,\n","    Ξ»_guid: float  = 0.1,\n","    save_every: int = 1000,\n","    max_token_length: int = 77,\n","    seed: int = 42,\n","    t5_prompt_mode: str = \"real_raw\",\n","    output_dir: str = \"/content/drive/MyDrive/dual_shunt_runs_omega_g\"\n","):\n","    import torch\n","    import torch.nn.functional as F\n","    from tqdm import tqdm\n","    from pathlib import Path\n","    import random\n","\n","    def extract_sparse_tokens(text: str, keep_min: int = 1, keep_max: int = 5, keep_prob: float = 0.3) -> str:\n","        words = text.strip().split()\n","        keep_count = min(len(words), random.randint(keep_min, keep_max))\n","        indices = sorted(random.sample(range(len(words)), k=keep_count))\n","        return \" \".join(words[i] for i in indices)\n","\n","    def seed_everything(seed: int = 42):\n","        torch.manual_seed(seed)\n","        torch.cuda.manual_seed_all(seed)\n","        torch.backends.cudnn.deterministic = True\n","        torch.backends.cudnn.benchmark = False\n","\n","    LOG_2PI = 1.83787706641\n","\n","    def hetero_loss(delta, target, log_sigma):\n","        log_sigma = log_sigma.clamp(-5.0, 5.0)\n","        inv_var = torch.exp(-log_sigma)\n","        return 0.5 * (inv_var * (delta - target)**2 + log_sigma + LOG_2PI).mean()\n","\n","    def entropy(p: torch.Tensor) -> torch.Tensor:\n","        p = p / (p.sum(-1, keepdim=True) + 1e-9)\n","        return -(p * (p + 1e-9).log()).sum(-1).mean()\n","\n","    # ─── Setup ───────────────────────────────────────────────\n","    seed_everything(seed)\n","    device = torch.device(device)\n","\n","    ds = ParsedMultiCharDataset(repo_id, start_file=50, num_files=6, shuffle=True)\n","    dl = torch.utils.data.DataLoader(\n","        ds, batch_size=batch_size, num_workers=4,\n","        drop_last=True, persistent_workers=True,\n","        generator=torch.Generator().manual_seed(seed),\n","        prefetch_factor=8\n","    )\n","\n","    output_dir = Path(output_dir)\n","    output_dir.mkdir(parents=True, exist_ok=True)\n","    global_step = 0\n","\n","    opt = torch.optim.AdamW(adapter.parameters(), lr=lr)\n","    scaler = torch.amp.GradScaler(device=\"cuda\")\n","\n","    # ─── Training Loop ───────────────────────────────────────\n","    for epoch in range(1, epochs + 1):\n","        pbar = tqdm(dl, desc=f\"Epoch {epoch}/{epochs}\")\n","        for texts in pbar:\n","            global_step += 1\n","\n","            with torch.autocast(device.type, dtype=torch.bfloat16):\n","                # T5 prompt encoding\n","                t5_inputs = t5_tok(texts, padding=True, truncation=True,\n","                                   max_length=max_token_length, return_tensors=\"pt\").to(device)\n","                with torch.no_grad():\n","                    t5_seq = t5_mod(**t5_inputs).last_hidden_state\n","\n","                # CLIP (plain)\n","                clip_inputs = clip_tok(texts, padding=\"max_length\", truncation=True,\n","                                       max_length=max_token_length, return_tensors=\"pt\").to(device)\n","                with torch.no_grad():\n","                    clip_seq_plain = clip_mod(**clip_inputs).last_hidden_state\n","\n","                # CLIP (captioned target)\n","\n","                # for null noise training\n","                cap_texts = [f\"{extract_sparse_tokens(t)}\" for t in texts]\n","                clip_inputs_cap = clip_tok(cap_texts, padding=\"max_length\", truncation=True,\n","                                           max_length=max_token_length, return_tensors=\"pt\").to(device)\n","                with torch.no_grad():\n","                    clip_seq_cap = clip_mod(**clip_inputs_cap).last_hidden_state\n","\n","                #with torch.no_grad():\n","                #    summary_ids = t5_mod.generate(**t5_inputs, max_length=max_token_length)\n","                #    cap_texts = t5_tok.batch_decode(summary_ids, skip_special_tokens=True)\n","\n","                # Encode summaries into CLIP\n","                #clip_inputs_cap = clip_tok(cap_texts, padding=\"max_length\", truncation=True,\n","                #                          max_length=max_token_length, return_tensors=\"pt\").to(device)\n","                #with torch.no_grad():\n","                #    clip_seq_cap = clip_mod(**clip_inputs_cap).last_hidden_state\n","\n","                # Adapter forward\n","                anchor, delta, log_sigma, attn_t2c, attn_c2t, tau, g_pred, gate = adapter(t5_seq, clip_seq_plain)\n","                delta_tgt = clip_seq_cap - clip_seq_plain\n","\n","                # πŸ”§ SAFE attention projection for auxiliary losses\n","                with torch.no_grad():\n","                    t5_b   = adapter.proj_t5(t5_seq)\n","                    clip_b = adapter.proj_clip(clip_seq_plain)\n","                    t2c_base, _ = adapter.cross_t2c(t5_b, clip_b, clip_b)\n","                    c2t_base, _ = adapter.cross_c2t(clip_b, t5_b, t5_b)\n","\n","                pocket = adapter.pocket_blocks(t2c_base.detach())\n","                loss_pocket = F.mse_loss(pocket, t2c_base.detach()) * 0.05\n","\n","                fused_h = adapter.fuse(torch.cat([\n","                    pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1),\n","                    c2t_base\n","                ], dim=-1))\n","                loss_hidden = (fused_h.norm(dim=-1).mean() - 1.0).abs() * 0.01\n","\n","                # Core losses\n","                loss_delta  = hetero_loss(delta, delta_tgt, log_sigma) * Ξ»_delta\n","                loss_gate   = (gate.mean() - 0.25).abs() * Ξ»_gate\n","                loss_ent    = 0.5 * (entropy(attn_t2c) + entropy(attn_c2t)) * Ξ»_ent\n","                loss_tau    = tau.abs().mean() * Ξ»_tau\n","                loss_anchor = (1 - F.cosine_similarity(anchor.mean(1), clip_seq_plain.mean(1), dim=-1).mean()) * Ξ»_anchor\n","                loss_guid   = F.mse_loss(g_pred, torch.full_like(g_pred, g_target)) * Ξ»_guid\n","\n","                total_loss = (\n","                    loss_delta + loss_gate + loss_ent +\n","                    loss_tau + loss_anchor + loss_guid +\n","                    loss_pocket + loss_hidden\n","                )\n","\n","            # Backprop + step\n","            scaler.scale(total_loss).backward()\n","            scaler.step(opt)\n","            scaler.update()\n","            opt.zero_grad(set_to_none=True)\n","            if global_step % 100 == 0:\n","                print(f\"Gate mean: {gate.mean().item():.3f}  log_sigma mean: {log_sigma.mean().item():.3f}\")\n","            if global_step % save_every == 0:\n","                path = output_dir / f\"dual_shunt_omega_{t5_prompt_mode}_noised_e1_step_{global_step}.safetensors\"\n","                save_safetensors(adapter, path)\n","\n","            pbar.set_postfix(loss=float(total_loss))\n","\n","        final = output_dir / f\"dual_shunt_omega_{t5_prompt_mode}_noised_e1_final.safetensors\"\n","        save_safetensors(adapter, final)\n","        print(f\"βœ… Epoch {epoch} complete. Final model saved.\")\n"],"metadata":{"id":"WoSAXzGDWsMv","executionInfo":{"status":"ok","timestamp":1748639025564,"user_tz":420,"elapsed":28,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":13,"outputs":[]},{"cell_type":"code","source":["train_dual_shunt(\n","    t5_tok=t5_tok,\n","    clip_tok=clip_tok,\n","    t5_mod=t5_mod,\n","    clip_mod=clip_mod,\n","    adapter=adapter,\n","    repo_id=\"AbstractPhil/human-templated-captions-1b\",\n","    batch_size=256,\n","    epochs=1,\n","    lr=1e-4,\n","    device=\"cuda\",\n","    g_target=7.5,\n","    save_every=1000,\n","    seed=42,\n","    max_token_length=77,\n","    t5_prompt_mode=\"no_caption\",\n","    output_dir=\"/content/drive/MyDrive/dual_shunt_runs_omega_g\"\n",")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"vJ8zeQ9-irqU","outputId":"0292845a-d502-476c-af48-0f4438145925"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   0%|          | 100/117187 [03:36<69:52:34,  2.15s/it, loss=1.4]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.672\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   0%|          | 200/117187 [07:12<70:05:53,  2.16s/it, loss=1.38]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.504\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   0%|          | 300/117187 [11:30<70:27:44,  2.17s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.641\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   0%|          | 400/117187 [15:06<69:24:20,  2.14s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.648\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   0%|          | 500/117187 [18:42<70:10:52,  2.17s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.664\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   1%|          | 600/117187 [22:19<70:05:26,  2.16s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.695\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   1%|          | 700/117187 [25:55<70:00:11,  2.16s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.703\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   1%|          | 800/117187 [29:31<68:42:09,  2.13s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   1%|          | 900/117187 [33:08<69:25:05,  2.15s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.746\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   1%|          | 999/117187 [36:41<69:35:16,  2.16s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   1%|          | 1000/117187 [36:44<72:33:06,  2.25s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["βœ… Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_1000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   1%|          | 1100/117187 [40:21<69:00:44,  2.14s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.773\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   1%|          | 1200/117187 [43:56<70:06:30,  2.18s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.762\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   1%|          | 1300/117187 [47:33<69:19:52,  2.15s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.785\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   1%|          | 1400/117187 [51:10<69:26:39,  2.16s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.773\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   1%|▏         | 1500/117187 [54:47<69:38:53,  2.17s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.789\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   1%|▏         | 1600/117187 [58:24<69:27:47,  2.16s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.691\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   1%|▏         | 1700/117187 [1:02:00<70:12:30,  2.19s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.715\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   2%|▏         | 1800/117187 [1:05:37<69:13:31,  2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.730\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   2%|▏         | 1900/117187 [1:09:13<68:57:24,  2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   2%|▏         | 1999/117187 [1:12:48<69:23:59,  2.17s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   2%|▏         | 2000/117187 [1:12:50<72:07:19,  2.25s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["βœ… Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_2000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   2%|▏         | 2100/117187 [1:16:26<70:14:40,  2.20s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.766\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   2%|▏         | 2200/117187 [1:20:03<68:38:11,  2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.777\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   2%|▏         | 2300/117187 [1:23:39<69:11:14,  2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.773\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   2%|▏         | 2400/117187 [1:27:15<70:12:39,  2.20s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.762\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   2%|▏         | 2500/117187 [1:30:52<68:46:14,  2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.766\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   2%|▏         | 2600/117187 [1:34:29<68:55:45,  2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.785\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   2%|▏         | 2700/117187 [1:38:05<69:16:25,  2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   2%|▏         | 2800/117187 [1:41:42<69:14:21,  2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.734\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   2%|▏         | 2900/117187 [1:45:19<68:12:24,  2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.746\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   3%|β–Ž         | 2999/117187 [1:48:52<68:32:54,  2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.812\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   3%|β–Ž         | 3000/117187 [1:48:54<71:14:59,  2.25s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["βœ… Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_3000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   3%|β–Ž         | 3100/117187 [1:52:31<68:08:09,  2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.863\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   3%|β–Ž         | 3200/117187 [1:56:07<68:19:20,  2.16s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.844\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   3%|β–Ž         | 3300/117187 [1:59:44<67:46:09,  2.14s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.852\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   3%|β–Ž         | 3400/117187 [2:03:19<68:13:23,  2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.863\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   3%|β–Ž         | 3500/117187 [2:06:56<68:10:13,  2.16s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.887\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   3%|β–Ž         | 3600/117187 [2:10:33<68:50:05,  2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.727\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   3%|β–Ž         | 3700/117187 [2:14:10<68:26:47,  2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.746\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   3%|β–Ž         | 3800/117187 [2:17:46<69:02:10,  2.19s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   3%|β–Ž         | 3900/117187 [2:21:22<67:54:57,  2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   3%|β–Ž         | 3999/117187 [2:24:56<68:50:27,  2.19s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.770\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   3%|β–Ž         | 4000/117187 [2:24:58<71:29:54,  2.27s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["βœ… Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_4000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   3%|β–Ž         | 4100/117187 [2:28:35<68:26:47,  2.18s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   4%|β–Ž         | 4200/117187 [2:32:12<67:47:18,  2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   4%|β–Ž         | 4300/117187 [2:35:49<67:43:33,  2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   4%|▍         | 4400/117187 [2:39:25<68:20:20,  2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   4%|▍         | 4500/117187 [2:43:01<68:01:47,  2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.652\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   4%|▍         | 4600/117187 [2:46:38<67:19:53,  2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.605\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   4%|▍         | 4700/117187 [2:50:15<68:15:17,  2.18s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.609\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   4%|▍         | 4800/117187 [2:53:51<66:49:04,  2.14s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.688\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   4%|▍         | 4900/117187 [2:57:28<67:53:59,  2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.766\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   4%|▍         | 4999/117187 [3:01:03<67:34:57,  2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.719\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   4%|▍         | 5000/117187 [3:01:05<70:09:23,  2.25s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["βœ… Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_5000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   4%|▍         | 5100/117187 [3:04:42<67:32:22,  2.17s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.734\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   4%|▍         | 5200/117187 [3:08:18<67:25:42,  2.17s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.746\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   5%|▍         | 5300/117187 [3:11:54<67:17:14,  2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   5%|▍         | 5400/117187 [3:15:30<67:11:10,  2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.754\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   5%|▍         | 5500/117187 [3:19:06<67:10:37,  2.17s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.754\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   5%|▍         | 5600/117187 [3:22:43<66:17:06,  2.14s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   5%|▍         | 5700/117187 [3:26:20<67:08:19,  2.17s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.773\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   5%|▍         | 5800/117187 [3:29:56<66:15:35,  2.14s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   5%|β–Œ         | 5900/117187 [3:33:33<65:57:43,  2.13s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   5%|β–Œ         | 5999/117187 [3:37:06<65:53:32,  2.13s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.770\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   5%|β–Œ         | 6000/117187 [3:37:09<68:33:52,  2.22s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["βœ… Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_6000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   5%|β–Œ         | 6100/117187 [3:40:44<66:08:48,  2.14s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.742\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   5%|β–Œ         | 6200/117187 [3:44:19<68:30:20,  2.22s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.629\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   5%|β–Œ         | 6300/117187 [3:47:56<66:18:20,  2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.605\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   5%|β–Œ         | 6400/117187 [3:51:32<66:18:28,  2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.570\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   6%|β–Œ         | 6500/117187 [3:55:08<65:36:41,  2.13s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.621\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   6%|β–Œ         | 6600/117187 [3:58:45<65:56:19,  2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.637\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   6%|β–Œ         | 6700/117187 [4:02:21<66:20:04,  2.16s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.664\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   6%|β–Œ         | 6800/117187 [4:05:57<67:02:27,  2.19s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.652\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   6%|β–Œ         | 6900/117187 [4:09:33<65:58:51,  2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.570\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   6%|β–Œ         | 6999/117187 [4:13:08<65:45:42,  2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.602\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   6%|β–Œ         | 7000/117187 [4:13:11<68:46:49,  2.25s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["βœ… Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_7000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   6%|β–Œ         | 7100/117187 [4:16:47<65:48:45,  2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270  log_sigma mean: -0.617\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1:   6%|β–Œ         | 7159/117187 [4:18:55<65:55:30,  2.16s/it, loss=1.34]"]}]},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"fy3PbHL4XEv3","executionInfo":{"status":"ok","timestamp":1748024232657,"user_tz":420,"elapsed":23471,"user":{"displayName":"P C","userId":"00707517734723903966"}},"outputId":"fdc85aa5-3920-44bc-8bc3-81f8d8a90cf2"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"markdown","source":["# summarization train"],"metadata":{"id":"EKLDJqI17Tkm"}},{"cell_type":"code","source":["from pathlib import Path\n","from tqdm import tqdm\n","import torch\n","import torch.nn.functional as F\n","from torch.utils.data import DataLoader\n","\n","def train_dual_shunt(\n","    t5_tok, clip_tok, t5_mod, clip_mod, adapter,\n","    repo_id: str,\n","    batch_size: int = 256,\n","    epochs: int = 1,\n","    lr: float = 5e-4,\n","    device: str = \"cuda\",\n","    g_target: float = 7.5,\n","    Ξ»_delta: float = 1.0,\n","    Ξ»_gate: float = 0.15,\n","    Ξ»_ent: float = 0.02,\n","    Ξ»_tau: float = 0.02,\n","    Ξ»_anchor: float = 0.12,\n","    Ξ»_guid: float = 0.1,\n","    save_every: int = 1000,\n","    max_token_length: int = 77,\n","    seed: int = 42,\n","    t5_prompt_mode: str = \"summarization_curriculum\",\n","    output_dir: str = \"/content/drive/MyDrive/dual_shunt_runs_omega_g\"\n","):\n","    def seed_everything(seed: int = 42):\n","        torch.manual_seed(seed)\n","        torch.cuda.manual_seed_all(seed)\n","        torch.backends.cudnn.deterministic = True\n","        torch.backends.cudnn.benchmark = False\n","\n","    def hetero_loss(delta, target, log_sigma):\n","        LOG_2PI = 1.83787706641\n","        log_sigma = log_sigma.clamp(-5.0, 5.0)\n","        inv_var = torch.exp(-log_sigma)\n","        return 0.5 * (inv_var * (delta - target) ** 2 + log_sigma + LOG_2PI).mean()\n","\n","    def entropy(p: torch.Tensor) -> torch.Tensor:\n","        p = p / (p.sum(-1, keepdim=True) + 1e-9)\n","        return -(p * (p + 1e-9).log()).sum(-1).mean()\n","\n","    # Setup\n","    seed_everything(seed)\n","    device = torch.device(device)\n","    ds = ParsedMultiCharDataset(repo_id, start_file=50, num_files=6, shuffle=True)\n","    dl = DataLoader(ds, batch_size=batch_size, num_workers=4, drop_last=True,\n","                    persistent_workers=True, generator=torch.Generator().manual_seed(seed),\n","                    prefetch_factor=8)\n","\n","    output_dir = Path(output_dir)\n","    output_dir.mkdir(parents=True, exist_ok=True)\n","    global_step = 0\n","\n","    opt = torch.optim.AdamW(adapter.parameters(), lr=lr)\n","    scaler = torch.cuda.amp.GradScaler()\n","\n","    for epoch in range(1, epochs + 1):\n","        pbar = tqdm(dl, desc=f\"Epoch {epoch}/{epochs}\")\n","        for texts in pbar:\n","            global_step += 1\n","\n","            with torch.autocast(device.type, dtype=torch.bfloat16):\n","                # T5 encoding\n","                t5_inputs = t5_tok(texts, padding=True, truncation=True,\n","                                   max_length=max_token_length, return_tensors=\"pt\").to(device)\n","                with torch.no_grad():\n","                    t5_seq = t5_mod(**t5_inputs).last_hidden_state\n","\n","                # CLIP (plain)\n","                clip_inputs = clip_tok(texts, padding=\"max_length\", truncation=True,\n","                                       max_length=max_token_length, return_tensors=\"pt\").to(device)\n","                with torch.no_grad():\n","                    clip_seq_plain = clip_mod(**clip_inputs).last_hidden_state\n","\n","                # Generate summaries\n","                with torch.no_grad():\n","                    summary_ids = t5_mod.generate(**t5_inputs, max_length=max_token_length)\n","                    cap_texts = t5_tok.batch_decode(summary_ids, skip_special_tokens=True)\n","                    clip_inputs_cap = clip_tok(cap_texts, padding=\"max_length\", truncation=True,\n","                                               max_length=max_token_length, return_tensors=\"pt\").to(device)\n","                    clip_seq_summary = clip_mod(**clip_inputs_cap).last_hidden_state\n","\n","                    null_clip_inputs = clip_tok([\"\"] * len(texts), padding=\"max_length\", truncation=True,\n","                                                max_length=max_token_length, return_tensors=\"pt\").to(device)\n","                    clip_seq_null = clip_mod(**null_clip_inputs).last_hidden_state\n","\n","                # Curriculum interpolation\n","                interpolation = min(global_step / 10000.0, 1.0)\n","                clip_seq_cap = (1.0 - interpolation) * clip_seq_null + interpolation * clip_seq_summary\n","                delta_tgt = clip_seq_cap - clip_seq_plain\n","\n","                # Adapter forward\n","                anchor, delta, log_sigma, attn_t2c, attn_c2t, tau, g_pred, gate = adapter(t5_seq, clip_seq_plain)\n","\n","                with torch.no_grad():\n","                    t5_b = adapter.proj_t5(t5_seq)\n","                    clip_b = adapter.proj_clip(clip_seq_plain)\n","                    t2c_base, _ = adapter.cross_t2c(t5_b, clip_b, clip_b)\n","                    c2t_base, _ = adapter.cross_c2t(clip_b, t5_b, t5_b)\n","\n","                pocket = adapter.pocket_blocks(t2c_base.detach())\n","                loss_pocket = F.mse_loss(pocket, t2c_base.detach()) * 0.05\n","\n","                fused_h = adapter.fuse(torch.cat([\n","                    pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1),\n","                    c2t_base\n","                ], dim=-1))\n","                loss_hidden = (fused_h.norm(dim=-1).mean() - 1.0).abs() * 0.01\n","\n","                # Losses\n","                warmup = min(global_step / 5000.0, 1.0)\n","                loss_delta = hetero_loss(delta, delta_tgt, log_sigma) * Ξ»_delta * warmup\n","\n","                gate_target_mean = 0.25 + 0.15 * interpolation\n","                loss_gate = (gate.mean() - gate_target_mean).abs() * Ξ»_gate\n","\n","                loss_ent = 0.5 * (entropy(attn_t2c) + entropy(attn_c2t)) * Ξ»_ent\n","                loss_tau = tau.abs().mean() * Ξ»_tau\n","                loss_anchor = (1 - F.cosine_similarity(anchor.mean(1), clip_seq_plain.mean(1), dim=-1).mean()) * Ξ»_anchor\n","                loss_guid = F.mse_loss(g_pred, torch.full_like(g_pred, g_target)) * Ξ»_guid\n","\n","                total_loss = (\n","                    loss_delta + loss_gate + loss_ent +\n","                    loss_tau + loss_anchor + loss_guid +\n","                    loss_pocket + loss_hidden\n","                )\n","\n","            # Backprop + step\n","            scaler.scale(total_loss).backward()\n","            scaler.step(opt)\n","            scaler.update()\n","            opt.zero_grad(set_to_none=True)\n","\n","            if global_step % 100 == 0:\n","                print(f\"[Step {global_step}] Loss: {total_loss.item():.4f}  Gate Mean: {gate.mean().item():.3f}\")\n","\n","            if global_step % save_every == 0:\n","                save_path = output_dir / f\"dual_shunt_omega_{t5_prompt_mode}_e{epoch}_step_{global_step}.safetensors\"\n","                save_safetensors(adapter, save_path)\n","\n","        final = output_dir / f\"dual_shunt_omega_{t5_prompt_mode}_e{epoch}_final.safetensors\"\n","        save_safetensors(adapter, final)\n","        print(f\"βœ… Epoch {epoch} complete. Final model saved to {final}\")\n"],"metadata":{"id":"iXmrn28M7L9L","executionInfo":{"status":"ok","timestamp":1748637860921,"user_tz":420,"elapsed":31,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["train_dual_shunt(\n","    t5_tok=t5_tok,\n","    clip_tok=clip_tok,\n","    t5_mod=t5_mod,\n","    clip_mod=clip_mod,\n","    adapter=adapter,\n","    repo_id=\"AbstractPhil/human-templated-captions-1b\",\n","    batch_size=256,\n","    epochs=1,\n","    lr=1e-4,\n","    device=\"cuda\",\n","    g_target=7.5,\n","    save_every=1000,\n","    seed=42,\n","    max_token_length=77,\n","    t5_prompt_mode=\"no_caption\",\n","    output_dir=\"/content/drive/MyDrive/dual_shunt_runs_omega_g\"\n",")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":408},"id":"aI2BH6Ja7YqZ","executionInfo":{"status":"error","timestamp":1748637873437,"user_tz":420,"elapsed":3426,"user":{"displayName":"P C","userId":"00707517734723903966"}},"outputId":"67122a9c-64f5-4cc0-85fa-7f5b810aed77"},"execution_count":8,"outputs":[{"output_type":"stream","name":"stderr","text":["<ipython-input-7-50f64bd900da>:56: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n","  scaler = torch.cuda.amp.GradScaler()\n","Epoch 1/1:   0%|          | 0/117187 [00:01<?, ?it/s]\n"]},{"output_type":"error","ename":"AttributeError","evalue":"'T5EncoderModel' object has no attribute 'generate'","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)","\u001b[0;32m<ipython-input-8-db93b9908267>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m train_dual_shunt(\n\u001b[0m\u001b[1;32m      2\u001b[0m     \u001b[0mt5_tok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mt5_tok\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0mclip_tok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclip_tok\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0mt5_mod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mt5_mod\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0mclip_mod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclip_mod\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m<ipython-input-7-50f64bd900da>\u001b[0m in \u001b[0;36mtrain_dual_shunt\u001b[0;34m(t5_tok, clip_tok, t5_mod, clip_mod, adapter, repo_id, batch_size, epochs, lr, device, g_target, Ξ»_delta, Ξ»_gate, Ξ»_ent, Ξ»_tau, Ξ»_anchor, Ξ»_guid, save_every, max_token_length, seed, t5_prompt_mode, output_dir)\u001b[0m\n\u001b[1;32m     76\u001b[0m                 \u001b[0;31m# Generate summaries\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     77\u001b[0m                 \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 78\u001b[0;31m                     \u001b[0msummary_ids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt5_mod\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mt5_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_length\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_token_length\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     79\u001b[0m                     \u001b[0mcap_texts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt5_tok\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_decode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msummary_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mskip_special_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     80\u001b[0m                     clip_inputs_cap = clip_tok(cap_texts, padding=\"max_length\", truncation=True,\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m   1926\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1927\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1928\u001b[0;31m         raise AttributeError(\n\u001b[0m\u001b[1;32m   1929\u001b[0m             \u001b[0;34mf\"'{type(self).__name__}' object has no attribute '{name}'\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1930\u001b[0m         )\n","\u001b[0;31mAttributeError\u001b[0m: 'T5EncoderModel' object has no attribute 'generate'"]}]}]}