{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"collapsed_sections":["erBNP6hvU724","PiehSbGeWmor","EKLDJqI17Tkm"],"machine_shape":"hm","gpuType":"L4","mount_file_id":"1Tsf5s2FZEHr9S5ja8MqkIA3DtPGNb0Cy","authorship_tag":"ABX9TyOx4EAq5vmbjU6gpbNxmViP"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"9bf9176fc21b4f15896d7281ced5fa17":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_76cf4a2d556842e497e2e64c2ba0b3fe","IPY_MODEL_a6b3d0c5d84b477c82545fb809fdda1c","IPY_MODEL_456b3fd649ab4c0fbb5348a765fb9613"],"layout":"IPY_MODEL_32bcb31386ca4187ba6a554d26189502"}},"76cf4a2d556842e497e2e64c2ba0b3fe":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_d4696d3bbdc149be9bab0cf4664c68c4","placeholder":"​","style":"IPY_MODEL_1575e84b9dc5454d94258035abf26da1","value":"Loading checkpoint shards: 100%"}},"a6b3d0c5d84b477c82545fb809fdda1c":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_b4e41d9dcb2c4be8bf94919c93c56a83","max":2,"min":0,"orientation":"horizontal","style":"IPY_MODEL_d96c33b98b874cf8862a2bedb3596b4c","value":2}},"456b3fd649ab4c0fbb5348a765fb9613":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_d7f4744342ec4ebd8d2798cae2e90518","placeholder":"​","style":"IPY_MODEL_2f9c4dbc55d04276a5b9fa92de613bf5","value":" 2/2 [00:00<00:00,  2.19it/s]"}},"32bcb31386ca4187ba6a554d26189502":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"d4696d3bbdc149be9bab0cf4664c68c4":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"1575e84b9dc5454d94258035abf26da1":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"b4e41d9dcb2c4be8bf94919c93c56a83":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"d96c33b98b874cf8862a2bedb3596b4c":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"d7f4744342ec4ebd8d2798cae2e90518":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"2f9c4dbc55d04276a5b9fa92de613bf5":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","source":["## Config and Module"],"metadata":{"id":"erBNP6hvU724"}},{"cell_type":"code","execution_count":1,"metadata":{"id":"9WGrocZdPBY_","executionInfo":{"status":"ok","timestamp":1748636886799,"user_tz":420,"elapsed":19,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"outputs":[],"source":["ADAPTER_CONFIG = {\n"," \"adapter_id\": \"003\",\n"," \"name\": \"DualShuntAdapter-G\",\n","\n"," \"t5\": {\n"," \"model\": \"google/flan-t5-base\",\n"," \"hidden_size\": 768,\n"," },\n"," \"clip\": {\n"," \"model\": \"AbstractPhil/omega-vit-g-reformed\",\n"," \"hidden_size\": 1280,\n"," },\n","\n"," \"bottleneck\": 640,\n"," \"heads\": 20,\n","\n"," \"tau_init\": 0.1,\n"," \"max_guidance\": 10.0,\n","\n"," \"proj_layers\": 2,\n"," \"layer_norm\": True,\n"," \"dropout\": 0.1,\n"," \"use_dropout\": True,\n"," \"use_proj_stack\": True,\n"," \"assert_input_dims\": True,\n","\n"," \"routing\": {\n"," \"type\": \"cross_attention\",\n"," \"enable_causal_mask\": False,\n"," \"bidirectional\": True\n"," },\n","\n"," \"version\": \"v0.3.2\",\n"," \"description\": \"Final Dual Shunt Adapter with projection stack, dropout, and stacked residual refinement pocket.\"\n","}\n"]},{"cell_type":"code","source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","# ─── Residual Pocket Block ───────────────────────────────────\n","class BottleneckResBlock(nn.Module):\n"," def __init__(self, dim, kernel=3, dropout=0.1):\n"," super().__init__()\n"," self.norm = nn.LayerNorm(dim)\n"," self.conv = nn.Conv1d(dim, dim, kernel_size=kernel, padding=kernel // 2, groups=1)\n"," self.proj = nn.Sequential(\n"," nn.Linear(dim, dim * 2),\n"," nn.GELU(),\n"," nn.Linear(dim * 2, dim),\n"," nn.Dropout(dropout)\n"," )\n","\n"," def forward(self, x):\n"," residual = x\n"," x = self.norm(x)\n"," x = x.transpose(1, 2)\n"," x = self.conv(x).transpose(1, 2)\n"," return residual + self.proj(x)\n","\n","# ─── Two Stream Shunt Adapter ──────────────────────────────────────\n","class TwoStreamShuntAdapter(nn.Module):\n"," def __init__(self, config: dict):\n"," super().__init__()\n"," self.config = config\n"," self.t5_dim = config[\"t5\"][\"hidden_size\"]\n"," self.clip_dim = config[\"clip\"][\"hidden_size\"]\n"," self.bneck = config[\"bottleneck\"]\n"," self.heads = config[\"heads\"]\n"," self.tau_init = config[\"tau_init\"]\n"," self.max_guidance = config[\"max_guidance\"]\n","\n"," use_norm = config.get(\"layer_norm\", True)\n"," use_do = config.get(\"use_dropout\", True)\n"," do_p = config.get(\"dropout\", 0.1)\n"," proj_depth = config.get(\"proj_layers\", 2)\n","\n"," def build_projection(input_dim, output_dim):\n"," layers = []\n"," last_dim = input_dim\n"," if use_norm:\n"," layers.append(nn.LayerNorm(last_dim))\n"," for i in range(proj_depth):\n"," next_dim = self.bneck * (2 if i == 0 and proj_depth > 1 else 1)\n"," layers.append(nn.Linear(last_dim, next_dim))\n"," layers.append(nn.GELU())\n"," if use_do:\n"," layers.append(nn.Dropout(do_p))\n"," last_dim = next_dim\n"," layers.append(nn.Linear(last_dim, output_dim))\n"," return nn.Sequential(*layers)\n","\n"," # Projections\n"," self.proj_t5 = build_projection(self.t5_dim, self.bneck)\n"," self.proj_clip = build_projection(self.clip_dim, self.bneck)\n","\n"," # Attention\n"," self.cross_t2c = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)\n"," self.cross_c2t = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)\n"," self.tau = nn.Parameter(torch.full((self.heads, 1, 1), self.tau_init))\n","\n"," # Residual Pocket\n"," self.pocket_blocks = nn.Sequential(\n"," BottleneckResBlock(self.bneck, dropout=do_p),\n"," BottleneckResBlock(self.bneck, dropout=do_p)\n"," )\n","\n"," # Fuse\n"," self.fuse = nn.Sequential(\n"," nn.LayerNorm(2 * self.bneck),\n"," nn.Linear(2 * self.bneck, self.bneck * 2),\n"," nn.GELU(),\n"," nn.Linear(self.bneck * 2, self.bneck)\n"," )\n","\n"," # Output Projections\n"," self.anchor_proj = build_projection(self.bneck, self.clip_dim)\n"," self.delta_proj = build_projection(self.bneck, self.clip_dim)\n"," self.logsig_proj = build_projection(self.bneck, self.clip_dim)\n","\n"," self.gate_proj = nn.Sequential(\n"," nn.LayerNorm(self.bneck),\n"," nn.Linear(self.bneck, self.bneck),\n"," nn.GELU(),\n"," nn.Linear(self.bneck, 1),\n"," nn.Tanh(),\n"," nn.Sigmoid()\n"," )\n","\n"," self.guidance_proj = nn.Sequential(\n"," nn.LayerNorm(self.bneck),\n"," nn.Linear(self.bneck, 1),\n"," nn.Sigmoid()\n"," )\n","\n"," def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):\n"," if self.config.get(\"assert_input_dims\", True):\n"," assert t5_seq.size(-1) == self.t5_dim\n"," assert clip_seq.size(-1) == self.clip_dim\n","\n"," t5_b = self.proj_t5(t5_seq)\n"," clip_b = self.proj_clip(clip_seq)\n","\n"," t2c, attn_t2c = self.cross_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)\n"," c2t, attn_c2t = self.cross_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)\n","\n"," pocket = self.pocket_blocks(t2c)\n","\n"," pocket_mean = pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1)\n"," h = self.fuse(torch.cat([pocket_mean, c2t], dim=-1))\n","\n"," anchor = self.anchor_proj(h)\n"," delta = self.delta_proj(h) * self.gate_proj(h)\n"," log_sigma = self.logsig_proj(h)\n","\n"," g_tok = self.guidance_proj(h).squeeze(-1)\n"," g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance\n","\n"," return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, self.gate_proj(h)\n"],"metadata":{"id":"qjb_vFZRTQaC","executionInfo":{"status":"ok","timestamp":1748636888396,"user_tz":420,"elapsed":1586,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":2,"outputs":[]},{"cell_type":"code","source":["from safetensors.torch import save_file, load_file\n","\n","def save_safetensors(adapter: nn.Module, path: str, metadata: dict = None):\n"," \"\"\"\n"," Save the current adapter state to safetensors format.\n","\n"," All tensors are moved to CPU and saved as float32 for compatibility.\n"," Optional metadata may be embedded (e.g., version, prompt_mode).\n"," \"\"\"\n"," state = {k: v.float().cpu() for k, v in adapter.state_dict().items()}\n"," save_file(state, path, metadata=metadata or {})\n"," print(f\"✅ Model saved to {path}\")\n","\n","def load_safetensors(adapter: nn.Module, path: str, map_location=\"cpu\"):\n"," \"\"\"\n"," Load a safetensors checkpoint into the adapter.\n","\n"," Uses strict key matching. Tensors are loaded to the specified device.\n"," \"\"\"\n"," state = load_file(path, device=map_location)\n"," adapter.load_state_dict(state, strict=True)\n"," print(f\"✅ Model loaded from {path}\")\n","\n","\n"],"metadata":{"id":"zpOi5svciXJ6","executionInfo":{"status":"ok","timestamp":1748636888425,"user_tz":420,"elapsed":13,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":3,"outputs":[]},{"cell_type":"markdown","source":["## Data Loader"],"metadata":{"id":"PiehSbGeWmor"}},{"cell_type":"code","source":["import torch\n","import csv\n","\n","# ─────────────────────────────────────────────────────────────\n","# ░ Streaming Caption Dataset\n","# ─────────────────────────────────────────────────────────────\n","# ─────────────────────────────────────────────────────────────\n","# ░ Streaming Caption Dataset – 32 quality descriptors\n","# ─────────────────────────────────────────────────────────────\n","import csv, random, re\n","from typing import List\n","from torch.utils.data import IterableDataset, get_worker_info\n","from huggingface_hub import hf_hub_download\n","from torch.utils.data import DataLoader\n","from pathlib import Path\n","\n","\n","class ParsedMultiCharDataset(IterableDataset):\n"," \"\"\"\n"," Streams HF-hosted caption shards and, for each text chunk:\n","\n"," • If it starts with “a ” (case-insensitive, ignoring leading spaces),\n"," replace that leading token with a random photo/video quality\n"," descriptor followed by “, ”.\n","\n"," No preliminary file scanning; every CSV line is read exactly once.\n"," \"\"\"\n","\n"," # match leading “a ” only\n"," _PAT_START_A = re.compile(r\"^\\s*a\\s+\", re.IGNORECASE)\n","\n"," # 32 diverse quality descriptors\n"," _QUALITY_DESCRIPTORS: List[str] = [\n"," \"masterpiece,\",\n"," \"very aesthetic,\",\n"," \"most aesthetic,\",\n"," \"an absolutely perfect depiction of\",\n"," \"awa,\",\n"," \"very awa,\",\n"," \"dimly lit,\",\n"," \"beautifuly lit,\",\n"," \"very beautiful,\",\n"," \"masterful depiction of\",\n"," \"dedicated masterpiece,\",\n"," \"warmly lit\",\n"," \"best quality, most aesthetic,\",\n"," \"beautiful depiction of\",\n"," \"masterful artwork of\",\n"," \"high-resolution photograph,\",\n"," \"hyper-realistic image,\",\n"," \"ultra-detailed photo,\",\n"," \"studio-quality photograph,\",\n"," \"cinematic shot,\",\n"," \"4K HDR image,\",\n"," \"sharp-focus photo,\",\n"," \"professionally lit photograph,\",\n"," \"DSLR capture,\",\n"," \"film-grain photograph,\",\n"," \"bokeh-rich shot,\",\n"," \"medium-format scan,\",\n"," \"analog film still,\",\n"," \"moody cinematic frame,\",\n"," \"dramatic-lighting photo,\",\n"," \"vibrant editorial image,\",\n"," \"macro-lens close-up,\",\n"," \"aerial drone photo,\",\n"," \"soft-focus dreamlike photo,\",\n"," \"low-key studio shot,\",\n"," \"overhead product shot,\",\n"," \"golden-hour photograph,\",\n"," \"noir-style monochrome shot,\",\n"," \"vintage Polaroid scan,\",\n"," \"infrared photograph,\",\n"," \"ultra-wide panorama,\",\n"," \"tilt-shift miniature photo,\",\n"," \"long-exposure night shot,\",\n"," \"time-lapse still,\",\n"," \"splash-photography frame,\",\n"," \"fine-art print scan,\",\n"," \"astrophotography capture,\",\n"," \"score_9, score_8, score_7, score_6,\",\n"," \"score_1, score_2, score_3, score_4,\",\n"," \"masterpiece, most aesthetic,\",\n"," \"most aesthetic, very aesthetic,\",\n"," \"masterpiece, most aesthetic, very aesthetic, realistic, real,\",\n"," \"most aesthetic, realistic, real,\",\n"," \"very aesthetic, realistic, real,\",\n"," \"masterpiece, very aesthetic, realistic, real,\",\n"," \"most aesthetic, very aesthetic, realistic, real,\",\n"," \"masterpiece, very aesthetic, realistic, anime,\",\n"," \"most aesthetic, very aesthetic, realistic, anime,\",\n"," \"very aesthetic, realistic, anime,\",\n"," \"masterpiece, very aesthetic, realistic, anime,\",\n"," \"most aesthetic, very aesthetic, realistic, anime,\",\n"," \"very aesthetic, realistic, anime, anime,\",\n"," \"2d,\",\n"," \"3d,\",\n"," \"anime,\",\n"," \"real,\",\n"," \"cartoon,\"\n"," \"realistic,\",\n"," \"2d,\",\n"," \"3d,\",\n"," \"anime,\",\n"," \"real,\",\n"," \"cartoon,\"\n"," \"realistic,\",\n"," \"2d,\",\n"," \"3d,\",\n"," \"anime,\",\n"," \"real,\",\n"," \"cartoon,\"\n"," \"realistic,\",\n"," \"2d,\",\n"," \"3d,\",\n"," \"anime,\",\n"," \"real,\",\n"," \"cartoon,\"\n"," \"realistic,\",\n"," \"masterpiece, 2d,\",\n"," \"masterpiece, 3d,\",\n"," \"masterpiece, anime,\",\n"," \"masterpiece, real,\",\n"," \"masterpiece, cartoon,\",\n"," \"3d, anime,\",\n"," \"3d, anime, real,\",\n"," \"3d, anime, real, realistic,\",\n"," \"masterpiece, 3d,\",\n"," \"masterpiece, 3d, anime,\",\n"," \"masterpiece, 3d, anime, real,\",\n"," \"masterpiece, 3d, anime, real, realistic,\",\n"," \"masterpiece, anime,\",\n"," \"masterpiece, anime, real,\",\n"," \"masterpiece, anime, real, realistic,\",\n"," \"masterpiece, 3d, anime, real,\",\n"," \"very aesthetic, 3d,\",\n"," \"very aesthetic, 3d, anime,\",\n"," \"very aesthetic, 3d, anime, real,\",\n"," \"very aesthetic, 3d, anime, real, realistic,\",\n"," \"very aesthetic, 3d, anime, real,\",\n"," \"most aesthetic, 3d,\",\n"," \"most aesthetic, 3d, anime,\",\n"," \"most aesthetic, 3d, anime, real,\",\n"," \"most aesthetic, 3d, anime, real, realistic,\",\n"," \"anime, comic,\"\n"," \"manga, anime\",\n"," \"masterpiece, cartoon,\",\n"," \"masterpiece, cartoon, real,\",\n"," \"masterpiece, cartoon, real, realistic,\",\n"," \"masterpiece, cartoon, real,\",\n"," \"most aesthetic, cartoon,\",\n"," \"most aesthetic, cartoon, real,\",\n"," \"most aesthetic, cartoon, real, realistic,\",\n"," \"most aesthetic, cartoon, real,\",\n"," \"grid_a1 head,\"\n"," \"grid_a2 head,\",\n"," \"grid_a3 head,\",\n"," \"grid_a4 head,\",\n"," \"grid_a5 head,\",\n"," \"grid_b1 head,\"\n"," \"grid_b2 head,\",\n"," \"grid_b3 head,\",\n"," \"grid_b4 head,\",\n"," \"grid_b5 head,\",\n"," \"grid_c1 head,\"\n"," \"grid_c2 head,\",\n"," \"grid_c3 head,\",\n"," \"grid_c4 head,\",\n"," \"grid_c5 head,\",\n"," \"grid_d1 head,\"\n"," \"grid_d2 head,\",\n"," \"grid_d3 head,\",\n"," \"grid_d4 head,\",\n"," \"grid_d5 head,\",\n"," \"grid_e1 head,\"\n"," \"grid_e2 head,\",\n"," \"grid_e3 head,\",\n"," \"grid_e4 head,\",\n"," \"grid_e5 head,\",\n"," \"grid_a1 upper body,\"\n"," \"grid_a2 upper body,\",\n"," \"grid_a3 upper body,\",\n"," \"grid_a4 upper body,\",\n"," \"grid_a5 upper body,\",\n"," \"grid_b1 upper body,\"\n"," \"grid_b2 upper body,\",\n"," \"grid_b3 upper body,\",\n"," \"grid_b4 upper body,\",\n"," \"grid_b5 upper body,\",\n"," \"grid_c1 upper body,\"\n"," \"grid_c2 upper body,\",\n"," \"grid_c3 upper body,\",\n"," \"grid_c4 upper body,\",\n"," \"grid_c5 upper body,\",\n"," \"grid_d1 upper body,\"\n"," \"grid_d2 upper body,\",\n"," \"grid_d3 upper body,\",\n"," \"grid_d4 upper body,\",\n"," \"grid_d5 upper body,\",\n"," \"grid_e1 upper body,\"\n"," \"grid_e2 upper body,\",\n"," \"grid_e3 upper body,\",\n"," \"grid_e4 upper body,\",\n"," \"grid_e5 upper body,\",\n"," \"zone_ul upper body,\"\n"," \"zone_ur upper body,\"\n"," \"zone_ll upper body,\"\n"," \"zone_lr upper body,\",\n"," \"zone_ul head,\"\n"," \"zone_ur head,\"\n"," \"zone_ll head,\"\n"," \"zone_lr head,\",\n","\n"," #\"disgusting, 3d, lowres,\",\n"," #\"disgusting, 2d, lowres,\",\n"," #\"disgusting, cartoon, lowres,\",\n"," #\"disgusting, 3d, cartoon, lowres,\",\n"," #\"disgusting, 2d, cartoon, lowres,\",\n"," #\"disgusting\n","\n","\n"," ]\n","\n"," def __init__(\n"," self,\n"," repo_id: str,\n"," delimiter: str = \".,|,.\",\n"," start_file: int = 20,\n"," num_files: int = 80,\n"," shuffle: bool = False\n"," ):\n"," super().__init__()\n"," self.delimiter = delimiter\n"," self.paths = [\n"," hf_hub_download(\n"," repo_id,\n"," f\"captions/caption_{i + start_file:03d}.csv\",\n"," repo_type=\"dataset\",\n"," )\n"," for i in range(num_files)\n"," ]\n"," self.total_rows = 5_000_000 * num_files\n"," self.shuffle = shuffle\n","\n"," def __len__(self):\n"," return self.total_rows\n","\n"," def __iter__(self):\n"," worker = get_worker_info()\n"," paths = (\n"," self.paths\n"," if worker is None\n"," else self.paths[worker.id :: worker.num_workers]\n"," )\n","\n"," pat_a = self._PAT_START_A\n"," choose = random.choice\n"," q_pool = self._QUALITY_DESCRIPTORS\n"," delim = self.delimiter\n","\n"," for path in paths:\n"," with open(path, encoding=\"utf-8\", newline=\"\") as f:\n"," for row in csv.DictReader(f):\n"," for chunk in row.get(\"text\", \"\").split(delim):\n"," chunk = chunk.strip()\n"," if not chunk:\n"," continue\n"," if self.shuffle:\n"," random.shuffle(q_pool)\n"," # replace leading “a ” with descriptor + comma\n"," chunk = pat_a.sub(choose(q_pool) + \" \", chunk, count=1)\n","\n"," yield chunk\n","\n","\n"],"metadata":{"id":"mWRkuGO5U5Ma","executionInfo":{"status":"ok","timestamp":1748636889529,"user_tz":420,"elapsed":267,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":4,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"_FBrEkfzbB0M","executionInfo":{"status":"ok","timestamp":1748636889550,"user_tz":420,"elapsed":19,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":4,"outputs":[]},{"cell_type":"markdown","source":["## Load Models"],"metadata":{"id":"RiRNq1OjXQ4R"}},{"cell_type":"code","source":["import torch\n","from transformers import (\n"," T5EncoderModel, T5TokenizerFast,\n"," CLIPTextModel, CLIPTokenizerFast\n",")\n","\n","import torch\n","from transformers import (\n"," T5EncoderModel, T5TokenizerFast,\n"," CLIPTextModel, CLIPTokenizerFast\n",")\n","\n","from safetensors.torch import save_file, load_file\n","\n","\n","# ─── Runtime Settings ────────────────────────────────────────\n","DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","DTYPE = torch.float32 # 🔒 Force full FP32 precision\n","\n","# ─── Load Tokenizers ─────────────────────────────────────────\n","t5_tok = T5TokenizerFast.from_pretrained(\"google/flan-t5-base\")\n","clip_tok = CLIPTokenizerFast.from_pretrained(\"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k\")\n","\n","# ─── Load and Freeze T5 Model ───────────────────────────────\n","t5_mod = T5EncoderModel.from_pretrained(\"google/flan-t5-base\").to(DEVICE, dtype=DTYPE)\n","t5_mod.eval().requires_grad_(False)\n","\n","# ─── Load and Freeze CLIP Model ─────────────────────────────\n","clip_mod = CLIPTextModel.from_pretrained(\"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k\").to(DEVICE, dtype=DTYPE)\n","\n","\n","print(clip_mod.state_dict().keys())\n","\n","temp_clip = load_file(\"/content/drive/MyDrive/clips/OMEGA-24-CLIP_G.safetensors\")\n","\n","print(temp_clip.keys())\n","\n","clip_mod.load_state_dict(temp_clip, strict=False)\n","clip_mod.eval().requires_grad_(False)\n","## ─── Load Dual Shunt Adapter ────────────────────────────────\n","#adapter = DualShuntAdapter(config=ADAPTER_CONFIG).to(DEVICE, dtype=DTYPE)\n","#\n","#\n","#print(\"✅ All models loaded at float32 precision and ready.\")\n","\n","# ─── Initialize Adapter from Config ──────────────────────────\n","adapter = TwoStreamShuntAdapter(config=ADAPTER_CONFIG).to(DEVICE, dtype=DTYPE)\n","load_safetensors(adapter, \"/content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_e1_step_10000.safetensors\")\n","adapter.train()\n","print(\"✅ T5, CLIP, and Dual Shunt Adapter loaded and cast.\")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":138,"referenced_widgets":["9bf9176fc21b4f15896d7281ced5fa17","76cf4a2d556842e497e2e64c2ba0b3fe","a6b3d0c5d84b477c82545fb809fdda1c","456b3fd649ab4c0fbb5348a765fb9613","32bcb31386ca4187ba6a554d26189502","d4696d3bbdc149be9bab0cf4664c68c4","1575e84b9dc5454d94258035abf26da1","b4e41d9dcb2c4be8bf94919c93c56a83","d96c33b98b874cf8862a2bedb3596b4c","d7f4744342ec4ebd8d2798cae2e90518","2f9c4dbc55d04276a5b9fa92de613bf5"]},"id":"NOpbg28ZXQH5","executionInfo":{"status":"ok","timestamp":1748636906923,"user_tz":420,"elapsed":12771,"user":{"displayName":"P C","userId":"00707517734723903966"}},"outputId":"8ebbca72-abed-49df-ca20-04250b38be55"},"execution_count":5,"outputs":[{"output_type":"display_data","data":{"text/plain":["Loading checkpoint shards: 0%| | 0/2 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mclip_mod\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtemp_clip\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstrict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mclip_mod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclip_mod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"reduce-overhead\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;31m## ─── Load Dual Shunt Adapter ────────────────────────────────\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mTypeError\u001b[0m: compile() got an unexpected keyword argument 'dtype'"]}]},{"cell_type":"markdown","source":["# Noise Train"],"metadata":{"id":"pj2a104jbp42"}},{"cell_type":"code","source":["import torch\n","import torch.nn.functional as F\n","from tqdm import tqdm\n","from pathlib import Path\n","\n","\n","def train_dual_shunt(\n"," t5_tok, clip_tok, t5_mod, clip_mod, adapter,\n"," repo_id: str,\n"," batch_size: int = 256,\n"," epochs: int = 1,\n"," lr: float = 5e-4,\n"," device: str = \"cuda\",\n"," g_target: float = 7.5,\n"," λ_delta: float = 1.0,\n"," λ_gate: float = 0.15,\n"," λ_ent: float = 0.02,\n"," λ_tau: float = 0.02,\n"," λ_anchor: float= 0.12,\n"," λ_guid: float = 0.1,\n"," save_every: int = 1000,\n"," max_token_length: int = 77,\n"," seed: int = 42,\n"," t5_prompt_mode: str = \"real_raw\",\n"," output_dir: str = \"/content/drive/MyDrive/dual_shunt_runs_omega_g\"\n","):\n"," import torch\n"," import torch.nn.functional as F\n"," from tqdm import tqdm\n"," from pathlib import Path\n"," import random\n","\n"," def extract_sparse_tokens(text: str, keep_min: int = 1, keep_max: int = 5, keep_prob: float = 0.3) -> str:\n"," words = text.strip().split()\n"," keep_count = min(len(words), random.randint(keep_min, keep_max))\n"," indices = sorted(random.sample(range(len(words)), k=keep_count))\n"," return \" \".join(words[i] for i in indices)\n","\n"," def seed_everything(seed: int = 42):\n"," torch.manual_seed(seed)\n"," torch.cuda.manual_seed_all(seed)\n"," torch.backends.cudnn.deterministic = True\n"," torch.backends.cudnn.benchmark = False\n","\n"," LOG_2PI = 1.83787706641\n","\n"," def hetero_loss(delta, target, log_sigma):\n"," log_sigma = log_sigma.clamp(-5.0, 5.0)\n"," inv_var = torch.exp(-log_sigma)\n"," return 0.5 * (inv_var * (delta - target)**2 + log_sigma + LOG_2PI).mean()\n","\n"," def entropy(p: torch.Tensor) -> torch.Tensor:\n"," p = p / (p.sum(-1, keepdim=True) + 1e-9)\n"," return -(p * (p + 1e-9).log()).sum(-1).mean()\n","\n"," # ─── Setup ───────────────────────────────────────────────\n"," seed_everything(seed)\n"," device = torch.device(device)\n","\n"," ds = ParsedMultiCharDataset(repo_id, start_file=50, num_files=6, shuffle=True)\n"," dl = torch.utils.data.DataLoader(\n"," ds, batch_size=batch_size, num_workers=4,\n"," drop_last=True, persistent_workers=True,\n"," generator=torch.Generator().manual_seed(seed),\n"," prefetch_factor=8\n"," )\n","\n"," output_dir = Path(output_dir)\n"," output_dir.mkdir(parents=True, exist_ok=True)\n"," global_step = 0\n","\n"," opt = torch.optim.AdamW(adapter.parameters(), lr=lr)\n"," scaler = torch.amp.GradScaler(device=\"cuda\")\n","\n"," # ─── Training Loop ───────────────────────────────────────\n"," for epoch in range(1, epochs + 1):\n"," pbar = tqdm(dl, desc=f\"Epoch {epoch}/{epochs}\")\n"," for texts in pbar:\n"," global_step += 1\n","\n"," with torch.autocast(device.type, dtype=torch.bfloat16):\n"," # T5 prompt encoding\n"," t5_inputs = t5_tok(texts, padding=True, truncation=True,\n"," max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," with torch.no_grad():\n"," t5_seq = t5_mod(**t5_inputs).last_hidden_state\n","\n"," # CLIP (plain)\n"," clip_inputs = clip_tok(texts, padding=\"max_length\", truncation=True,\n"," max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," with torch.no_grad():\n"," clip_seq_plain = clip_mod(**clip_inputs).last_hidden_state\n","\n"," # CLIP (captioned target)\n","\n"," # for null noise training\n"," cap_texts = [f\"{extract_sparse_tokens(t)}\" for t in texts]\n"," clip_inputs_cap = clip_tok(cap_texts, padding=\"max_length\", truncation=True,\n"," max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," with torch.no_grad():\n"," clip_seq_cap = clip_mod(**clip_inputs_cap).last_hidden_state\n","\n"," #with torch.no_grad():\n"," # summary_ids = t5_mod.generate(**t5_inputs, max_length=max_token_length)\n"," # cap_texts = t5_tok.batch_decode(summary_ids, skip_special_tokens=True)\n","\n"," # Encode summaries into CLIP\n"," #clip_inputs_cap = clip_tok(cap_texts, padding=\"max_length\", truncation=True,\n"," # max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," #with torch.no_grad():\n"," # clip_seq_cap = clip_mod(**clip_inputs_cap).last_hidden_state\n","\n"," # Adapter forward\n"," anchor, delta, log_sigma, attn_t2c, attn_c2t, tau, g_pred, gate = adapter(t5_seq, clip_seq_plain)\n"," delta_tgt = clip_seq_cap - clip_seq_plain\n","\n"," # 🔧 SAFE attention projection for auxiliary losses\n"," with torch.no_grad():\n"," t5_b = adapter.proj_t5(t5_seq)\n"," clip_b = adapter.proj_clip(clip_seq_plain)\n"," t2c_base, _ = adapter.cross_t2c(t5_b, clip_b, clip_b)\n"," c2t_base, _ = adapter.cross_c2t(clip_b, t5_b, t5_b)\n","\n"," pocket = adapter.pocket_blocks(t2c_base.detach())\n"," loss_pocket = F.mse_loss(pocket, t2c_base.detach()) * 0.05\n","\n"," fused_h = adapter.fuse(torch.cat([\n"," pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1),\n"," c2t_base\n"," ], dim=-1))\n"," loss_hidden = (fused_h.norm(dim=-1).mean() - 1.0).abs() * 0.01\n","\n"," # Core losses\n"," loss_delta = hetero_loss(delta, delta_tgt, log_sigma) * λ_delta\n"," loss_gate = (gate.mean() - 0.25).abs() * λ_gate\n"," loss_ent = 0.5 * (entropy(attn_t2c) + entropy(attn_c2t)) * λ_ent\n"," loss_tau = tau.abs().mean() * λ_tau\n"," loss_anchor = (1 - F.cosine_similarity(anchor.mean(1), clip_seq_plain.mean(1), dim=-1).mean()) * λ_anchor\n"," loss_guid = F.mse_loss(g_pred, torch.full_like(g_pred, g_target)) * λ_guid\n","\n"," total_loss = (\n"," loss_delta + loss_gate + loss_ent +\n"," loss_tau + loss_anchor + loss_guid +\n"," loss_pocket + loss_hidden\n"," )\n","\n"," # Backprop + step\n"," scaler.scale(total_loss).backward()\n"," scaler.step(opt)\n"," scaler.update()\n"," opt.zero_grad(set_to_none=True)\n"," if global_step % 100 == 0:\n"," print(f\"Gate mean: {gate.mean().item():.3f} log_sigma mean: {log_sigma.mean().item():.3f}\")\n"," if global_step % save_every == 0:\n"," path = output_dir / f\"dual_shunt_omega_{t5_prompt_mode}_noised_e1_step_{global_step}.safetensors\"\n"," save_safetensors(adapter, path)\n","\n"," pbar.set_postfix(loss=float(total_loss))\n","\n"," final = output_dir / f\"dual_shunt_omega_{t5_prompt_mode}_noised_e1_final.safetensors\"\n"," save_safetensors(adapter, final)\n"," print(f\"✅ Epoch {epoch} complete. Final model saved.\")\n"],"metadata":{"id":"WoSAXzGDWsMv","executionInfo":{"status":"ok","timestamp":1748639025564,"user_tz":420,"elapsed":28,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":13,"outputs":[]},{"cell_type":"code","source":["train_dual_shunt(\n"," t5_tok=t5_tok,\n"," clip_tok=clip_tok,\n"," t5_mod=t5_mod,\n"," clip_mod=clip_mod,\n"," adapter=adapter,\n"," repo_id=\"AbstractPhil/human-templated-captions-1b\",\n"," batch_size=256,\n"," epochs=1,\n"," lr=1e-4,\n"," device=\"cuda\",\n"," g_target=7.5,\n"," save_every=1000,\n"," seed=42,\n"," max_token_length=77,\n"," t5_prompt_mode=\"no_caption\",\n"," output_dir=\"/content/drive/MyDrive/dual_shunt_runs_omega_g\"\n",")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"vJ8zeQ9-irqU","outputId":"0292845a-d502-476c-af48-0f4438145925"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 0%| | 100/117187 [03:36<69:52:34, 2.15s/it, loss=1.4]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.672\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 0%| | 200/117187 [07:12<70:05:53, 2.16s/it, loss=1.38]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.504\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 0%| | 300/117187 [11:30<70:27:44, 2.17s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.641\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 0%| | 400/117187 [15:06<69:24:20, 2.14s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.648\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 0%| | 500/117187 [18:42<70:10:52, 2.17s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.664\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 600/117187 [22:19<70:05:26, 2.16s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.695\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 700/117187 [25:55<70:00:11, 2.16s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.703\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 800/117187 [29:31<68:42:09, 2.13s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 900/117187 [33:08<69:25:05, 2.15s/it, loss=1.37]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.746\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 999/117187 [36:41<69:35:16, 2.16s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 1000/117187 [36:44<72:33:06, 2.25s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["✅ Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_1000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 1100/117187 [40:21<69:00:44, 2.14s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.773\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 1200/117187 [43:56<70:06:30, 2.18s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.762\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 1300/117187 [47:33<69:19:52, 2.15s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.785\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%| | 1400/117187 [51:10<69:26:39, 2.16s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.773\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%|▏ | 1500/117187 [54:47<69:38:53, 2.17s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.789\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%|▏ | 1600/117187 [58:24<69:27:47, 2.16s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.691\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 1%|▏ | 1700/117187 [1:02:00<70:12:30, 2.19s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.715\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|▏ | 1800/117187 [1:05:37<69:13:31, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.730\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|▏ | 1900/117187 [1:09:13<68:57:24, 2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|▏ | 1999/117187 [1:12:48<69:23:59, 2.17s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|▏ | 2000/117187 [1:12:50<72:07:19, 2.25s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["✅ Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_2000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|▏ | 2100/117187 [1:16:26<70:14:40, 2.20s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.766\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|▏ | 2200/117187 [1:20:03<68:38:11, 2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.777\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|▏ | 2300/117187 [1:23:39<69:11:14, 2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.773\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|▏ | 2400/117187 [1:27:15<70:12:39, 2.20s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.762\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|▏ | 2500/117187 [1:30:52<68:46:14, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.766\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|▏ | 2600/117187 [1:34:29<68:55:45, 2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.785\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|▏ | 2700/117187 [1:38:05<69:16:25, 2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|▏ | 2800/117187 [1:41:42<69:14:21, 2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.734\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 2%|▏ | 2900/117187 [1:45:19<68:12:24, 2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.746\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|▎ | 2999/117187 [1:48:52<68:32:54, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.812\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|▎ | 3000/117187 [1:48:54<71:14:59, 2.25s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["✅ Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_3000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|▎ | 3100/117187 [1:52:31<68:08:09, 2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.863\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|▎ | 3200/117187 [1:56:07<68:19:20, 2.16s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.844\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|▎ | 3300/117187 [1:59:44<67:46:09, 2.14s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.852\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|▎ | 3400/117187 [2:03:19<68:13:23, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.863\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|▎ | 3500/117187 [2:06:56<68:10:13, 2.16s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.887\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|▎ | 3600/117187 [2:10:33<68:50:05, 2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.727\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|▎ | 3700/117187 [2:14:10<68:26:47, 2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.746\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|▎ | 3800/117187 [2:17:46<69:02:10, 2.19s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|▎ | 3900/117187 [2:21:22<67:54:57, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|▎ | 3999/117187 [2:24:56<68:50:27, 2.19s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.770\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|▎ | 4000/117187 [2:24:58<71:29:54, 2.27s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["✅ Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_4000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 3%|▎ | 4100/117187 [2:28:35<68:26:47, 2.18s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|▎ | 4200/117187 [2:32:12<67:47:18, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|▎ | 4300/117187 [2:35:49<67:43:33, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.738\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|▍ | 4400/117187 [2:39:25<68:20:20, 2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|▍ | 4500/117187 [2:43:01<68:01:47, 2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.652\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|▍ | 4600/117187 [2:46:38<67:19:53, 2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.605\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|▍ | 4700/117187 [2:50:15<68:15:17, 2.18s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.609\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|▍ | 4800/117187 [2:53:51<66:49:04, 2.14s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.688\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|▍ | 4900/117187 [2:57:28<67:53:59, 2.18s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.766\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|▍ | 4999/117187 [3:01:03<67:34:57, 2.17s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.719\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|▍ | 5000/117187 [3:01:05<70:09:23, 2.25s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["✅ Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_5000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|▍ | 5100/117187 [3:04:42<67:32:22, 2.17s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.734\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 4%|▍ | 5200/117187 [3:08:18<67:25:42, 2.17s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.746\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|▍ | 5300/117187 [3:11:54<67:17:14, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|▍ | 5400/117187 [3:15:30<67:11:10, 2.16s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.754\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|▍ | 5500/117187 [3:19:06<67:10:37, 2.17s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.754\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|▍ | 5600/117187 [3:22:43<66:17:06, 2.14s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.750\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|▍ | 5700/117187 [3:26:20<67:08:19, 2.17s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.773\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|▍ | 5800/117187 [3:29:56<66:15:35, 2.14s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|▌ | 5900/117187 [3:33:33<65:57:43, 2.13s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.758\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|▌ | 5999/117187 [3:37:06<65:53:32, 2.13s/it, loss=1.36]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.770\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|▌ | 6000/117187 [3:37:09<68:33:52, 2.22s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["✅ Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_6000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|▌ | 6100/117187 [3:40:44<66:08:48, 2.14s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.742\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|▌ | 6200/117187 [3:44:19<68:30:20, 2.22s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.629\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|▌ | 6300/117187 [3:47:56<66:18:20, 2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.605\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 5%|▌ | 6400/117187 [3:51:32<66:18:28, 2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.570\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|▌ | 6500/117187 [3:55:08<65:36:41, 2.13s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.621\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|▌ | 6600/117187 [3:58:45<65:56:19, 2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.637\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|▌ | 6700/117187 [4:02:21<66:20:04, 2.16s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.664\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|▌ | 6800/117187 [4:05:57<67:02:27, 2.19s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.652\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|▌ | 6900/117187 [4:09:33<65:58:51, 2.15s/it, loss=1.35]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.570\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|▌ | 6999/117187 [4:13:08<65:45:42, 2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.602\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|▌ | 7000/117187 [4:13:11<68:46:49, 2.25s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["✅ Model saved to /content/drive/MyDrive/dual_shunt_runs_omega_g/dual_shunt_omega_no_caption_noised_e1_step_7000.safetensors\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|▌ | 7100/117187 [4:16:47<65:48:45, 2.15s/it, loss=1.34]"]},{"output_type":"stream","name":"stdout","text":["Gate mean: 0.270 log_sigma mean: -0.617\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1/1: 6%|▌ | 7159/117187 [4:18:55<65:55:30, 2.16s/it, loss=1.34]"]}]},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"fy3PbHL4XEv3","executionInfo":{"status":"ok","timestamp":1748024232657,"user_tz":420,"elapsed":23471,"user":{"displayName":"P C","userId":"00707517734723903966"}},"outputId":"fdc85aa5-3920-44bc-8bc3-81f8d8a90cf2"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"markdown","source":["# summarization train"],"metadata":{"id":"EKLDJqI17Tkm"}},{"cell_type":"code","source":["from pathlib import Path\n","from tqdm import tqdm\n","import torch\n","import torch.nn.functional as F\n","from torch.utils.data import DataLoader\n","\n","def train_dual_shunt(\n"," t5_tok, clip_tok, t5_mod, clip_mod, adapter,\n"," repo_id: str,\n"," batch_size: int = 256,\n"," epochs: int = 1,\n"," lr: float = 5e-4,\n"," device: str = \"cuda\",\n"," g_target: float = 7.5,\n"," λ_delta: float = 1.0,\n"," λ_gate: float = 0.15,\n"," λ_ent: float = 0.02,\n"," λ_tau: float = 0.02,\n"," λ_anchor: float = 0.12,\n"," λ_guid: float = 0.1,\n"," save_every: int = 1000,\n"," max_token_length: int = 77,\n"," seed: int = 42,\n"," t5_prompt_mode: str = \"summarization_curriculum\",\n"," output_dir: str = \"/content/drive/MyDrive/dual_shunt_runs_omega_g\"\n","):\n"," def seed_everything(seed: int = 42):\n"," torch.manual_seed(seed)\n"," torch.cuda.manual_seed_all(seed)\n"," torch.backends.cudnn.deterministic = True\n"," torch.backends.cudnn.benchmark = False\n","\n"," def hetero_loss(delta, target, log_sigma):\n"," LOG_2PI = 1.83787706641\n"," log_sigma = log_sigma.clamp(-5.0, 5.0)\n"," inv_var = torch.exp(-log_sigma)\n"," return 0.5 * (inv_var * (delta - target) ** 2 + log_sigma + LOG_2PI).mean()\n","\n"," def entropy(p: torch.Tensor) -> torch.Tensor:\n"," p = p / (p.sum(-1, keepdim=True) + 1e-9)\n"," return -(p * (p + 1e-9).log()).sum(-1).mean()\n","\n"," # Setup\n"," seed_everything(seed)\n"," device = torch.device(device)\n"," ds = ParsedMultiCharDataset(repo_id, start_file=50, num_files=6, shuffle=True)\n"," dl = DataLoader(ds, batch_size=batch_size, num_workers=4, drop_last=True,\n"," persistent_workers=True, generator=torch.Generator().manual_seed(seed),\n"," prefetch_factor=8)\n","\n"," output_dir = Path(output_dir)\n"," output_dir.mkdir(parents=True, exist_ok=True)\n"," global_step = 0\n","\n"," opt = torch.optim.AdamW(adapter.parameters(), lr=lr)\n"," scaler = torch.cuda.amp.GradScaler()\n","\n"," for epoch in range(1, epochs + 1):\n"," pbar = tqdm(dl, desc=f\"Epoch {epoch}/{epochs}\")\n"," for texts in pbar:\n"," global_step += 1\n","\n"," with torch.autocast(device.type, dtype=torch.bfloat16):\n"," # T5 encoding\n"," t5_inputs = t5_tok(texts, padding=True, truncation=True,\n"," max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," with torch.no_grad():\n"," t5_seq = t5_mod(**t5_inputs).last_hidden_state\n","\n"," # CLIP (plain)\n"," clip_inputs = clip_tok(texts, padding=\"max_length\", truncation=True,\n"," max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," with torch.no_grad():\n"," clip_seq_plain = clip_mod(**clip_inputs).last_hidden_state\n","\n"," # Generate summaries\n"," with torch.no_grad():\n"," summary_ids = t5_mod.generate(**t5_inputs, max_length=max_token_length)\n"," cap_texts = t5_tok.batch_decode(summary_ids, skip_special_tokens=True)\n"," clip_inputs_cap = clip_tok(cap_texts, padding=\"max_length\", truncation=True,\n"," max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," clip_seq_summary = clip_mod(**clip_inputs_cap).last_hidden_state\n","\n"," null_clip_inputs = clip_tok([\"\"] * len(texts), padding=\"max_length\", truncation=True,\n"," max_length=max_token_length, return_tensors=\"pt\").to(device)\n"," clip_seq_null = clip_mod(**null_clip_inputs).last_hidden_state\n","\n"," # Curriculum interpolation\n"," interpolation = min(global_step / 10000.0, 1.0)\n"," clip_seq_cap = (1.0 - interpolation) * clip_seq_null + interpolation * clip_seq_summary\n"," delta_tgt = clip_seq_cap - clip_seq_plain\n","\n"," # Adapter forward\n"," anchor, delta, log_sigma, attn_t2c, attn_c2t, tau, g_pred, gate = adapter(t5_seq, clip_seq_plain)\n","\n"," with torch.no_grad():\n"," t5_b = adapter.proj_t5(t5_seq)\n"," clip_b = adapter.proj_clip(clip_seq_plain)\n"," t2c_base, _ = adapter.cross_t2c(t5_b, clip_b, clip_b)\n"," c2t_base, _ = adapter.cross_c2t(clip_b, t5_b, t5_b)\n","\n"," pocket = adapter.pocket_blocks(t2c_base.detach())\n"," loss_pocket = F.mse_loss(pocket, t2c_base.detach()) * 0.05\n","\n"," fused_h = adapter.fuse(torch.cat([\n"," pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1),\n"," c2t_base\n"," ], dim=-1))\n"," loss_hidden = (fused_h.norm(dim=-1).mean() - 1.0).abs() * 0.01\n","\n"," # Losses\n"," warmup = min(global_step / 5000.0, 1.0)\n"," loss_delta = hetero_loss(delta, delta_tgt, log_sigma) * λ_delta * warmup\n","\n"," gate_target_mean = 0.25 + 0.15 * interpolation\n"," loss_gate = (gate.mean() - gate_target_mean).abs() * λ_gate\n","\n"," loss_ent = 0.5 * (entropy(attn_t2c) + entropy(attn_c2t)) * λ_ent\n"," loss_tau = tau.abs().mean() * λ_tau\n"," loss_anchor = (1 - F.cosine_similarity(anchor.mean(1), clip_seq_plain.mean(1), dim=-1).mean()) * λ_anchor\n"," loss_guid = F.mse_loss(g_pred, torch.full_like(g_pred, g_target)) * λ_guid\n","\n"," total_loss = (\n"," loss_delta + loss_gate + loss_ent +\n"," loss_tau + loss_anchor + loss_guid +\n"," loss_pocket + loss_hidden\n"," )\n","\n"," # Backprop + step\n"," scaler.scale(total_loss).backward()\n"," scaler.step(opt)\n"," scaler.update()\n"," opt.zero_grad(set_to_none=True)\n","\n"," if global_step % 100 == 0:\n"," print(f\"[Step {global_step}] Loss: {total_loss.item():.4f} Gate Mean: {gate.mean().item():.3f}\")\n","\n"," if global_step % save_every == 0:\n"," save_path = output_dir / f\"dual_shunt_omega_{t5_prompt_mode}_e{epoch}_step_{global_step}.safetensors\"\n"," save_safetensors(adapter, save_path)\n","\n"," final = output_dir / f\"dual_shunt_omega_{t5_prompt_mode}_e{epoch}_final.safetensors\"\n"," save_safetensors(adapter, final)\n"," print(f\"✅ Epoch {epoch} complete. Final model saved to {final}\")\n"],"metadata":{"id":"iXmrn28M7L9L","executionInfo":{"status":"ok","timestamp":1748637860921,"user_tz":420,"elapsed":31,"user":{"displayName":"P C","userId":"00707517734723903966"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["train_dual_shunt(\n"," t5_tok=t5_tok,\n"," clip_tok=clip_tok,\n"," t5_mod=t5_mod,\n"," clip_mod=clip_mod,\n"," adapter=adapter,\n"," repo_id=\"AbstractPhil/human-templated-captions-1b\",\n"," batch_size=256,\n"," epochs=1,\n"," lr=1e-4,\n"," device=\"cuda\",\n"," g_target=7.5,\n"," save_every=1000,\n"," seed=42,\n"," max_token_length=77,\n"," t5_prompt_mode=\"no_caption\",\n"," output_dir=\"/content/drive/MyDrive/dual_shunt_runs_omega_g\"\n",")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":408},"id":"aI2BH6Ja7YqZ","executionInfo":{"status":"error","timestamp":1748637873437,"user_tz":420,"elapsed":3426,"user":{"displayName":"P C","userId":"00707517734723903966"}},"outputId":"67122a9c-64f5-4cc0-85fa-7f5b810aed77"},"execution_count":8,"outputs":[{"output_type":"stream","name":"stderr","text":[":56: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n"," scaler = torch.cuda.amp.GradScaler()\n","Epoch 1/1: 0%| | 0/117187 [00:01\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m train_dual_shunt(\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mt5_tok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mt5_tok\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mclip_tok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclip_tok\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mt5_mod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mt5_mod\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mclip_mod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclip_mod\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m\u001b[0m in \u001b[0;36mtrain_dual_shunt\u001b[0;34m(t5_tok, clip_tok, t5_mod, clip_mod, adapter, repo_id, batch_size, epochs, lr, device, g_target, λ_delta, λ_gate, λ_ent, λ_tau, λ_anchor, λ_guid, save_every, max_token_length, seed, t5_prompt_mode, output_dir)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;31m# Generate summaries\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 78\u001b[0;31m \u001b[0msummary_ids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt5_mod\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mt5_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_length\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_token_length\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 79\u001b[0m \u001b[0mcap_texts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt5_tok\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_decode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msummary_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mskip_special_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 80\u001b[0m clip_inputs_cap = clip_tok(cap_texts, padding=\"max_length\", truncation=True,\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1926\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1927\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1928\u001b[0;31m raise AttributeError(\n\u001b[0m\u001b[1;32m 1929\u001b[0m \u001b[0;34mf\"'{type(self).__name__}' object has no attribute '{name}'\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1930\u001b[0m )\n","\u001b[0;31mAttributeError\u001b[0m: 'T5EncoderModel' object has no attribute 'generate'"]}]}]}