File size: 127,037 Bytes
daf8ad1 |
1 |
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"collapsed_sections":["erBNP6hvU724","PiehSbGeWmor","EKLDJqI17Tkm"],"machine_shape":"hm","gpuType":"L4","mount_file_id":"1Tsf5s2FZEHr9S5ja8MqkIA3DtPGNb0Cy","authorship_tag":"ABX9TyOx4EAq5vmbjU6gpbNxmViP"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"9bf9176fc21b4f15896d7281ced5fa17":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_76cf4a2d556842e497e2e64c2ba0b3fe","IPY_MODEL_a6b3d0c5d84b477c82545fb809fdda1c","IPY_MODEL_456b3fd649ab4c0fbb5348a765fb9613"],"layout":"IPY_MODEL_32bcb31386ca4187ba6a554d26189502"}},"76cf4a2d556842e497e2e64c2ba0b3fe":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_d4696d3bbdc149be9bab0cf4664c68c4","placeholder":"β","style":"IPY_MODEL_1575e84b9dc5454d94258035abf26da1","value":"Loadingβcheckpointβshards:β100%"}},"a6b3d0c5d84b477c82545fb809fdda1c":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_b4e41d9dcb2c4be8bf94919c93c56a83","max":2,"min":0,"orientation":"horizontal","style":"IPY_MODEL_d96c33b98b874cf8862a2bedb3596b4c","value":2}},"456b3fd649ab4c0fbb5348a765fb9613":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_d7f4744342ec4ebd8d2798cae2e90518","placeholder":"β","style":"IPY_MODEL_2f9c4dbc55d04276a5b9fa92de613bf5","value":"β2/2β[00:00<00:00,ββ2.19it/s]"}},"32bcb31386ca4187ba6a554d26189502":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"d4696d3bbdc149be9bab0cf4664c68c4":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"1575e84b9dc5454d94258035abf26da1":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"b4e41d9dcb2c4be8bf94919c93c56a83":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"d96c33b98b874cf8862a2bedb3596b4c":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"d7f4744342ec4ebd8d2798cae2e90518":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"2f9c4dbc55d04276a5b9fa92de613bf5":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","source":["## Config and Module"],"metadata":{"id":"erBNP6hvU724"}},{"cell_type":"code","execution_count":1,"metadata":{"id":"9WGrocZdPBY_","executionInfo":{"status":"ok","timestamp":1748636886799,"user_tz":420,"elapsed":19,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"outputs":[],"source":["ADAPTER_CONFIG = {\n"," \"adapter_id\": \"003\",\n"," \"name\": \"DualShuntAdapter-G\",\n","\n"," \"t5\": {\n"," \"model\": \"google/flan-t5-base\",\n"," \"hidden_size\": 768,\n"," },\n"," \"clip\": {\n"," \"model\": \"AbstractPhil/omega-vit-g-reformed\",\n"," \"hidden_size\": 1280,\n"," },\n","\n"," \"bottleneck\": 640,\n"," \"heads\": 20,\n","\n"," \"tau_init\": 0.1,\n"," \"max_guidance\": 10.0,\n","\n"," \"proj_layers\": 2,\n"," \"layer_norm\": True,\n"," \"dropout\": 0.1,\n"," \"use_dropout\": True,\n"," \"use_proj_stack\": True,\n"," \"assert_input_dims\": True,\n","\n"," \"routing\": {\n"," \"type\": \"cross_attention\",\n"," \"enable_causal_mask\": False,\n"," \"bidirectional\": True\n"," },\n","\n"," \"version\": \"v0.3.2\",\n"," \"description\": \"Final Dual Shunt Adapter with projection stack, dropout, and stacked residual refinement pocket.\"\n","}\n"]},{"cell_type":"code","source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","# βββ Residual Pocket Block βββββββββββββββββββββββββββββββββββ\n","class BottleneckResBlock(nn.Module):\n"," def __init__(self, dim, kernel=3, dropout=0.1):\n"," super().__init__()\n"," self.norm = nn.LayerNorm(dim)\n"," self.conv = nn.Conv1d(dim, dim, kernel_size=kernel, padding=kernel // 2, groups=1)\n"," self.proj = nn.Sequential(\n"," nn.Linear(dim, dim * 2),\n"," nn.GELU(),\n"," nn.Linear(dim * 2, dim),\n"," nn.Dropout(dropout)\n"," )\n","\n"," def forward(self, x):\n"," residual = x\n"," x = self.norm(x)\n"," x = x.transpose(1, 2)\n"," x = self.conv(x).transpose(1, 2)\n"," return residual + self.proj(x)\n","\n","# βββ Two Stream Shunt Adapter ββββββββββββββββββββββββββββββββββββββ\n","class TwoStreamShuntAdapter(nn.Module):\n"," def __init__(self, config: dict):\n"," super().__init__()\n"," self.config = config\n"," self.t5_dim = config[\"t5\"][\"hidden_size\"]\n"," self.clip_dim = config[\"clip\"][\"hidden_size\"]\n"," self.bneck = config[\"bottleneck\"]\n"," self.heads = config[\"heads\"]\n"," self.tau_init = config[\"tau_init\"]\n"," self.max_guidance = config[\"max_guidance\"]\n","\n"," use_norm = config.get(\"layer_norm\", True)\n"," use_do = config.get(\"use_dropout\", True)\n"," do_p = config.get(\"dropout\", 0.1)\n"," proj_depth = config.get(\"proj_layers\", 2)\n","\n"," def build_projection(input_dim, output_dim):\n"," layers = []\n"," last_dim = input_dim\n"," if use_norm:\n"," layers.append(nn.LayerNorm(last_dim))\n"," for i in range(proj_depth):\n"," next_dim = self.bneck * (2 if i == 0 and proj_depth > 1 else 1)\n"," layers.append(nn.Linear(last_dim, next_dim))\n"," layers.append(nn.GELU())\n"," if use_do:\n"," layers.append(nn.Dropout(do_p))\n"," last_dim = next_dim\n"," layers.append(nn.Linear(last_dim, output_dim))\n"," return nn.Sequential(*layers)\n","\n"," # Projections\n"," self.proj_t5 = build_projection(self.t5_dim, self.bneck)\n"," self.proj_clip = build_projection(self.clip_dim, self.bneck)\n","\n"," # Attention\n"," self.cross_t2c = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)\n"," self.cross_c2t = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)\n"," self.tau = nn.Parameter(torch.full((self.heads, 1, 1), self.tau_init))\n","\n"," # Residual Pocket\n"," self.pocket_blocks = nn.Sequential(\n"," BottleneckResBlock(self.bneck, dropout=do_p),\n"," BottleneckResBlock(self.bneck, dropout=do_p)\n"," )\n","\n"," # Fuse\n"," self.fuse = nn.Sequential(\n"," nn.LayerNorm(2 * self.bneck),\n"," nn.Linear(2 * self.bneck, self.bneck * 2),\n"," nn.GELU(),\n"," nn.Linear(self.bneck * 2, self.bneck)\n"," )\n","\n"," # Output Projections\n"," self.anchor_proj = build_projection(self.bneck, self.clip_dim)\n"," self.delta_proj = build_projection(self.bneck, self.clip_dim)\n"," self.logsig_proj = build_projection(self.bneck, self.clip_dim)\n","\n"," self.gate_proj = nn.Sequential(\n"," nn.LayerNorm(self.bneck),\n"," nn.Linear(self.bneck, self.bneck),\n"," nn.GELU(),\n"," nn.Linear(self.bneck, 1),\n"," nn.Tanh(),\n"," nn.Sigmoid()\n"," )\n","\n"," self.guidance_proj = nn.Sequential(\n"," nn.LayerNorm(self.bneck),\n"," nn.Linear(self.bneck, 1),\n"," nn.Sigmoid()\n"," )\n","\n"," def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):\n"," if self.config.get(\"assert_input_dims\", True):\n"," assert t5_seq.size(-1) == self.t5_dim\n"," assert clip_seq.size(-1) == self.clip_dim\n","\n"," t5_b = self.proj_t5(t5_seq)\n"," clip_b = self.proj_clip(clip_seq)\n","\n"," t2c, attn_t2c = self.cross_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)\n"," c2t, attn_c2t = self.cross_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)\n","\n"," pocket = self.pocket_blocks(t2c)\n","\n"," pocket_mean = pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1)\n"," h = self.fuse(torch.cat([pocket_mean, c2t], dim=-1))\n","\n"," anchor = self.anchor_proj(h)\n"," delta = self.delta_proj(h) * self.gate_proj(h)\n"," log_sigma = self.logsig_proj(h)\n","\n"," g_tok = self.guidance_proj(h).squeeze(-1)\n"," g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance\n","\n"," return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, self.gate_proj(h)\n"],"metadata":{"id":"qjb_vFZRTQaC","executionInfo":{"status":"ok","timestamp":1748636888396,"user_tz":420,"elapsed":1586,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":2,"outputs":[]},{"cell_type":"code","source":["from safetensors.torch import save_file, load_file\n","\n","def save_safetensors(adapter: nn.Module, path: str, metadata: dict = None):\n"," \"\"\"\n"," Save the current adapter state to safetensors format.\n","\n"," All tensors are moved to CPU and saved as float32 for compatibility.\n"," Optional metadata may be embedded (e.g., version, prompt_mode).\n"," \"\"\"\n"," state = {k: v.float().cpu() for k, v in adapter.state_dict().items()}\n"," save_file(state, path, metadata=metadata or {})\n"," print(f\"β
Model saved to {path}\")\n","\n","def load_safetensors(adapter: nn.Module, path: str, map_location=\"cpu\"):\n"," \"\"\"\n"," Load a safetensors checkpoint into the adapter.\n","\n"," Uses strict key matching. Tensors are loaded to the specified device.\n"," \"\"\"\n"," state = load_file(path, device=map_location)\n"," adapter.load_state_dict(state, strict=True)\n"," print(f\"β
Model loaded from {path}\")\n","\n","\n"],"metadata":{"id":"zpOi5svciXJ6","executionInfo":{"status":"ok","timestamp":1748636888425,"user_tz":420,"elapsed":13,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":3,"outputs":[]},{"cell_type":"markdown","source":["## Data Loader"],"metadata":{"id":"PiehSbGeWmor"}},{"cell_type":"code","source":["import torch\n","import csv\n","\n","# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n","# β Streaming Caption Dataset\n","# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n","# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n","# β Streaming Caption Dataset β 32 quality descriptors\n","# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n","import csv, random, re\n","from typing import List\n","from torch.utils.data import IterableDataset, get_worker_info\n","from huggingface_hub import hf_hub_download\n","from torch.utils.data import DataLoader\n","from pathlib import Path\n","\n","\n","class ParsedMultiCharDataset(IterableDataset):\n"," \"\"\"\n"," Streams HF-hosted caption shards and, for each text chunk:\n","\n"," β’ If it starts with βa β (case-insensitive, ignoring leading spaces),\n"," replace that leading token with a random photo/video quality\n"," descriptor followed by β, β.\n","\n"," No preliminary file scanning; every CSV line is read exactly once.\n"," \"\"\"\n","\n"," # match leading βa β only\n"," _PAT_START_A = re.compile(r\"^\\s*a\\s+\", re.IGNORECASE)\n","\n"," # 32 diverse quality descriptors\n"," _QUALITY_DESCRIPTORS: List[str] = [\n"," \"masterpiece,\",\n"," \"very aesthetic,\",\n"," \"most aesthetic,\",\n"," \"an absolutely perfect depiction of\",\n"," \"awa,\",\n"," \"very awa,\",\n"," \"dimly lit,\",\n"," \"beautifuly lit,\",\n"," \"very beautiful,\",\n"," \"masterful depiction of\",\n"," \"dedicated masterpiece,\",\n"," \"warmly lit\",\n"," \"best quality, most aesthetic,\",\n"," \"beautiful depiction of\",\n"," \"masterful artwork of\",\n"," \"high-resolution photograph,\",\n"," \"hyper-realistic image,\",\n"," \"ultra-detailed photo,\",\n"," \"studio-quality photograph,\",\n"," \"cinematic shot,\",\n"," \"4K HDR image,\",\n"," \"sharp-focus photo,\",\n"," \"professionally lit photograph,\",\n"," \"DSLR capture,\",\n"," \"film-grain photograph,\",\n"," \"bokeh-rich shot,\",\n"," \"medium-format scan,\",\n"," \"analog film still,\",\n"," \"moody cinematic frame,\",\n"," \"dramatic-lighting photo,\",\n"," \"vibrant editorial image,\",\n"," \"macro-lens close-up,\",\n"," \"aerial drone photo,\",\n"," \"soft-focus dreamlike photo,\",\n"," \"low-key studio shot,\",\n"," \"overhead product shot,\",\n"," \"golden-hour photograph,\",\n"," \"noir-style monochrome shot,\",\n"," \"vintage Polaroid scan,\",\n"," \"infrared photograph,\",\n"," \"ultra-wide panorama,\",\n"," \"tilt-shift miniature photo,\",\n"," \"long-exposure night shot,\",\n"," \"time-lapse still,\",\n"," \"splash-photography frame,\",\n"," \"fine-art print scan,\",\n"," \"astrophotography capture,\",\n"," \"score_9, score_8, score_7, score_6,\",\n"," \"score_1, score_2, score_3, score_4,\",\n"," \"masterpiece, most aesthetic,\",\n"," \"most aesthetic, very aesthetic,\",\n"," \"masterpiece, most aesthetic, very aesthetic, realistic, real,\",\n"," \"most aesthetic, realistic, real,\",\n"," \"very aesthetic, realistic, real,\",\n"," \"masterpiece, very aesthetic, realistic, real,\",\n"," \"most aesthetic, very aesthetic, realistic, real,\",\n"," \"masterpiece, very aesthetic, realistic, anime,\",\n"," \"most aesthetic, very aesthetic, realistic, anime,\",\n"," \"very aesthetic, realistic, anime,\",\n"," \"masterpiece, very aesthetic, realistic, anime,\",\n"," \"most aesthetic, very aesthetic, realistic, anime,\",\n"," \"very aesthetic, realistic, anime, anime,\",\n"," \"2d,\",\n"," \"3d,\",\n"," \"anime,\",\n"," \"real,\",\n"," \"cartoon,\"\n"," \"realistic,\",\n"," \"2d,\",\n"," \"3d,\",\n"," \"anime,\",\n"," \"real,\",\n"," \"cartoon,\"\n"," \"realistic,\",\n"," \"2d,\",\n"," \"3d,\",\n"," \"anime,\",\n"," \"real,\",\n"," \"cartoon,\"\n"," \"realistic,\",\n"," \"2d,\",\n"," \"3d,\",\n"," \"anime,\",\n"," \"real,\",\n"," \"cartoon,\"\n"," \"realistic,\",\n"," \"masterpiece, 2d,\",\n"," \"masterpiece, 3d,\",\n"," \"masterpiece, anime,\",\n"," \"masterpiece, real,\",\n"," \"masterpiece, cartoon,\",\n"," \"3d, anime,\",\n"," \"3d, anime, real,\",\n"," \"3d, anime, real, realistic,\",\n"," \"masterpiece, 3d,\",\n"," \"masterpiece, 3d, anime,\",\n"," \"masterpiece, 3d, anime, real,\",\n"," \"masterpiece, 3d, anime, real, realistic,\",\n"," \"masterpiece, anime,\",\n"," \"masterpiece, anime, real,\",\n"," \"masterpiece, anime, real, realistic,\",\n"," \"masterpiece, 3d, anime, real,\",\n"," \"very aesthetic, 3d,\",\n"," \"very aesthetic, 3d, anime,\",\n"," \"very aesthetic, 3d, anime, real,\",\n"," \"very aesthetic, 3d, anime, real, realistic,\",\n"," \"very aesthetic, 3d, anime, real,\",\n"," \"most aesthetic, 3d,\",\n"," \"most aesthetic, 3d, anime,\",\n"," \"most aesthetic, 3d, anime, real,\",\n"," \"most aesthetic, 3d, anime, real, realistic,\",\n"," \"anime, comic,\"\n"," \"manga, anime\",\n"," \"masterpiece, cartoon,\",\n"," \"masterpiece, cartoon, real,\",\n"," \"masterpiece, cartoon, real, realistic,\",\n"," \"masterpiece, cartoon, real,\",\n"," \"most aesthetic, cartoon,\",\n"," \"most aesthetic, cartoon, real,\",\n"," \"most aesthetic, cartoon, real, realistic,\",\n"," \"most aesthetic, cartoon, real,\",\n"," \"grid_a1 head,\"\n"," \"grid_a2 head,\",\n"," \"grid_a3 head,\",\n"," \"grid_a4 head,\",\n"," \"grid_a5 head,\",\n"," \"grid_b1 head,\"\n"," \"grid_b2 head,\",\n"," \"grid_b3 head,\",\n"," \"grid_b4 head,\",\n"," \"grid_b5 head,\",\n"," \"grid_c1 head,\"\n"," \"grid_c2 head,\",\n"," \"grid_c3 head,\",\n"," \"grid_c4 head,\",\n"," \"grid_c5 head,\",\n"," \"grid_d1 head,\"\n"," \"grid_d2 head,\",\n"," \"grid_d3 head,\",\n"," \"grid_d4 head,\",\n"," \"grid_d5 head,\",\n"," \"grid_e1 head,\"\n"," \"grid_e2 head,\",\n"," \"grid_e3 head,\",\n"," \"grid_e4 head,\",\n"," \"grid_e5 head,\",\n"," \"grid_a1 upper body,\"\n"," \"grid_a2 upper body,\",\n"," \"grid_a3 upper body,\",\n"," \"grid_a4 upper body,\",\n"," \"grid_a5 upper body,\",\n"," \"grid_b1 upper body,\"\n"," \"grid_b2 upper body,\",\n"," \"grid_b3 upper body,\",\n"," \"grid_b4 upper body,\",\n"," \"grid_b5 upper body,\",\n"," \"grid_c1 upper body,\"\n"," \"grid_c2 upper body,\",\n"," \"grid_c3 upper body,\",\n"," \"grid_c4 upper body,\",\n"," \"grid_c5 upper body,\",\n"," \"grid_d1 upper body,\"\n"," \"grid_d2 upper body,\",\n"," \"grid_d3 upper body,\",\n"," \"grid_d4 upper body,\",\n"," \"grid_d5 upper body,\",\n"," \"grid_e1 upper body,\"\n"," \"grid_e2 upper body,\",\n"," \"grid_e3 upper body,\",\n"," \"grid_e4 upper body,\",\n"," \"grid_e5 upper body,\",\n"," \"zone_ul upper body,\"\n"," \"zone_ur upper body,\"\n"," \"zone_ll upper body,\"\n"," \"zone_lr upper body,\",\n"," \"zone_ul head,\"\n"," \"zone_ur head,\"\n"," \"zone_ll head,\"\n"," \"zone_lr head,\",\n","\n"," #\"disgusting, 3d, lowres,\",\n"," #\"disgusting, 2d, lowres,\",\n"," #\"disgusting, cartoon, lowres,\",\n"," #\"disgusting, 3d, cartoon, lowres,\",\n"," #\"disgusting, 2d, cartoon, lowres,\",\n"," #\"disgusting\n","\n","\n"," ]\n","\n"," def __init__(\n"," self,\n"," repo_id: str,\n"," delimiter: str = \".,|,.\",\n"," start_file: int = 20,\n"," num_files: int = 80,\n"," shuffle: bool = False\n"," ):\n"," super().__init__()\n"," self.delimiter = delimiter\n"," self.paths = [\n"," hf_hub_download(\n"," repo_id,\n"," f\"captions/caption_{i + start_file:03d}.csv\",\n"," repo_type=\"dataset\",\n"," )\n"," for i in range(num_files)\n"," ]\n"," self.total_rows = 5_000_000 * num_files\n"," self.shuffle = shuffle\n","\n"," def __len__(self):\n"," return self.total_rows\n","\n"," def __iter__(self):\n"," worker = get_worker_info()\n"," paths = (\n"," self.paths\n"," if worker is None\n"," else self.paths[worker.id :: worker.num_workers]\n"," )\n","\n"," pat_a = self._PAT_START_A\n"," choose = random.choice\n"," q_pool = self._QUALITY_DESCRIPTORS\n"," delim = self.delimiter\n","\n"," for path in paths:\n"," with open(path, encoding=\"utf-8\", newline=\"\") as f:\n"," for row in csv.DictReader(f):\n"," for chunk in row.get(\"text\", \"\").split(delim):\n"," chunk = chunk.strip()\n"," if not chunk:\n"," continue\n"," if self.shuffle:\n"," random.shuffle(q_pool)\n"," # replace leading βa β with descriptor + comma\n"," chunk = pat_a.sub(choose(q_pool) + \" \", chunk, count=1)\n","\n"," yield chunk\n","\n","\n"],"metadata":{"id":"mWRkuGO5U5Ma","executionInfo":{"status":"ok","timestamp":1748636889529,"user_tz":420,"elapsed":267,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":4,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"_FBrEkfzbB0M","executionInfo":{"status":"ok","timestamp":1748636889550,"user_tz":420,"elapsed":19,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":4,"outputs":[]},{"cell_type":"markdown","source":["## Load Models"],"metadata":{"id":"RiRNq1OjXQ4R"}},{"cell_type":"code","source":["import torch\n","from transformers import (\n"," T5EncoderModel, T5TokenizerFast,\n"," CLIPTextModel, CLIPTokenizerFast\n",")\n","\n","import torch\n","from transformers import (\n"," T5EncoderModel, T5TokenizerFast,\n"," CLIPTextModel, CLIPTokenizerFast\n",")\n","\n","from safetensors.torch import save_file, load_file\n","\n","\n","# βββ Runtime Settings ββββββββββββββββββββββββββββββββββββββββ\n","DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","DTYPE = torch.float32 # π Force full FP32 precision\n","\n","# βββ Load Tokenizers βββββββββββββββββββββββββββββββββββββββββ\n","t5_tok = T5TokenizerFast.from_pretrained(\"google/flan-t5-base\")\n","clip_tok = CLIPTokenizerFast.from_pretrained(\"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k\")\n","\n","# βββ Load and Freeze T5 Model βββββββββββββββββββββββββββββββ\n","t5_mod = T5EncoderModel.from_pretrained(\"google/flan-t5-base\").to(DEVICE, dtype=DTYPE)\n","t5_mod.eval().requires_grad_(False)\n","\n","# βββ Load and Freeze CLIP Model βββββββββββββββββββββββββββββ\n","clip_mod = CLIPTextModel.from_pretrained(\"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k\").to(DEVICE, dtype=DTYPE)\n","\n","\n","print(clip_mod.state_dict().keys())\n","\n","temp_clip = load_file(\"/content/drive/MyDrive/clips/OMEGA-24-CLIP_G.safetensors\")\n","\n","print(temp_clip.keys())\n","\n","clip_mod.load_state_dict(temp_clip, strict=False)\n","clip_mod.eval().requires_grad_(False)\n","## βββ Load Dual Shunt Adapter ββββββββββββββββββββββββββββββββ\n","#adapter = DualShuntAdapter(config=ADAPTER_CONFIG).to(DEVICE, dtype=DTYPE)\n","#\n","#\n","#print(\"β
All models loaded at float32 precision and ready.\")\n","\n","# βββ Initialize Adapter from Config ββββββββββββββββββββββββββ\n","adapter = TwoStreamShuntAdapter(config=ADAPTER_CONFIG).to(DEVICE, dtype=DTYPE)\n","load_safetensors(adapter, \"/content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_e1_step_10000.safetensors\")\n","adapter.train()\n","print(\"β
T5, CLIP, and Dual Shunt Adapter loaded and cast.\")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":138,"referenced_widgets":["9bf9176fc21b4f15896d7281ced5fa17","76cf4a2d556842e497e2e64c2ba0b3fe","a6b3d0c5d84b477c82545fb809fdda1c","456b3fd649ab4c0fbb5348a765fb9613","32bcb31386ca4187ba6a554d26189502","d4696d3bbdc149be9bab0cf4664c68c4","1575e84b9dc5454d94258035abf26da1","b4e41d9dcb2c4be8bf94919c93c56a83","d96c33b98b874cf8862a2bedb3596b4c","d7f4744342ec4ebd8d2798cae2e90518","2f9c4dbc55d04276a5b9fa92de613bf5"]},"id":"NOpbg28ZXQH5","executionInfo":{"status":"ok","timestamp":1748636906923,"user_tz":420,"elapsed":12771,"user":{"displayName":"P C","userId":"00707517734723903966"}},"outputId":"8ebbca72-abed-49df-ca20-04250b38be55"},"execution_count":5,"outputs":[{"output_type":"display_data","data":{"text/plain":["Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"9bf9176fc21b4f15896d7281ced5fa17"}},"metadata":{}},{"output_type":"stream","name":"stdout","text":["odict_keys(['text_model.embeddings.token_embedding.weight', 'text_model.embeddings.position_embedding.weight', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.1.self_attn.k_proj.weight', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.1.mlp.fc1.weight', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.2.self_attn.k_proj.weight', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.2.self_attn.v_proj.weight', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.2.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.encoder.layers.2.mlp.fc2.weight', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.3.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.3.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.3.self_attn.out_proj.bias', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.3.mlp.fc2.weight', 'text_model.encoder.layers.3.mlp.fc2.bias', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight', 'text_model.encoder.layers.4.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.4.self_attn.out_proj.bias', 'text_model.encoder.layers.4.layer_norm1.weight', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.4.mlp.fc2.bias', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'text_model.encoder.layers.5.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.5.layer_norm1.weight', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.5.mlp.fc1.weight', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.5.mlp.fc2.weight', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.6.self_attn.k_proj.weight', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.6.self_attn.out_proj.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.6.layer_norm1.bias', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.bias', 'text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.7.self_attn.out_proj.bias', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.7.layer_norm1.bias', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.v_proj.weight', 'text_model.encoder.layers.8.self_attn.v_proj.bias', 'text_model.encoder.layers.8.self_attn.q_proj.weight', 'text_model.encoder.layers.8.self_attn.q_proj.bias', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.8.self_attn.out_proj.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.8.mlp.fc1.weight', 'text_model.encoder.layers.8.mlp.fc1.bias', 'text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.8.mlp.fc2.bias', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.k_proj.weight', 'text_model.encoder.layers.9.self_attn.k_proj.bias', 'text_model.encoder.layers.9.self_attn.v_proj.weight', 'text_model.encoder.layers.9.self_attn.v_proj.bias', 'text_model.encoder.layers.9.self_attn.q_proj.weight', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.9.self_attn.out_proj.weight', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_model.encoder.layers.9.layer_norm2.weight', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.10.self_attn.k_proj.weight', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.10.layer_norm1.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.10.mlp.fc1.bias', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.encoder.layers.10.layer_norm2.weight', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.encoder.layers.12.self_attn.k_proj.weight', 'text_model.encoder.layers.12.self_attn.k_proj.bias', 'text_model.encoder.layers.12.self_attn.v_proj.weight', 'text_model.encoder.layers.12.self_attn.v_proj.bias', 'text_model.encoder.layers.12.self_attn.q_proj.weight', 'text_model.encoder.layers.12.self_attn.q_proj.bias', 'text_model.encoder.layers.12.self_attn.out_proj.weight', 'text_model.encoder.layers.12.self_attn.out_proj.bias', 'text_model.encoder.layers.12.layer_norm1.weight', 'text_model.encoder.layers.12.layer_norm1.bias', 'text_model.encoder.layers.12.mlp.fc1.weight', 'text_model.encoder.layers.12.mlp.fc1.bias', 'text_model.encoder.layers.12.mlp.fc2.weight', 'text_model.encoder.layers.12.mlp.fc2.bias', 'text_model.encoder.layers.12.layer_norm2.weight', 'text_model.encoder.layers.12.layer_norm2.bias', 'text_model.encoder.layers.13.self_attn.k_proj.weight', 'text_model.encoder.layers.13.self_attn.k_proj.bias', 'text_model.encoder.layers.13.self_attn.v_proj.weight', 'text_model.encoder.layers.13.self_attn.v_proj.bias', 'text_model.encoder.layers.13.self_attn.q_proj.weight', 'text_model.encoder.layers.13.self_attn.q_proj.bias', 'text_model.encoder.layers.13.self_attn.out_proj.weight', 'text_model.encoder.layers.13.self_attn.out_proj.bias', 'text_model.encoder.layers.13.layer_norm1.weight', 'text_model.encoder.layers.13.layer_norm1.bias', 'text_model.encoder.layers.13.mlp.fc1.weight', 'text_model.encoder.layers.13.mlp.fc1.bias', 'text_model.encoder.layers.13.mlp.fc2.weight', 'text_model.encoder.layers.13.mlp.fc2.bias', 'text_model.encoder.layers.13.layer_norm2.weight', 'text_model.encoder.layers.13.layer_norm2.bias', 'text_model.encoder.layers.14.self_attn.k_proj.weight', 'text_model.encoder.layers.14.self_attn.k_proj.bias', 'text_model.encoder.layers.14.self_attn.v_proj.weight', 'text_model.encoder.layers.14.self_attn.v_proj.bias', 'text_model.encoder.layers.14.self_attn.q_proj.weight', 'text_model.encoder.layers.14.self_attn.q_proj.bias', 'text_model.encoder.layers.14.self_attn.out_proj.weight', 'text_model.encoder.layers.14.self_attn.out_proj.bias', 'text_model.encoder.layers.14.layer_norm1.weight', 'text_model.encoder.layers.14.layer_norm1.bias', 'text_model.encoder.layers.14.mlp.fc1.weight', 'text_model.encoder.layers.14.mlp.fc1.bias', 'text_model.encoder.layers.14.mlp.fc2.weight', 'text_model.encoder.layers.14.mlp.fc2.bias', 'text_model.encoder.layers.14.layer_norm2.weight', 'text_model.encoder.layers.14.layer_norm2.bias', 'text_model.encoder.layers.15.self_attn.k_proj.weight', 'text_model.encoder.layers.15.self_attn.k_proj.bias', 'text_model.encoder.layers.15.self_attn.v_proj.weight', 'text_model.encoder.layers.15.self_attn.v_proj.bias', 'text_model.encoder.layers.15.self_attn.q_proj.weight', 'text_model.encoder.layers.15.self_attn.q_proj.bias', 'text_model.encoder.layers.15.self_attn.out_proj.weight', 'text_model.encoder.layers.15.self_attn.out_proj.bias', 'text_model.encoder.layers.15.layer_norm1.weight', 'text_model.encoder.layers.15.layer_norm1.bias', 'text_model.encoder.layers.15.mlp.fc1.weight', 'text_model.encoder.layers.15.mlp.fc1.bias', 'text_model.encoder.layers.15.mlp.fc2.weight', 'text_model.encoder.layers.15.mlp.fc2.bias', 'text_model.encoder.layers.15.layer_norm2.weight', 'text_model.encoder.layers.15.layer_norm2.bias', 'text_model.encoder.layers.16.self_attn.k_proj.weight', 'text_model.encoder.layers.16.self_attn.k_proj.bias', 'text_model.encoder.layers.16.self_attn.v_proj.weight', 'text_model.encoder.layers.16.self_attn.v_proj.bias', 'text_model.encoder.layers.16.self_attn.q_proj.weight', 'text_model.encoder.layers.16.self_attn.q_proj.bias', 'text_model.encoder.layers.16.self_attn.out_proj.weight', 'text_model.encoder.layers.16.self_attn.out_proj.bias', 'text_model.encoder.layers.16.layer_norm1.weight', 'text_model.encoder.layers.16.layer_norm1.bias', 'text_model.encoder.layers.16.mlp.fc1.weight', 'text_model.encoder.layers.16.mlp.fc1.bias', 'text_model.encoder.layers.16.mlp.fc2.weight', 'text_model.encoder.layers.16.mlp.fc2.bias', 'text_model.encoder.layers.16.layer_norm2.weight', 'text_model.encoder.layers.16.layer_norm2.bias', 'text_model.encoder.layers.17.self_attn.k_proj.weight', 'text_model.encoder.layers.17.self_attn.k_proj.bias', 'text_model.encoder.layers.17.self_attn.v_proj.weight', 'text_model.encoder.layers.17.self_attn.v_proj.bias', 'text_model.encoder.layers.17.self_attn.q_proj.weight', 'text_model.encoder.layers.17.self_attn.q_proj.bias', 'text_model.encoder.layers.17.self_attn.out_proj.weight', 'text_model.encoder.layers.17.self_attn.out_proj.bias', 'text_model.encoder.layers.17.layer_norm1.weight', 'text_model.encoder.layers.17.layer_norm1.bias', 'text_model.encoder.layers.17.mlp.fc1.weight', 'text_model.encoder.layers.17.mlp.fc1.bias', 'text_model.encoder.layers.17.mlp.fc2.weight', 'text_model.encoder.layers.17.mlp.fc2.bias', 'text_model.encoder.layers.17.layer_norm2.weight', 'text_model.encoder.layers.17.layer_norm2.bias', 'text_model.encoder.layers.18.self_attn.k_proj.weight', 'text_model.encoder.layers.18.self_attn.k_proj.bias', 'text_model.encoder.layers.18.self_attn.v_proj.weight', 'text_model.encoder.layers.18.self_attn.v_proj.bias', 'text_model.encoder.layers.18.self_attn.q_proj.weight', 'text_model.encoder.layers.18.self_attn.q_proj.bias', 'text_model.encoder.layers.18.self_attn.out_proj.weight', 'text_model.encoder.layers.18.self_attn.out_proj.bias', 'text_model.encoder.layers.18.layer_norm1.weight', 'text_model.encoder.layers.18.layer_norm1.bias', 'text_model.encoder.layers.18.mlp.fc1.weight', 'text_model.encoder.layers.18.mlp.fc1.bias', 'text_model.encoder.layers.18.mlp.fc2.weight', 'text_model.encoder.layers.18.mlp.fc2.bias', 'text_model.encoder.layers.18.layer_norm2.weight', 'text_model.encoder.layers.18.layer_norm2.bias', 'text_model.encoder.layers.19.self_attn.k_proj.weight', 'text_model.encoder.layers.19.self_attn.k_proj.bias', 'text_model.encoder.layers.19.self_attn.v_proj.weight', 'text_model.encoder.layers.19.self_attn.v_proj.bias', 'text_model.encoder.layers.19.self_attn.q_proj.weight', 'text_model.encoder.layers.19.self_attn.q_proj.bias', 'text_model.encoder.layers.19.self_attn.out_proj.weight', 'text_model.encoder.layers.19.self_attn.out_proj.bias', 'text_model.encoder.layers.19.layer_norm1.weight', 'text_model.encoder.layers.19.layer_norm1.bias', 'text_model.encoder.layers.19.mlp.fc1.weight', 'text_model.encoder.layers.19.mlp.fc1.bias', 'text_model.encoder.layers.19.mlp.fc2.weight', 'text_model.encoder.layers.19.mlp.fc2.bias', 'text_model.encoder.layers.19.layer_norm2.weight', 'text_model.encoder.layers.19.layer_norm2.bias', 'text_model.encoder.layers.20.self_attn.k_proj.weight', 'text_model.encoder.layers.20.self_attn.k_proj.bias', 'text_model.encoder.layers.20.self_attn.v_proj.weight', 'text_model.encoder.layers.20.self_attn.v_proj.bias', 'text_model.encoder.layers.20.self_attn.q_proj.weight', 'text_model.encoder.layers.20.self_attn.q_proj.bias', 'text_model.encoder.layers.20.self_attn.out_proj.weight', 'text_model.encoder.layers.20.self_attn.out_proj.bias', 'text_model.encoder.layers.20.layer_norm1.weight', 'text_model.encoder.layers.20.layer_norm1.bias', 'text_model.encoder.layers.20.mlp.fc1.weight', 'text_model.encoder.layers.20.mlp.fc1.bias', 'text_model.encoder.layers.20.mlp.fc2.weight', 'text_model.encoder.layers.20.mlp.fc2.bias', 'text_model.encoder.layers.20.layer_norm2.weight', 'text_model.encoder.layers.20.layer_norm2.bias', 'text_model.encoder.layers.21.self_attn.k_proj.weight', 'text_model.encoder.layers.21.self_attn.k_proj.bias', 'text_model.encoder.layers.21.self_attn.v_proj.weight', 'text_model.encoder.layers.21.self_attn.v_proj.bias', 'text_model.encoder.layers.21.self_attn.q_proj.weight', 'text_model.encoder.layers.21.self_attn.q_proj.bias', 'text_model.encoder.layers.21.self_attn.out_proj.weight', 'text_model.encoder.layers.21.self_attn.out_proj.bias', 'text_model.encoder.layers.21.layer_norm1.weight', 'text_model.encoder.layers.21.layer_norm1.bias', 'text_model.encoder.layers.21.mlp.fc1.weight', 'text_model.encoder.layers.21.mlp.fc1.bias', 'text_model.encoder.layers.21.mlp.fc2.weight', 'text_model.encoder.layers.21.mlp.fc2.bias', 'text_model.encoder.layers.21.layer_norm2.weight', 'text_model.encoder.layers.21.layer_norm2.bias', 'text_model.encoder.layers.22.self_attn.k_proj.weight', 'text_model.encoder.layers.22.self_attn.k_proj.bias', 'text_model.encoder.layers.22.self_attn.v_proj.weight', 'text_model.encoder.layers.22.self_attn.v_proj.bias', 'text_model.encoder.layers.22.self_attn.q_proj.weight', 'text_model.encoder.layers.22.self_attn.q_proj.bias', 'text_model.encoder.layers.22.self_attn.out_proj.weight', 'text_model.encoder.layers.22.self_attn.out_proj.bias', 'text_model.encoder.layers.22.layer_norm1.weight', 'text_model.encoder.layers.22.layer_norm1.bias', 'text_model.encoder.layers.22.mlp.fc1.weight', 'text_model.encoder.layers.22.mlp.fc1.bias', 'text_model.encoder.layers.22.mlp.fc2.weight', 'text_model.encoder.layers.22.mlp.fc2.bias', 'text_model.encoder.layers.22.layer_norm2.weight', 'text_model.encoder.layers.22.layer_norm2.bias', 'text_model.encoder.layers.23.self_attn.k_proj.weight', 'text_model.encoder.layers.23.self_attn.k_proj.bias', 'text_model.encoder.layers.23.self_attn.v_proj.weight', 'text_model.encoder.layers.23.self_attn.v_proj.bias', 'text_model.encoder.layers.23.self_attn.q_proj.weight', 'text_model.encoder.layers.23.self_attn.q_proj.bias', 'text_model.encoder.layers.23.self_attn.out_proj.weight', 'text_model.encoder.layers.23.self_attn.out_proj.bias', 'text_model.encoder.layers.23.layer_norm1.weight', 'text_model.encoder.layers.23.layer_norm1.bias', 'text_model.encoder.layers.23.mlp.fc1.weight', 'text_model.encoder.layers.23.mlp.fc1.bias', 'text_model.encoder.layers.23.mlp.fc2.weight', 'text_model.encoder.layers.23.mlp.fc2.bias', 'text_model.encoder.layers.23.layer_norm2.weight', 'text_model.encoder.layers.23.layer_norm2.bias', 'text_model.encoder.layers.24.self_attn.k_proj.weight', 'text_model.encoder.layers.24.self_attn.k_proj.bias', 'text_model.encoder.layers.24.self_attn.v_proj.weight', 'text_model.encoder.layers.24.self_attn.v_proj.bias', 'text_model.encoder.layers.24.self_attn.q_proj.weight', 'text_model.encoder.layers.24.self_attn.q_proj.bias', 'text_model.encoder.layers.24.self_attn.out_proj.weight', 'text_model.encoder.layers.24.self_attn.out_proj.bias', 'text_model.encoder.layers.24.layer_norm1.weight', 'text_model.encoder.layers.24.layer_norm1.bias', 'text_model.encoder.layers.24.mlp.fc1.weight', 'text_model.encoder.layers.24.mlp.fc1.bias', 'text_model.encoder.layers.24.mlp.fc2.weight', 'text_model.encoder.layers.24.mlp.fc2.bias', 'text_model.encoder.layers.24.layer_norm2.weight', 'text_model.encoder.layers.24.layer_norm2.bias', 'text_model.encoder.layers.25.self_attn.k_proj.weight', 'text_model.encoder.layers.25.self_attn.k_proj.bias', 'text_model.encoder.layers.25.self_attn.v_proj.weight', 'text_model.encoder.layers.25.self_attn.v_proj.bias', 'text_model.encoder.layers.25.self_attn.q_proj.weight', 'text_model.encoder.layers.25.self_attn.q_proj.bias', 'text_model.encoder.layers.25.self_attn.out_proj.weight', 'text_model.encoder.layers.25.self_attn.out_proj.bias', 'text_model.encoder.layers.25.layer_norm1.weight', 'text_model.encoder.layers.25.layer_norm1.bias', 'text_model.encoder.layers.25.mlp.fc1.weight', 'text_model.encoder.layers.25.mlp.fc1.bias', 'text_model.encoder.layers.25.mlp.fc2.weight', 'text_model.encoder.layers.25.mlp.fc2.bias', 'text_model.encoder.layers.25.layer_norm2.weight', 'text_model.encoder.layers.25.layer_norm2.bias', 'text_model.encoder.layers.26.self_attn.k_proj.weight', 'text_model.encoder.layers.26.self_attn.k_proj.bias', 'text_model.encoder.layers.26.self_attn.v_proj.weight', 'text_model.encoder.layers.26.self_attn.v_proj.bias', 'text_model.encoder.layers.26.self_attn.q_proj.weight', 'text_model.encoder.layers.26.self_attn.q_proj.bias', 'text_model.encoder.layers.26.self_attn.out_proj.weight', 'text_model.encoder.layers.26.self_attn.out_proj.bias', 'text_model.encoder.layers.26.layer_norm1.weight', 'text_model.encoder.layers.26.layer_norm1.bias', 'text_model.encoder.layers.26.mlp.fc1.weight', 'text_model.encoder.layers.26.mlp.fc1.bias', 'text_model.encoder.layers.26.mlp.fc2.weight', 'text_model.encoder.layers.26.mlp.fc2.bias', 'text_model.encoder.layers.26.layer_norm2.weight', 'text_model.encoder.layers.26.layer_norm2.bias', 'text_model.encoder.layers.27.self_attn.k_proj.weight', 'text_model.encoder.layers.27.self_attn.k_proj.bias', 'text_model.encoder.layers.27.self_attn.v_proj.weight', 'text_model.encoder.layers.27.self_attn.v_proj.bias', 'text_model.encoder.layers.27.self_attn.q_proj.weight', 'text_model.encoder.layers.27.self_attn.q_proj.bias', 'text_model.encoder.layers.27.self_attn.out_proj.weight', 'text_model.encoder.layers.27.self_attn.out_proj.bias', 'text_model.encoder.layers.27.layer_norm1.weight', 'text_model.encoder.layers.27.layer_norm1.bias', 'text_model.encoder.layers.27.mlp.fc1.weight', 'text_model.encoder.layers.27.mlp.fc1.bias', 'text_model.encoder.layers.27.mlp.fc2.weight', 'text_model.encoder.layers.27.mlp.fc2.bias', 'text_model.encoder.layers.27.layer_norm2.weight', 'text_model.encoder.layers.27.layer_norm2.bias', 'text_model.encoder.layers.28.self_attn.k_proj.weight', 'text_model.encoder.layers.28.self_attn.k_proj.bias', 'text_model.encoder.layers.28.self_attn.v_proj.weight', 'text_model.encoder.layers.28.self_attn.v_proj.bias', 'text_model.encoder.layers.28.self_attn.q_proj.weight', 'text_model.encoder.layers.28.self_attn.q_proj.bias', 'text_model.encoder.layers.28.self_attn.out_proj.weight', 'text_model.encoder.layers.28.self_attn.out_proj.bias', 'text_model.encoder.layers.28.layer_norm1.weight', 'text_model.encoder.layers.28.layer_norm1.bias', 'text_model.encoder.layers.28.mlp.fc1.weight', 'text_model.encoder.layers.28.mlp.fc1.bias', 'text_model.encoder.layers.28.mlp.fc2.weight', 'text_model.encoder.layers.28.mlp.fc2.bias', 'text_model.encoder.layers.28.layer_norm2.weight', 'text_model.encoder.layers.28.layer_norm2.bias', 'text_model.encoder.layers.29.self_attn.k_proj.weight', 'text_model.encoder.layers.29.self_attn.k_proj.bias', 'text_model.encoder.layers.29.self_attn.v_proj.weight', 'text_model.encoder.layers.29.self_attn.v_proj.bias', 'text_model.encoder.layers.29.self_attn.q_proj.weight', 'text_model.encoder.layers.29.self_attn.q_proj.bias', 'text_model.encoder.layers.29.self_attn.out_proj.weight', 'text_model.encoder.layers.29.self_attn.out_proj.bias', 'text_model.encoder.layers.29.layer_norm1.weight', 'text_model.encoder.layers.29.layer_norm1.bias', 'text_model.encoder.layers.29.mlp.fc1.weight', 'text_model.encoder.layers.29.mlp.fc1.bias', 'text_model.encoder.layers.29.mlp.fc2.weight', 'text_model.encoder.layers.29.mlp.fc2.bias', 'text_model.encoder.layers.29.layer_norm2.weight', 'text_model.encoder.layers.29.layer_norm2.bias', 'text_model.encoder.layers.30.self_attn.k_proj.weight', 'text_model.encoder.layers.30.self_attn.k_proj.bias', 'text_model.encoder.layers.30.self_attn.v_proj.weight', 'text_model.encoder.layers.30.self_attn.v_proj.bias', 'text_model.encoder.layers.30.self_attn.q_proj.weight', 'text_model.encoder.layers.30.self_attn.q_proj.bias', 'text_model.encoder.layers.30.self_attn.out_proj.weight', 'text_model.encoder.layers.30.self_attn.out_proj.bias', 'text_model.encoder.layers.30.layer_norm1.weight', 'text_model.encoder.layers.30.layer_norm1.bias', 'text_model.encoder.layers.30.mlp.fc1.weight', 'text_model.encoder.layers.30.mlp.fc1.bias', 'text_model.encoder.layers.30.mlp.fc2.weight', 'text_model.encoder.layers.30.mlp.fc2.bias', 'text_model.encoder.layers.30.layer_norm2.weight', 'text_model.encoder.layers.30.layer_norm2.bias', 'text_model.encoder.layers.31.self_attn.k_proj.weight', 'text_model.encoder.layers.31.self_attn.k_proj.bias', 'text_model.encoder.layers.31.self_attn.v_proj.weight', 'text_model.encoder.layers.31.self_attn.v_proj.bias', 'text_model.encoder.layers.31.self_attn.q_proj.weight', 'text_model.encoder.layers.31.self_attn.q_proj.bias', 'text_model.encoder.layers.31.self_attn.out_proj.weight', 'text_model.encoder.layers.31.self_attn.out_proj.bias', 'text_model.encoder.layers.31.layer_norm1.weight', 'text_model.encoder.layers.31.layer_norm1.bias', 'text_model.encoder.layers.31.mlp.fc1.weight', 'text_model.encoder.layers.31.mlp.fc1.bias', 'text_model.encoder.layers.31.mlp.fc2.weight', 'text_model.encoder.layers.31.mlp.fc2.bias', 'text_model.encoder.layers.31.layer_norm2.weight', 'text_model.encoder.layers.31.layer_norm2.bias', 'text_model.final_layer_norm.weight', 'text_model.final_layer_norm.bias'])\n","dict_keys(['logit_scale', 'text_model.embeddings.position_embedding.weight', 'text_model.embeddings.token_embedding.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.1.mlp.fc1.weight', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.1.self_attn.k_proj.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.10.layer_norm1.bias', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.10.layer_norm2.weight', 'text_model.encoder.layers.10.mlp.fc1.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.10.self_attn.k_proj.weight', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.12.layer_norm1.bias', 'text_model.encoder.layers.12.layer_norm1.weight', 'text_model.encoder.layers.12.layer_norm2.bias', 'text_model.encoder.layers.12.layer_norm2.weight', 'text_model.encoder.layers.12.mlp.fc1.bias', 'text_model.encoder.layers.12.mlp.fc1.weight', 'text_model.encoder.layers.12.mlp.fc2.bias', 'text_model.encoder.layers.12.mlp.fc2.weight', 'text_model.encoder.layers.12.self_attn.k_proj.bias', 'text_model.encoder.layers.12.self_attn.k_proj.weight', 'text_model.encoder.layers.12.self_attn.out_proj.bias', 'text_model.encoder.layers.12.self_attn.out_proj.weight', 'text_model.encoder.layers.12.self_attn.q_proj.bias', 'text_model.encoder.layers.12.self_attn.q_proj.weight', 'text_model.encoder.layers.12.self_attn.v_proj.bias', 'text_model.encoder.layers.12.self_attn.v_proj.weight', 'text_model.encoder.layers.13.layer_norm1.bias', 'text_model.encoder.layers.13.layer_norm1.weight', 'text_model.encoder.layers.13.layer_norm2.bias', 'text_model.encoder.layers.13.layer_norm2.weight', 'text_model.encoder.layers.13.mlp.fc1.bias', 'text_model.encoder.layers.13.mlp.fc1.weight', 'text_model.encoder.layers.13.mlp.fc2.bias', 'text_model.encoder.layers.13.mlp.fc2.weight', 'text_model.encoder.layers.13.self_attn.k_proj.bias', 'text_model.encoder.layers.13.self_attn.k_proj.weight', 'text_model.encoder.layers.13.self_attn.out_proj.bias', 'text_model.encoder.layers.13.self_attn.out_proj.weight', 'text_model.encoder.layers.13.self_attn.q_proj.bias', 'text_model.encoder.layers.13.self_attn.q_proj.weight', 'text_model.encoder.layers.13.self_attn.v_proj.bias', 'text_model.encoder.layers.13.self_attn.v_proj.weight', 'text_model.encoder.layers.14.layer_norm1.bias', 'text_model.encoder.layers.14.layer_norm1.weight', 'text_model.encoder.layers.14.layer_norm2.bias', 'text_model.encoder.layers.14.layer_norm2.weight', 'text_model.encoder.layers.14.mlp.fc1.bias', 'text_model.encoder.layers.14.mlp.fc1.weight', 'text_model.encoder.layers.14.mlp.fc2.bias', 'text_model.encoder.layers.14.mlp.fc2.weight', 'text_model.encoder.layers.14.self_attn.k_proj.bias', 'text_model.encoder.layers.14.self_attn.k_proj.weight', 'text_model.encoder.layers.14.self_attn.out_proj.bias', 'text_model.encoder.layers.14.self_attn.out_proj.weight', 'text_model.encoder.layers.14.self_attn.q_proj.bias', 'text_model.encoder.layers.14.self_attn.q_proj.weight', 'text_model.encoder.layers.14.self_attn.v_proj.bias', 'text_model.encoder.layers.14.self_attn.v_proj.weight', 'text_model.encoder.layers.15.layer_norm1.bias', 'text_model.encoder.layers.15.layer_norm1.weight', 'text_model.encoder.layers.15.layer_norm2.bias', 'text_model.encoder.layers.15.layer_norm2.weight', 'text_model.encoder.layers.15.mlp.fc1.bias', 'text_model.encoder.layers.15.mlp.fc1.weight', 'text_model.encoder.layers.15.mlp.fc2.bias', 'text_model.encoder.layers.15.mlp.fc2.weight', 'text_model.encoder.layers.15.self_attn.k_proj.bias', 'text_model.encoder.layers.15.self_attn.k_proj.weight', 'text_model.encoder.layers.15.self_attn.out_proj.bias', 'text_model.encoder.layers.15.self_attn.out_proj.weight', 'text_model.encoder.layers.15.self_attn.q_proj.bias', 'text_model.encoder.layers.15.self_attn.q_proj.weight', 'text_model.encoder.layers.15.self_attn.v_proj.bias', 'text_model.encoder.layers.15.self_attn.v_proj.weight', 'text_model.encoder.layers.16.layer_norm1.bias', 'text_model.encoder.layers.16.layer_norm1.weight', 'text_model.encoder.layers.16.layer_norm2.bias', 'text_model.encoder.layers.16.layer_norm2.weight', 'text_model.encoder.layers.16.mlp.fc1.bias', 'text_model.encoder.layers.16.mlp.fc1.weight', 'text_model.encoder.layers.16.mlp.fc2.bias', 'text_model.encoder.layers.16.mlp.fc2.weight', 'text_model.encoder.layers.16.self_attn.k_proj.bias', 'text_model.encoder.layers.16.self_attn.k_proj.weight', 'text_model.encoder.layers.16.self_attn.out_proj.bias', 'text_model.encoder.layers.16.self_attn.out_proj.weight', 'text_model.encoder.layers.16.self_attn.q_proj.bias', 'text_model.encoder.layers.16.self_attn.q_proj.weight', 'text_model.encoder.layers.16.self_attn.v_proj.bias', 'text_model.encoder.layers.16.self_attn.v_proj.weight', 'text_model.encoder.layers.17.layer_norm1.bias', 'text_model.encoder.layers.17.layer_norm1.weight', 'text_model.encoder.layers.17.layer_norm2.bias', 'text_model.encoder.layers.17.layer_norm2.weight', 'text_model.encoder.layers.17.mlp.fc1.bias', 'text_model.encoder.layers.17.mlp.fc1.weight', 'text_model.encoder.layers.17.mlp.fc2.bias', 'text_model.encoder.layers.17.mlp.fc2.weight', 'text_model.encoder.layers.17.self_attn.k_proj.bias', 'text_model.encoder.layers.17.self_attn.k_proj.weight', 'text_model.encoder.layers.17.self_attn.out_proj.bias', 'text_model.encoder.layers.17.self_attn.out_proj.weight', 'text_model.encoder.layers.17.self_attn.q_proj.bias', 'text_model.encoder.layers.17.self_attn.q_proj.weight', 'text_model.encoder.layers.17.self_attn.v_proj.bias', 'text_model.encoder.layers.17.self_attn.v_proj.weight', 'text_model.encoder.layers.18.layer_norm1.bias', 'text_model.encoder.layers.18.layer_norm1.weight', 'text_model.encoder.layers.18.layer_norm2.bias', 'text_model.encoder.layers.18.layer_norm2.weight', 'text_model.encoder.layers.18.mlp.fc1.bias', 'text_model.encoder.layers.18.mlp.fc1.weight', 'text_model.encoder.layers.18.mlp.fc2.bias', 'text_model.encoder.layers.18.mlp.fc2.weight', 'text_model.encoder.layers.18.self_attn.k_proj.bias', 'text_model.encoder.layers.18.self_attn.k_proj.weight', 'text_model.encoder.layers.18.self_attn.out_proj.bias', 'text_model.encoder.layers.18.self_attn.out_proj.weight', 'text_model.encoder.layers.18.self_attn.q_proj.bias', 'text_model.encoder.layers.18.self_attn.q_proj.weight', 'text_model.encoder.layers.18.self_attn.v_proj.bias', 'text_model.encoder.layers.18.self_attn.v_proj.weight', 'text_model.encoder.layers.19.layer_norm1.bias', 'text_model.encoder.layers.19.layer_norm1.weight', 'text_model.encoder.layers.19.layer_norm2.bias', 'text_model.encoder.layers.19.layer_norm2.weight', 'text_model.encoder.layers.19.mlp.fc1.bias', 'text_model.encoder.layers.19.mlp.fc1.weight', 'text_model.encoder.layers.19.mlp.fc2.bias', 'text_model.encoder.layers.19.mlp.fc2.weight', 'text_model.encoder.layers.19.self_attn.k_proj.bias', 'text_model.encoder.layers.19.self_attn.k_proj.weight', 'text_model.encoder.layers.19.self_attn.out_proj.bias', 'text_model.encoder.layers.19.self_attn.out_proj.weight', 'text_model.encoder.layers.19.self_attn.q_proj.bias', 'text_model.encoder.layers.19.self_attn.q_proj.weight', 'text_model.encoder.layers.19.self_attn.v_proj.bias', 'text_model.encoder.layers.19.self_attn.v_proj.weight', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.encoder.layers.2.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.2.mlp.fc2.weight', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.2.self_attn.k_proj.weight', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.2.self_attn.v_proj.weight', 'text_model.encoder.layers.20.layer_norm1.bias', 'text_model.encoder.layers.20.layer_norm1.weight', 'text_model.encoder.layers.20.layer_norm2.bias', 'text_model.encoder.layers.20.layer_norm2.weight', 'text_model.encoder.layers.20.mlp.fc1.bias', 'text_model.encoder.layers.20.mlp.fc1.weight', 'text_model.encoder.layers.20.mlp.fc2.bias', 'text_model.encoder.layers.20.mlp.fc2.weight', 'text_model.encoder.layers.20.self_attn.k_proj.bias', 'text_model.encoder.layers.20.self_attn.k_proj.weight', 'text_model.encoder.layers.20.self_attn.out_proj.bias', 'text_model.encoder.layers.20.self_attn.out_proj.weight', 'text_model.encoder.layers.20.self_attn.q_proj.bias', 'text_model.encoder.layers.20.self_attn.q_proj.weight', 'text_model.encoder.layers.20.self_attn.v_proj.bias', 'text_model.encoder.layers.20.self_attn.v_proj.weight', 'text_model.encoder.layers.21.layer_norm1.bias', 'text_model.encoder.layers.21.layer_norm1.weight', 'text_model.encoder.layers.21.layer_norm2.bias', 'text_model.encoder.layers.21.layer_norm2.weight', 'text_model.encoder.layers.21.mlp.fc1.bias', 'text_model.encoder.layers.21.mlp.fc1.weight', 'text_model.encoder.layers.21.mlp.fc2.bias', 'text_model.encoder.layers.21.mlp.fc2.weight', 'text_model.encoder.layers.21.self_attn.k_proj.bias', 'text_model.encoder.layers.21.self_attn.k_proj.weight', 'text_model.encoder.layers.21.self_attn.out_proj.bias', 'text_model.encoder.layers.21.self_attn.out_proj.weight', 'text_model.encoder.layers.21.self_attn.q_proj.bias', 'text_model.encoder.layers.21.self_attn.q_proj.weight', 'text_model.encoder.layers.21.self_attn.v_proj.bias', 'text_model.encoder.layers.21.self_attn.v_proj.weight', 'text_model.encoder.layers.22.layer_norm1.bias', 'text_model.encoder.layers.22.layer_norm1.weight', 'text_model.encoder.layers.22.layer_norm2.bias', 'text_model.encoder.layers.22.layer_norm2.weight', 'text_model.encoder.layers.22.mlp.fc1.bias', 'text_model.encoder.layers.22.mlp.fc1.weight', 'text_model.encoder.layers.22.mlp.fc2.bias', 'text_model.encoder.layers.22.mlp.fc2.weight', 'text_model.encoder.layers.22.self_attn.k_proj.bias', 'text_model.encoder.layers.22.self_attn.k_proj.weight', 'text_model.encoder.layers.22.self_attn.out_proj.bias', 'text_model.encoder.layers.22.self_attn.out_proj.weight', 'text_model.encoder.layers.22.self_attn.q_proj.bias', 'text_model.encoder.layers.22.self_attn.q_proj.weight', 'text_model.encoder.layers.22.self_attn.v_proj.bias', 'text_model.encoder.layers.22.self_attn.v_proj.weight', 'text_model.encoder.layers.23.layer_norm1.bias', 'text_model.encoder.layers.23.layer_norm1.weight', 'text_model.encoder.layers.23.layer_norm2.bias', 'text_model.encoder.layers.23.layer_norm2.weight', 'text_model.encoder.layers.23.mlp.fc1.bias', 'text_model.encoder.layers.23.mlp.fc1.weight', 'text_model.encoder.layers.23.mlp.fc2.bias', 'text_model.encoder.layers.23.mlp.fc2.weight', 'text_model.encoder.layers.23.self_attn.k_proj.bias', 'text_model.encoder.layers.23.self_attn.k_proj.weight', 'text_model.encoder.layers.23.self_attn.out_proj.bias', 'text_model.encoder.layers.23.self_attn.out_proj.weight', 'text_model.encoder.layers.23.self_attn.q_proj.bias', 'text_model.encoder.layers.23.self_attn.q_proj.weight', 'text_model.encoder.layers.23.self_attn.v_proj.bias', 'text_model.encoder.layers.23.self_attn.v_proj.weight', 'text_model.encoder.layers.24.layer_norm1.bias', 'text_model.encoder.layers.24.layer_norm1.weight', 'text_model.encoder.layers.24.layer_norm2.bias', 'text_model.encoder.layers.24.layer_norm2.weight', 'text_model.encoder.layers.24.mlp.fc1.bias', 'text_model.encoder.layers.24.mlp.fc1.weight', 'text_model.encoder.layers.24.mlp.fc2.bias', 'text_model.encoder.layers.24.mlp.fc2.weight', 'text_model.encoder.layers.24.self_attn.k_proj.bias', 'text_model.encoder.layers.24.self_attn.k_proj.weight', 'text_model.encoder.layers.24.self_attn.out_proj.bias', 'text_model.encoder.layers.24.self_attn.out_proj.weight', 'text_model.encoder.layers.24.self_attn.q_proj.bias', 'text_model.encoder.layers.24.self_attn.q_proj.weight', 'text_model.encoder.layers.24.self_attn.v_proj.bias', 'text_model.encoder.layers.24.self_attn.v_proj.weight', 'text_model.encoder.layers.25.layer_norm1.bias', 'text_model.encoder.layers.25.layer_norm1.weight', 'text_model.encoder.layers.25.layer_norm2.bias', 'text_model.encoder.layers.25.layer_norm2.weight', 'text_model.encoder.layers.25.mlp.fc1.bias', 'text_model.encoder.layers.25.mlp.fc1.weight', 'text_model.encoder.layers.25.mlp.fc2.bias', 'text_model.encoder.layers.25.mlp.fc2.weight', 'text_model.encoder.layers.25.self_attn.k_proj.bias', 'text_model.encoder.layers.25.self_attn.k_proj.weight', 'text_model.encoder.layers.25.self_attn.out_proj.bias', 'text_model.encoder.layers.25.self_attn.out_proj.weight', 'text_model.encoder.layers.25.self_attn.q_proj.bias', 'text_model.encoder.layers.25.self_attn.q_proj.weight', 'text_model.encoder.layers.25.self_attn.v_proj.bias', 'text_model.encoder.layers.25.self_attn.v_proj.weight', 'text_model.encoder.layers.26.layer_norm1.bias', 'text_model.encoder.layers.26.layer_norm1.weight', 'text_model.encoder.layers.26.layer_norm2.bias', 'text_model.encoder.layers.26.layer_norm2.weight', 'text_model.encoder.layers.26.mlp.fc1.bias', 'text_model.encoder.layers.26.mlp.fc1.weight', 'text_model.encoder.layers.26.mlp.fc2.bias', 'text_model.encoder.layers.26.mlp.fc2.weight', 'text_model.encoder.layers.26.self_attn.k_proj.bias', 'text_model.encoder.layers.26.self_attn.k_proj.weight', 'text_model.encoder.layers.26.self_attn.out_proj.bias', 'text_model.encoder.layers.26.self_attn.out_proj.weight', 'text_model.encoder.layers.26.self_attn.q_proj.bias', 'text_model.encoder.layers.26.self_attn.q_proj.weight', 'text_model.encoder.layers.26.self_attn.v_proj.bias', 'text_model.encoder.layers.26.self_attn.v_proj.weight', 'text_model.encoder.layers.27.layer_norm1.bias', 'text_model.encoder.layers.27.layer_norm1.weight', 'text_model.encoder.layers.27.layer_norm2.bias', 'text_model.encoder.layers.27.layer_norm2.weight', 'text_model.encoder.layers.27.mlp.fc1.bias', 'text_model.encoder.layers.27.mlp.fc1.weight', 'text_model.encoder.layers.27.mlp.fc2.bias', 'text_model.encoder.layers.27.mlp.fc2.weight', 'text_model.encoder.layers.27.self_attn.k_proj.bias', 'text_model.encoder.layers.27.self_attn.k_proj.weight', 'text_model.encoder.layers.27.self_attn.out_proj.bias', 'text_model.encoder.layers.27.self_attn.out_proj.weight', 'text_model.encoder.layers.27.self_attn.q_proj.bias', 'text_model.encoder.layers.27.self_attn.q_proj.weight', 'text_model.encoder.layers.27.self_attn.v_proj.bias', 'text_model.encoder.layers.27.self_attn.v_proj.weight', 'text_model.encoder.layers.28.layer_norm1.bias', 'text_model.encoder.layers.28.layer_norm1.weight', 'text_model.encoder.layers.28.layer_norm2.bias', 'text_model.encoder.layers.28.layer_norm2.weight', 'text_model.encoder.layers.28.mlp.fc1.bias', 'text_model.encoder.layers.28.mlp.fc1.weight', 'text_model.encoder.layers.28.mlp.fc2.bias', 'text_model.encoder.layers.28.mlp.fc2.weight', 'text_model.encoder.layers.28.self_attn.k_proj.bias', 'text_model.encoder.layers.28.self_attn.k_proj.weight', 'text_model.encoder.layers.28.self_attn.out_proj.bias', 'text_model.encoder.layers.28.self_attn.out_proj.weight', 'text_model.encoder.layers.28.self_attn.q_proj.bias', 'text_model.encoder.layers.28.self_attn.q_proj.weight', 'text_model.encoder.layers.28.self_attn.v_proj.bias', 'text_model.encoder.layers.28.self_attn.v_proj.weight', 'text_model.encoder.layers.29.layer_norm1.bias', 'text_model.encoder.layers.29.layer_norm1.weight', 'text_model.encoder.layers.29.layer_norm2.bias', 'text_model.encoder.layers.29.layer_norm2.weight', 'text_model.encoder.layers.29.mlp.fc1.bias', 'text_model.encoder.layers.29.mlp.fc1.weight', 'text_model.encoder.layers.29.mlp.fc2.bias', 'text_model.encoder.layers.29.mlp.fc2.weight', 'text_model.encoder.layers.29.self_attn.k_proj.bias', 'text_model.encoder.layers.29.self_attn.k_proj.weight', 'text_model.encoder.layers.29.self_attn.out_proj.bias', 'text_model.encoder.layers.29.self_attn.out_proj.weight', 'text_model.encoder.layers.29.self_attn.q_proj.bias', 'text_model.encoder.layers.29.self_attn.q_proj.weight', 'text_model.encoder.layers.29.self_attn.v_proj.bias', 'text_model.encoder.layers.29.self_attn.v_proj.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.3.mlp.fc2.bias', 'text_model.encoder.layers.3.mlp.fc2.weight', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.3.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.out_proj.bias', 'text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.3.self_attn.v_proj.weight', 'text_model.encoder.layers.30.layer_norm1.bias', 'text_model.encoder.layers.30.layer_norm1.weight', 'text_model.encoder.layers.30.layer_norm2.bias', 'text_model.encoder.layers.30.layer_norm2.weight', 'text_model.encoder.layers.30.mlp.fc1.bias', 'text_model.encoder.layers.30.mlp.fc1.weight', 'text_model.encoder.layers.30.mlp.fc2.bias', 'text_model.encoder.layers.30.mlp.fc2.weight', 'text_model.encoder.layers.30.self_attn.k_proj.bias', 'text_model.encoder.layers.30.self_attn.k_proj.weight', 'text_model.encoder.layers.30.self_attn.out_proj.bias', 'text_model.encoder.layers.30.self_attn.out_proj.weight', 'text_model.encoder.layers.30.self_attn.q_proj.bias', 'text_model.encoder.layers.30.self_attn.q_proj.weight', 'text_model.encoder.layers.30.self_attn.v_proj.bias', 'text_model.encoder.layers.30.self_attn.v_proj.weight', 'text_model.encoder.layers.31.layer_norm1.bias', 'text_model.encoder.layers.31.layer_norm1.weight', 'text_model.encoder.layers.31.layer_norm2.bias', 'text_model.encoder.layers.31.layer_norm2.weight', 'text_model.encoder.layers.31.mlp.fc1.bias', 'text_model.encoder.layers.31.mlp.fc1.weight', 'text_model.encoder.layers.31.mlp.fc2.bias', 'text_model.encoder.layers.31.mlp.fc2.weight', 'text_model.encoder.layers.31.self_attn.k_proj.bias', 'text_model.encoder.layers.31.self_attn.k_proj.weight', 'text_model.encoder.layers.31.self_attn.out_proj.bias', 'text_model.encoder.layers.31.self_attn.out_proj.weight', 'text_model.encoder.layers.31.self_attn.q_proj.bias', 'text_model.encoder.layers.31.self_attn.q_proj.weight', 'text_model.encoder.layers.31.self_attn.v_proj.bias', 'text_model.encoder.layers.31.self_attn.v_proj.weight', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.4.layer_norm1.weight', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.4.mlp.fc2.bias', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.4.self_attn.out_proj.bias', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.4.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.5.layer_norm1.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.5.mlp.fc1.weight', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.5.mlp.fc2.weight', 'text_model.encoder.layers.5.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.6.layer_norm1.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.6.self_attn.k_proj.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.6.self_attn.out_proj.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.7.layer_norm1.bias', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.7.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.bias', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.8.mlp.fc1.bias', 'text_model.encoder.layers.8.mlp.fc1.weight', 'text_model.encoder.layers.8.mlp.fc2.bias', 'text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.8.self_attn.out_proj.bias', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.8.self_attn.q_proj.bias', 'text_model.encoder.layers.8.self_attn.q_proj.weight', 'text_model.encoder.layers.8.self_attn.v_proj.bias', 'text_model.encoder.layers.8.self_attn.v_proj.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.9.layer_norm2.weight', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.9.self_attn.k_proj.bias', 'text_model.encoder.layers.9.self_attn.k_proj.weight', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.9.self_attn.out_proj.weight', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.9.self_attn.q_proj.weight', 'text_model.encoder.layers.9.self_attn.v_proj.bias', 'text_model.encoder.layers.9.self_attn.v_proj.weight', 'text_model.final_layer_norm.bias', 'text_model.final_layer_norm.weight', 'text_projection.weight'])\n","β
Model loaded from /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_e1_step_10000.safetensors\n","β
T5, CLIP, and Dual Shunt Adapter loaded and cast.\n"]}]},{"cell_type":"code","source":["\n","clip_mod.load_state_dict(temp_clip, strict=False)\n","clip_mod = torch.compile(clip_mod, mode=\"reduce-overhead\", dtype=torch.float32)\n","## βββ Load Dual Shunt Adapter ββββββββββββββββββββββββββββββββ"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":176},"id":"URcbaWEJfiKJ","executionInfo":{"status":"error","timestamp":1748636908022,"user_tz":420,"elapsed":1097,"user":{"displayName":"P C","userId":"00707517734723903966"}},"outputId":"81dc483e-e3e0-4295-98c0-190ac89a6237"},"execution_count":6,"outputs":[{"output_type":"error","ename":"TypeError","evalue":"compile() got an unexpected keyword argument 'dtype'","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-6-56deb4d1f78d>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mclip_mod\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtemp_clip\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstrict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mclip_mod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclip_mod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"reduce-overhead\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;31m## βββ Load Dual Shunt Adapter ββββββββββββββββββββββββββββββββ\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mTypeError\u001b[0m: compile() got an unexpected keyword argument 'dtype'"]}]},{"cell_type":"markdown","source":["# Noise Train"],"metadata":{"id":"pj2a104jbp42"}},{"cell_type":"code","source":["import torch\n","import torch.nn.functional as F\n","from tqdm import tqdm\n","from pathlib import Path\n","\n","\n","def train_dual_shunt(\n"," t5_tok, clip_tok, t5_mod, clip_mod, adapter,\n"," repo_id: str,\n"," batch_size: int = 256,\n"," epochs: int = 1,\n"," lr: float = 5e-4,\n"," device: str = \"cuda\",\n"," g_target: float = 7.5,\n"," Ξ»_delta: float = 1.0,\n"," Ξ»_gate: float = 0.15,\n"," Ξ»_ent: float = 0.02,\n"," Ξ»_tau: float = 0.02,\n"," Ξ»_anchor: float= 0.12,\n"," Ξ»_guid: float = 0.1,\n"," save_every: int = 1000,\n"," max_token_length: int = 77,\n"," seed: int = 42,\n"," t5_prompt_mode: str = \"real_raw\",\n"," output_dir: str = \"/content/drive/MyDrive/dual_shunt_runs_omega_g\"\n","):\n"," import torch\n"," import torch.nn.functional as F\n"," from tqdm import tqdm\n"," from pathlib import Path\n"," import random\n","\n"," def extract_sparse_tokens(text: str, keep_min: int = 1, keep_max: int = 5, keep_prob: float = 0.3) -> str:\n"," words = text.strip().split()\n"," keep_count = min(len(words), random.randint(keep_min, keep_max))\n"," indices = sorted(random.sample(range(len(words)), k=keep_count))\n"," return \" \".join(words[i] for i in indices)\n","\n"," def seed_everything(seed: int = 42):\n"," torch.manual_seed(seed)\n"," torch.cuda.manual_seed_all(seed)\n"," torch.backends.cudnn.deterministic = True\n"," torch.backends.cudnn.benchmark = False\n","\n"," LOG_2PI = 1.83787706641\n","\n"," def hetero_loss(delta, target, log_sigma):\n"," log_sigma = log_sigma.clamp(-5.0, 5.0)\n"," inv_var = torch.exp(-log_sigma)\n"," return 0.5 * (inv_var * (delta - target)**2 + log_sigma + LOG_2PI).mean()\n","\n"," def entropy(p: torch.Tensor) -> torch.Tensor:\n"," p = p / (p.sum(-1, keepdim=True) + 1e-9)\n"," return -(p * (p + 1e-9).log()).sum(-1).mean()\n","\n"," # βββ Setup βββββββββββββββββββββββββββββββββββββββββββββββ\n"," seed_everything(seed)\n"," device = torch.device(device)\n","\n"," ds = ParsedMultiCharDataset(repo_id, start_file=50, num_files=6, shuffle=True)\n"," dl = torch.utils.data.DataLoader(\n"," ds, batch_size=batch_size, num_workers=4,\n"," drop_last=True, persistent_workers=True,\n"," generator=torch.Generator().manual_seed(seed),\n"," prefetch_factor=8\n"," )\n","\n"," output_dir = Path(output_dir)\n"," output_dir.mkdir(parents=True, exist_ok=True)\n"," global_step = 0\n","\n"," opt = torch.optim.AdamW(adapter.parameters(), lr=lr)\n"," scaler = torch.amp.GradScaler(device=\"cuda\")\n","\n"," # βββ Training Loop βββββββββββββββββββββββββββββββββββββββ\n"," for epoch in range(1, epochs + 1):\n"," pbar = tqdm(dl, desc=f\"Epoch {epoch}/{epochs}\")\n"," for texts in pbar:\n"," global_step += 1\n","\n"," with torch.autocast(device.type, dtype=torch.bfloat16):\n"," # T5 prompt encoding\n"," t5_inputs = t5_tok(texts, padding=True, truncation=True,\n"," max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," with torch.no_grad():\n"," t5_seq = t5_mod(**t5_inputs).last_hidden_state\n","\n"," # CLIP (plain)\n"," clip_inputs = clip_tok(texts, padding=\"max_length\", truncation=True,\n"," max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," with torch.no_grad():\n"," clip_seq_plain = clip_mod(**clip_inputs).last_hidden_state\n","\n"," # CLIP (captioned target)\n","\n"," # for null noise training\n"," cap_texts = [f\"{extract_sparse_tokens(t)}\" for t in texts]\n"," clip_inputs_cap = clip_tok(cap_texts, padding=\"max_length\", truncation=True,\n"," max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," with torch.no_grad():\n"," clip_seq_cap = clip_mod(**clip_inputs_cap).last_hidden_state\n","\n"," #with torch.no_grad():\n"," # summary_ids = t5_mod.generate(**t5_inputs, max_length=max_token_length)\n"," # cap_texts = t5_tok.batch_decode(summary_ids, skip_special_tokens=True)\n","\n"," # Encode summaries into CLIP\n"," #clip_inputs_cap = clip_tok(cap_texts, padding=\"max_length\", truncation=True,\n"," # max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," #with torch.no_grad():\n"," # clip_seq_cap = clip_mod(**clip_inputs_cap).last_hidden_state\n","\n"," # Adapter forward\n"," anchor, delta, log_sigma, attn_t2c, attn_c2t, tau, g_pred, gate = adapter(t5_seq, clip_seq_plain)\n"," delta_tgt = clip_seq_cap - clip_seq_plain\n","\n"," # π§ SAFE attention projection for auxiliary losses\n"," with torch.no_grad():\n"," t5_b = adapter.proj_t5(t5_seq)\n"," clip_b = adapter.proj_clip(clip_seq_plain)\n"," t2c_base, _ = adapter.cross_t2c(t5_b, clip_b, clip_b)\n"," c2t_base, _ = adapter.cross_c2t(clip_b, t5_b, t5_b)\n","\n"," pocket = adapter.pocket_blocks(t2c_base.detach())\n"," loss_pocket = F.mse_loss(pocket, t2c_base.detach()) * 0.05\n","\n"," fused_h = adapter.fuse(torch.cat([\n"," pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1),\n"," c2t_base\n"," ], dim=-1))\n"," loss_hidden = (fused_h.norm(dim=-1).mean() - 1.0).abs() * 0.01\n","\n"," # Core losses\n"," loss_delta = hetero_loss(delta, delta_tgt, log_sigma) * Ξ»_delta\n"," loss_gate = (gate.mean() - 0.25).abs() * Ξ»_gate\n"," loss_ent = 0.5 * (entropy(attn_t2c) + entropy(attn_c2t)) * Ξ»_ent\n"," loss_tau = tau.abs().mean() * Ξ»_tau\n"," loss_anchor = (1 - F.cosine_similarity(anchor.mean(1), clip_seq_plain.mean(1), dim=-1).mean()) * Ξ»_anchor\n"," loss_guid = F.mse_loss(g_pred, torch.full_like(g_pred, g_target)) * Ξ»_guid\n","\n"," total_loss = (\n"," loss_delta + loss_gate + loss_ent +\n"," loss_tau + loss_anchor + loss_guid +\n"," loss_pocket + loss_hidden\n"," )\n","\n"," # Backprop + step\n"," scaler.scale(total_loss).backward()\n"," scaler.step(opt)\n"," scaler.update()\n"," opt.zero_grad(set_to_none=True)\n"," if global_step % 100 == 0:\n"," print(f\"Gate mean: {gate.mean().item():.3f} log_sigma mean: {log_sigma.mean().item():.3f}\")\n"," if global_step % save_every == 0:\n"," path = output_dir / f\"dual_shunt_omega_{t5_prompt_mode}_noised_e1_step_{global_step}.safetensors\"\n"," save_safetensors(adapter, path)\n","\n"," pbar.set_postfix(loss=float(total_loss))\n","\n"," final = output_dir / f\"dual_shunt_omega_{t5_prompt_mode}_noised_e1_final.safetensors\"\n"," save_safetensors(adapter, final)\n"," print(f\"β
Epoch {epoch} complete. Final model saved.\")\n"],"metadata":{"id":"WoSAXzGDWsMv","executionInfo":{"status":"ok","timestamp":1748639025564,"user_tz":420,"elapsed":28,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":13,"outputs":[]},{"cell_type":"code","source":["train_dual_shunt(\n"," t5_tok=t5_tok,\n"," clip_tok=clip_tok,\n"," t5_mod=t5_mod,\n"," clip_mod=clip_mod,\n"," adapter=adapter,\n"," repo_id=\"AbstractPhil/human-templated-captions-1b\",\n"," batch_size=256,\n"," epochs=1,\n"," lr=1e-4,\n"," device=\"cuda\",\n"," g_target=7.5,\n"," save_every=1000,\n"," seed=42,\n"," max_token_length=77,\n"," t5_prompt_mode=\"no_caption\",\n"," output_dir=\"/content/drive/MyDrive/dual_shunt_runs_omega_g\"\n",")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"vJ8zeQ9-irqU","outputId":"0292845a-d502-476c-af48-0f4438145925"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 0%| | 100/117187 [03:36<69:52:34, 2.15s/it, loss=1.4]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.672\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 0%| | 200/117187 [07:12<70:05:53, 2.16s/it, loss=1.38]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.504\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 0%| | 300/117187 [11:30<70:27:44, 2.17s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.641\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 0%| | 400/117187 [15:06<69:24:20, 2.14s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.648\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 0%| | 500/117187 [18:42<70:10:52, 2.17s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.664\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 600/117187 [22:19<70:05:26, 2.16s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.695\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 700/117187 [25:55<70:00:11, 2.16s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.703\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 800/117187 [29:31<68:42:09, 2.13s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 900/117187 [33:08<69:25:05, 2.15s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.746\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 999/117187 [36:41<69:35:16, 2.16s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 1000/117187 [36:44<72:33:06, 2.25s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["β
Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_1000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 1100/117187 [40:21<69:00:44, 2.14s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.773\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 1200/117187 [43:56<70:06:30, 2.18s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.762\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 1300/117187 [47:33<69:19:52, 2.15s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.785\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 1400/117187 [51:10<69:26:39, 2.16s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.773\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%|β | 1500/117187 [54:47<69:38:53, 2.17s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.789\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%|β | 1600/117187 [58:24<69:27:47, 2.16s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.691\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%|β | 1700/117187 [1:02:00<70:12:30, 2.19s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.715\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|β | 1800/117187 [1:05:37<69:13:31, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.730\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|β | 1900/117187 [1:09:13<68:57:24, 2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|β | 1999/117187 [1:12:48<69:23:59, 2.17s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|β | 2000/117187 [1:12:50<72:07:19, 2.25s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["β
Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_2000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|β | 2100/117187 [1:16:26<70:14:40, 2.20s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.766\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|β | 2200/117187 [1:20:03<68:38:11, 2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.777\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|β | 2300/117187 [1:23:39<69:11:14, 2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.773\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|β | 2400/117187 [1:27:15<70:12:39, 2.20s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.762\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|β | 2500/117187 [1:30:52<68:46:14, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.766\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|β | 2600/117187 [1:34:29<68:55:45, 2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.785\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|β | 2700/117187 [1:38:05<69:16:25, 2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|β | 2800/117187 [1:41:42<69:14:21, 2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.734\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|β | 2900/117187 [1:45:19<68:12:24, 2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.746\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|β | 2999/117187 [1:48:52<68:32:54, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.812\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|β | 3000/117187 [1:48:54<71:14:59, 2.25s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["β
Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_3000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|β | 3100/117187 [1:52:31<68:08:09, 2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.863\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|β | 3200/117187 [1:56:07<68:19:20, 2.16s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.844\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|β | 3300/117187 [1:59:44<67:46:09, 2.14s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.852\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|β | 3400/117187 [2:03:19<68:13:23, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.863\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|β | 3500/117187 [2:06:56<68:10:13, 2.16s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.887\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|β | 3600/117187 [2:10:33<68:50:05, 2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.727\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|β | 3700/117187 [2:14:10<68:26:47, 2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.746\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|β | 3800/117187 [2:17:46<69:02:10, 2.19s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|β | 3900/117187 [2:21:22<67:54:57, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|β | 3999/117187 [2:24:56<68:50:27, 2.19s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.770\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|β | 4000/117187 [2:24:58<71:29:54, 2.27s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["β
Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_4000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|β | 4100/117187 [2:28:35<68:26:47, 2.18s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|β | 4200/117187 [2:32:12<67:47:18, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|β | 4300/117187 [2:35:49<67:43:33, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|β | 4400/117187 [2:39:25<68:20:20, 2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|β | 4500/117187 [2:43:01<68:01:47, 2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.652\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|β | 4600/117187 [2:46:38<67:19:53, 2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.605\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|β | 4700/117187 [2:50:15<68:15:17, 2.18s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.609\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|β | 4800/117187 [2:53:51<66:49:04, 2.14s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.688\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|β | 4900/117187 [2:57:28<67:53:59, 2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.766\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|β | 4999/117187 [3:01:03<67:34:57, 2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.719\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|β | 5000/117187 [3:01:05<70:09:23, 2.25s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["β
Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_5000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|β | 5100/117187 [3:04:42<67:32:22, 2.17s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.734\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|β | 5200/117187 [3:08:18<67:25:42, 2.17s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.746\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|β | 5300/117187 [3:11:54<67:17:14, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|β | 5400/117187 [3:15:30<67:11:10, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.754\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|β | 5500/117187 [3:19:06<67:10:37, 2.17s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.754\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|β | 5600/117187 [3:22:43<66:17:06, 2.14s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|β | 5700/117187 [3:26:20<67:08:19, 2.17s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.773\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|β | 5800/117187 [3:29:56<66:15:35, 2.14s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|β | 5900/117187 [3:33:33<65:57:43, 2.13s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|β | 5999/117187 [3:37:06<65:53:32, 2.13s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.770\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|β | 6000/117187 [3:37:09<68:33:52, 2.22s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["β
Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_6000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|β | 6100/117187 [3:40:44<66:08:48, 2.14s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.742\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|β | 6200/117187 [3:44:19<68:30:20, 2.22s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.629\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|β | 6300/117187 [3:47:56<66:18:20, 2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.605\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|β | 6400/117187 [3:51:32<66:18:28, 2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.570\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|β | 6500/117187 [3:55:08<65:36:41, 2.13s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.621\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|β | 6600/117187 [3:58:45<65:56:19, 2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.637\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|β | 6700/117187 [4:02:21<66:20:04, 2.16s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.664\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|β | 6800/117187 [4:05:57<67:02:27, 2.19s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.652\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|β | 6900/117187 [4:09:33<65:58:51, 2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.570\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|β | 6999/117187 [4:13:08<65:45:42, 2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.602\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|β | 7000/117187 [4:13:11<68:46:49, 2.25s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["β
Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_7000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|β | 7100/117187 [4:16:47<65:48:45, 2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.617\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|β | 7159/117187 [4:18:55<65:55:30, 2.16s/it, loss=1.34]"]}]},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"fy3PbHL4XEv3","executionInfo":{"status":"ok","timestamp":1748024232657,"user_tz":420,"elapsed":23471,"user":{"displayName":"P C","userId":"00707517734723903966"}},"outputId":"fdc85aa5-3920-44bc-8bc3-81f8d8a90cf2"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"markdown","source":["# summarization train"],"metadata":{"id":"EKLDJqI17Tkm"}},{"cell_type":"code","source":["from pathlib import Path\n","from tqdm import tqdm\n","import torch\n","import torch.nn.functional as F\n","from torch.utils.data import DataLoader\n","\n","def train_dual_shunt(\n"," t5_tok, clip_tok, t5_mod, clip_mod, adapter,\n"," repo_id: str,\n"," batch_size: int = 256,\n"," epochs: int = 1,\n"," lr: float = 5e-4,\n"," device: str = \"cuda\",\n"," g_target: float = 7.5,\n"," Ξ»_delta: float = 1.0,\n"," Ξ»_gate: float = 0.15,\n"," Ξ»_ent: float = 0.02,\n"," Ξ»_tau: float = 0.02,\n"," Ξ»_anchor: float = 0.12,\n"," Ξ»_guid: float = 0.1,\n"," save_every: int = 1000,\n"," max_token_length: int = 77,\n"," seed: int = 42,\n"," t5_prompt_mode: str = \"summarization_curriculum\",\n"," output_dir: str = \"/content/drive/MyDrive/dual_shunt_runs_omega_g\"\n","):\n"," def seed_everything(seed: int = 42):\n"," torch.manual_seed(seed)\n"," torch.cuda.manual_seed_all(seed)\n"," torch.backends.cudnn.deterministic = True\n"," torch.backends.cudnn.benchmark = False\n","\n"," def hetero_loss(delta, target, log_sigma):\n"," LOG_2PI = 1.83787706641\n"," log_sigma = log_sigma.clamp(-5.0, 5.0)\n"," inv_var = torch.exp(-log_sigma)\n"," return 0.5 * (inv_var * (delta - target) ** 2 + log_sigma + LOG_2PI).mean()\n","\n"," def entropy(p: torch.Tensor) -> torch.Tensor:\n"," p = p / (p.sum(-1, keepdim=True) + 1e-9)\n"," return -(p * (p + 1e-9).log()).sum(-1).mean()\n","\n"," # Setup\n"," seed_everything(seed)\n"," device = torch.device(device)\n"," ds = ParsedMultiCharDataset(repo_id, start_file=50, num_files=6, shuffle=True)\n"," dl = DataLoader(ds, batch_size=batch_size, num_workers=4, drop_last=True,\n"," persistent_workers=True, generator=torch.Generator().manual_seed(seed),\n"," prefetch_factor=8)\n","\n"," output_dir = Path(output_dir)\n"," output_dir.mkdir(parents=True, exist_ok=True)\n"," global_step = 0\n","\n"," opt = torch.optim.AdamW(adapter.parameters(), lr=lr)\n"," scaler = torch.cuda.amp.GradScaler()\n","\n"," for epoch in range(1, epochs + 1):\n"," pbar = tqdm(dl, desc=f\"Epoch {epoch}/{epochs}\")\n"," for texts in pbar:\n"," global_step += 1\n","\n"," with torch.autocast(device.type, dtype=torch.bfloat16):\n"," # T5 encoding\n"," t5_inputs = t5_tok(texts, padding=True, truncation=True,\n"," max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," with torch.no_grad():\n"," t5_seq = t5_mod(**t5_inputs).last_hidden_state\n","\n"," # CLIP (plain)\n"," clip_inputs = clip_tok(texts, padding=\"max_length\", truncation=True,\n"," max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," with torch.no_grad():\n"," clip_seq_plain = clip_mod(**clip_inputs).last_hidden_state\n","\n"," # Generate summaries\n"," with torch.no_grad():\n"," summary_ids = t5_mod.generate(**t5_inputs, max_length=max_token_length)\n"," cap_texts = t5_tok.batch_decode(summary_ids, skip_special_tokens=True)\n"," clip_inputs_cap = clip_tok(cap_texts, padding=\"max_length\", truncation=True,\n"," max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," clip_seq_summary = clip_mod(**clip_inputs_cap).last_hidden_state\n","\n"," null_clip_inputs = clip_tok([\"\"] * len(texts), padding=\"max_length\", truncation=True,\n"," max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," clip_seq_null = clip_mod(**null_clip_inputs).last_hidden_state\n","\n"," # Curriculum interpolation\n"," interpolation = min(global_step / 10000.0, 1.0)\n"," clip_seq_cap = (1.0 - interpolation) * clip_seq_null + interpolation * clip_seq_summary\n"," delta_tgt = clip_seq_cap - clip_seq_plain\n","\n"," # Adapter forward\n"," anchor, delta, log_sigma, attn_t2c, attn_c2t, tau, g_pred, gate = adapter(t5_seq, clip_seq_plain)\n","\n"," with torch.no_grad():\n"," t5_b = adapter.proj_t5(t5_seq)\n"," clip_b = adapter.proj_clip(clip_seq_plain)\n"," t2c_base, _ = adapter.cross_t2c(t5_b, clip_b, clip_b)\n"," c2t_base, _ = adapter.cross_c2t(clip_b, t5_b, t5_b)\n","\n"," pocket = adapter.pocket_blocks(t2c_base.detach())\n"," loss_pocket = F.mse_loss(pocket, t2c_base.detach()) * 0.05\n","\n"," fused_h = adapter.fuse(torch.cat([\n"," pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1),\n"," c2t_base\n"," ], dim=-1))\n"," loss_hidden = (fused_h.norm(dim=-1).mean() - 1.0).abs() * 0.01\n","\n"," # Losses\n"," warmup = min(global_step / 5000.0, 1.0)\n"," loss_delta = hetero_loss(delta, delta_tgt, log_sigma) * Ξ»_delta * warmup\n","\n"," gate_target_mean = 0.25 + 0.15 * interpolation\n"," loss_gate = (gate.mean() - gate_target_mean).abs() * Ξ»_gate\n","\n"," loss_ent = 0.5 * (entropy(attn_t2c) + entropy(attn_c2t)) * Ξ»_ent\n"," loss_tau = tau.abs().mean() * Ξ»_tau\n"," loss_anchor = (1 - F.cosine_similarity(anchor.mean(1), clip_seq_plain.mean(1), dim=-1).mean()) * Ξ»_anchor\n"," loss_guid = F.mse_loss(g_pred, torch.full_like(g_pred, g_target)) * Ξ»_guid\n","\n"," total_loss = (\n"," loss_delta + loss_gate + loss_ent +\n"," loss_tau + loss_anchor + loss_guid +\n"," loss_pocket + loss_hidden\n"," )\n","\n"," # Backprop + step\n"," scaler.scale(total_loss).backward()\n"," scaler.step(opt)\n"," scaler.update()\n"," opt.zero_grad(set_to_none=True)\n","\n"," if global_step % 100 == 0:\n"," print(f\"[Step {global_step}] Loss: {total_loss.item():.4f} Gate Mean: {gate.mean().item():.3f}\")\n","\n"," if global_step % save_every == 0:\n"," save_path = output_dir / f\"dual_shunt_omega_{t5_prompt_mode}_e{epoch}_step_{global_step}.safetensors\"\n"," save_safetensors(adapter, save_path)\n","\n"," final = output_dir / f\"dual_shunt_omega_{t5_prompt_mode}_e{epoch}_final.safetensors\"\n"," save_safetensors(adapter, final)\n"," print(f\"β
Epoch {epoch} complete. Final model saved to {final}\")\n"],"metadata":{"id":"iXmrn28M7L9L","executionInfo":{"status":"ok","timestamp":1748637860921,"user_tz":420,"elapsed":31,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["train_dual_shunt(\n"," t5_tok=t5_tok,\n"," clip_tok=clip_tok,\n"," t5_mod=t5_mod,\n"," clip_mod=clip_mod,\n"," adapter=adapter,\n"," repo_id=\"AbstractPhil/human-templated-captions-1b\",\n"," batch_size=256,\n"," epochs=1,\n"," lr=1e-4,\n"," device=\"cuda\",\n"," g_target=7.5,\n"," save_every=1000,\n"," seed=42,\n"," max_token_length=77,\n"," t5_prompt_mode=\"no_caption\",\n"," output_dir=\"/content/drive/MyDrive/dual_shunt_runs_omega_g\"\n",")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":408},"id":"aI2BH6Ja7YqZ","executionInfo":{"status":"error","timestamp":1748637873437,"user_tz":420,"elapsed":3426,"user":{"displayName":"P C","userId":"00707517734723903966"}},"outputId":"67122a9c-64f5-4cc0-85fa-7f5b810aed77"},"execution_count":8,"outputs":[{"output_type":"stream","name":"stderr","text":["<ipython-input-7-50f64bd900da>:56: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n"," scaler = torch.cuda.amp.GradScaler()\n","Epoch 1/1: 0%| | 0/117187 [00:01<?, ?it/s]\n"]},{"output_type":"error","ename":"AttributeError","evalue":"'T5EncoderModel' object has no attribute 'generate'","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-8-db93b9908267>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m train_dual_shunt(\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mt5_tok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mt5_tok\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mclip_tok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclip_tok\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mt5_mod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mt5_mod\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mclip_mod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclip_mod\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m<ipython-input-7-50f64bd900da>\u001b[0m in \u001b[0;36mtrain_dual_shunt\u001b[0;34m(t5_tok, clip_tok, t5_mod, clip_mod, adapter, repo_id, batch_size, epochs, lr, device, g_target, Ξ»_delta, Ξ»_gate, Ξ»_ent, Ξ»_tau, Ξ»_anchor, Ξ»_guid, save_every, max_token_length, seed, t5_prompt_mode, output_dir)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;31m# Generate summaries\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 78\u001b[0;31m \u001b[0msummary_ids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt5_mod\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mt5_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_length\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_token_length\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 79\u001b[0m \u001b[0mcap_texts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt5_tok\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_decode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msummary_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mskip_special_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 80\u001b[0m clip_inputs_cap = clip_tok(cap_texts, padding=\"max_length\", truncation=True,\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1926\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1927\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1928\u001b[0;31m raise AttributeError(\n\u001b[0m\u001b[1;32m 1929\u001b[0m \u001b[0;34mf\"'{type(self).__name__}' object has no attribute '{name}'\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1930\u001b[0m )\n","\u001b[0;31mAttributeError\u001b[0m: 'T5EncoderModel' object has no attribute 'generate'"]}]}]} |