diff --git "a/src/Train.ipynb" "b/src/Train.ipynb" new file mode 100644--- /dev/null +++ "b/src/Train.ipynb" @@ -0,0 +1,1847 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "f4d95fac-ac1d-473c-ab96-650f76e6aaf5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# # Code to convert this notebook to .py if you want to run it via command line or with Slurm\n", + "# from subprocess import call\n", + "# command = \"jupyter nbconvert Train.ipynb --to python\"\n", + "# call(command,shell=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e012513-880e-4f88-9680-013397af1c8f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "b0f0f4f3", + "metadata": { + "tags": [] + }, + "source": [ + "# Import packages & functions" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5bad764b-45c1-45ce-a716-8d055e09821a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-09-05 13:05:25,854] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "import json\n", + "import argparse\n", + "import numpy as np\n", + "import time\n", + "import random\n", + "import h5py\n", + "from tqdm import tqdm\n", + "\n", + "import webdataset as wds\n", + "import gc\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "import torch.nn as nn\n", + "from torchvision import transforms\n", + "from flash_attn import flash_attn_qkvpacked_func, flash_attn_func\n", + "\n", + "# tf32 data type is faster than standard float32\n", + "torch.backends.cuda.matmul.allow_tf32 = True\n", + "\n", + "# custom functions #\n", + "import utils" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "cc5d2e32-6027-4a19-bef4-5ca068db35bb", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LOCAL RANK 0\n", + "[2023-09-05 13:05:34,712] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented\n", + "[2023-09-05 13:05:34,712] [INFO] [comm.py:594:init_distributed] cdb=None\n", + "[2023-09-05 13:05:34,713] [INFO] [comm.py:625:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl\n" + ] + } + ], + "source": [ + "local_rank = os.getenv('RANK')\n", + "if local_rank is None: \n", + " local_rank = 0\n", + "else:\n", + " local_rank = int(local_rank)\n", + "print(\"LOCAL RANK \", local_rank)\n", + "\n", + "### Single-GPU config ###\n", + "## Feel free to uncomment the below 4 lines and comment out all the multi-gpu config code to simplify things for single-gpu\n", + "# from accelerate import Accelerator\n", + "# num_devices = torch.cuda.device_count()\n", + "# if num_devices==0: num_devices = 1\n", + "# accelerator = Accelerator(split_batches=False)\n", + "# global_batch_size = 128\n", + " \n", + "### Multi-GPU config ###\n", + "from accelerate import Accelerator, DeepSpeedPlugin\n", + "num_devices = torch.cuda.device_count()\n", + "if num_devices==0: num_devices = 1\n", + "if num_devices <= 1 and utils.is_interactive():\n", + " # can emulate a distributed environment for deepspeed to work in jupyter notebook\n", + " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", + " os.environ[\"MASTER_PORT\"] = str(np.random.randint(10000)+9000)\n", + " os.environ[\"RANK\"] = \"0\"\n", + " os.environ[\"LOCAL_RANK\"] = \"0\"\n", + " os.environ[\"WORLD_SIZE\"] = \"1\"\n", + " os.environ[\"GLOBAL_BATCH_SIZE\"] = \"128\" # set this to your batch size!\n", + " global_batch_size = os.environ[\"GLOBAL_BATCH_SIZE\"]\n", + "\n", + "# alter the deepspeed config according to your global and local batch size\n", + "if local_rank == 0:\n", + " with open('deepspeed_config_stage2.json', 'r') as file:\n", + " config = json.load(file)\n", + " config['train_batch_size'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"])\n", + " config['train_micro_batch_size_per_gpu'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"]) // num_devices\n", + " with open('deepspeed_config_stage2.json', 'w') as file:\n", + " json.dump(config, file)\n", + "else:\n", + " # give some time for the local_rank=0 gpu to prep new deepspeed config file\n", + " time.sleep(10)\n", + "deepspeed_plugin = DeepSpeedPlugin(\"deepspeed_config_stage2.json\")\n", + "accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b767ab6f-d4a9-47a5-b3bf-f56bf6760c0c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PID of this process = 804611\n", + "device: cuda:0\n", + "Distributed environment: DEEPSPEED Backend: nccl\n", + "Num processes: 1\n", + "Process index: 0\n", + "Local process index: 0\n", + "Device: cuda:0\n", + "\n", + "Mixed precision type: fp16\n", + "ds_config: {'bf16': {'enabled': False}, 'fp16': {'enabled': True}, 'zero_optimization': {'stage': 2, 'contiguous_gradients': True, 'stage3_gather_16bit_weights_on_model_save': True, 'stage3_max_live_parameters': 1000000000.0, 'stage3_max_reuse_distance': 1000000000.0, 'stage3_prefetch_bucket_size': 10000000.0, 'stage3_param_persistence_threshold': 100000.0, 'reduce_bucket_size': 10000000.0, 'sub_group_size': 1000000000.0, 'offload_optimizer': {'device': 'none', 'nvme_path': '/scratch', 'pin_memory': True}, 'offload_param': {'device': 'none', 'nvme_path': '/scratch', 'buffer_size': 4000000000.0, 'pin_memory': True}}, 'aio': {'block_size': 26214400, 'queue_depth': 32, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}, 'gradient_accumulation_steps': 1, 'gradient_clipping': 1.0, 'steps_per_print': inf, 'train_batch_size': 128, 'train_micro_batch_size_per_gpu': 128, 'wall_clock_breakdown': False, 'zero_allow_untested_optimizer': True}\n", + "\n", + "distributed = True num_devices = 1 local rank = 0 world size = 1\n" + ] + } + ], + "source": [ + "print(\"PID of this process =\",os.getpid())\n", + "device = accelerator.device\n", + "print(\"device:\",device)\n", + "num_workers = num_devices\n", + "print(accelerator.state)\n", + "world_size = accelerator.state.num_processes\n", + "distributed = not accelerator.state.distributed_type == 'NO'\n", + "print(\"distributed =\",distributed, \"num_devices =\", num_devices, \"local rank =\", local_rank, \"world size =\", world_size)\n", + "print = accelerator.print # only print if local_rank=0" + ] + }, + { + "cell_type": "markdown", + "id": "9018b82b-c054-4463-9527-4b0c2a75bda6", + "metadata": { + "tags": [] + }, + "source": [ + "# Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2b61fec7-72a0-4b67-86da-1375f1d9fbd3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset', '--model_name=test', '--subj=1', '--batch_size=128', '--n_samples_save=0', '--max_lr=5e-3', '--mixup_pct=.66', '--num_epochs=12', '--ckpt_interval=999', '--no-use_image_aug']\n" + ] + } + ], + "source": [ + "# if running this interactively, can specify jupyter_args here for argparser to use\n", + "if utils.is_interactive():\n", + " # Example use\n", + " jupyter_args = f\"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \\\n", + " --model_name=test \\\n", + " --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \\\n", + " --max_lr=5e-3 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug\"\n", + "\n", + " jupyter_args = jupyter_args.split()\n", + " print(jupyter_args)\n", + " \n", + " from IPython.display import clear_output # function to clear print outputs in cell\n", + " %load_ext autoreload \n", + " # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions\n", + " %autoreload 2 " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2028bdf0-2f41-46d9-b6e7-86b870dbf16c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "global batch_size 128\n", + "batch_size 128\n" + ] + } + ], + "source": [ + "parser = argparse.ArgumentParser(description=\"Model Training Configuration\")\n", + "parser.add_argument(\n", + " \"--model_name\", type=str, default=\"testing\",\n", + " help=\"name of model, used for ckpt saving and wandb logging (if enabled)\",\n", + ")\n", + "parser.add_argument(\n", + " \"--data_path\", type=str, default=\"/fsx/proj-fmri/shared/natural-scenes-dataset\",\n", + " help=\"Path to where NSD data is stored / where to download it to\",\n", + ")\n", + "parser.add_argument(\n", + " \"--subj\",type=int, default=1, choices=[1,2,5,7],\n", + ")\n", + "parser.add_argument(\n", + " \"--batch_size\", type=int, default=32,\n", + " help=\"Batch size can be increased by 10x if only training v2c and not diffusion prior\",\n", + ")\n", + "parser.add_argument(\n", + " \"--wandb_log\",action=argparse.BooleanOptionalAction,default=False,\n", + " help=\"whether to log to wandb\",\n", + ")\n", + "parser.add_argument(\n", + " \"--resume_from_ckpt\",action=argparse.BooleanOptionalAction,default=False,\n", + " help=\"if not using wandb and want to resume from a ckpt\",\n", + ")\n", + "parser.add_argument(\n", + " \"--wandb_project\",type=str,default=\"stability\",\n", + " help=\"wandb project name\",\n", + ")\n", + "parser.add_argument(\n", + " \"--mixup_pct\",type=float,default=.33,\n", + " help=\"proportion of way through training when to switch from BiMixCo to SoftCLIP\",\n", + ")\n", + "parser.add_argument(\n", + " \"--use_image_aug\",action=argparse.BooleanOptionalAction,default=True,\n", + " help=\"whether to use image augmentation\",\n", + ")\n", + "parser.add_argument(\n", + " \"--num_epochs\",type=int,default=240,\n", + " help=\"number of epochs of training\",\n", + ")\n", + "parser.add_argument(\n", + " \"--lr_scheduler_type\",type=str,default='cycle',choices=['cycle','linear'],\n", + ")\n", + "parser.add_argument(\n", + " \"--ckpt_saving\",action=argparse.BooleanOptionalAction,default=True,\n", + ")\n", + "parser.add_argument(\n", + " \"--ckpt_interval\",type=int,default=5,\n", + " help=\"save backup ckpt and reconstruct every x epochs\",\n", + ")\n", + "parser.add_argument(\n", + " \"--seed\",type=int,default=42,\n", + ")\n", + "parser.add_argument(\n", + " \"--max_lr\",type=float,default=3e-4,\n", + ")\n", + "parser.add_argument(\n", + " \"--n_samples_save\",type=int,default=0,choices=[0,1],\n", + " help=\"Number of reconstructions for monitoring progress, 0 will speed up training\",\n", + ")\n", + "\n", + "if utils.is_interactive():\n", + " args = parser.parse_args(jupyter_args)\n", + "else:\n", + " args = parser.parse_args()\n", + "\n", + "# create global variables without the args prefix\n", + "for attribute_name in vars(args).keys():\n", + " globals()[attribute_name] = getattr(args, attribute_name)\n", + "\n", + "print(\"global batch_size\", batch_size)\n", + "batch_size = int(batch_size / num_devices)\n", + "print(\"batch_size\", batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "60cd7f2c-37fd-426b-a0c6-633e51bc4c4d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "outdir = os.path.abspath(f'../train_logs/{model_name}')\n", + "if not os.path.exists(outdir):\n", + " os.makedirs(outdir,exist_ok=True)\n", + "if use_image_aug:\n", + " import kornia\n", + " from kornia.augmentation.container import AugmentationSequential\n", + " img_augment = AugmentationSequential(\n", + " kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n", + " kornia.augmentation.Resize((224, 224)),\n", + " kornia.augmentation.RandomHorizontalFlip(p=0.3),\n", + " kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n", + " kornia.augmentation.RandomGrayscale(p=0.3),\n", + " same_on_batch=False,\n", + " data_keys=[\"input\"],\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "42d13c25-1369-4c49-81d4-83d713586096", + "metadata": { + "tags": [] + }, + "source": [ + "# Prep data, models, and dataloaders" + ] + }, + { + "cell_type": "markdown", + "id": "1c023f24-5233-4a15-a2f5-78487b3a8546", + "metadata": {}, + "source": [ + "## Dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "81084834-035f-4465-ad59-59e6b806a2f5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar\n", + "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar\n" + ] + } + ], + "source": [ + "if subj==1:\n", + " num_train = 24958\n", + " num_test = 2770\n", + "test_batch_size = num_test\n", + "\n", + "def my_split_by_node(urls): return urls\n", + " \n", + "train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..36}.tar\"\n", + "print(train_url)\n", + "\n", + "train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\\\n", + " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", + " .decode(\"torch\")\\\n", + " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", + " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", + "train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)\n", + "\n", + "test_url = f\"{data_path}/wds/subj0{subj}/test/\" + \"0.tar\"\n", + "print(test_url)\n", + "\n", + "test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\\\n", + " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", + " .decode(\"torch\")\\\n", + " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", + " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", + "test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)" + ] + }, + { + "cell_type": "markdown", + "id": "203b060a-2dd2-4c35-929b-c576be82eb52", + "metadata": {}, + "source": [ + "### check dataloaders are working" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e7a9c68c-c3c9-4080-bd99-067c4486dc37", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# test_indices = []\n", + "# test_images = []\n", + "# for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n", + "# test_indices = np.append(test_indices, behav[:,0,5].numpy())\n", + "# test_images = np.append(test_images, behav[:,0,0].numpy())\n", + "# test_indices = test_indices.astype(np.int16)\n", + "# print(test_i, (test_i+1) * test_batch_size, len(test_indices))\n", + "# print(\"---\\n\")\n", + "\n", + "# train_indices = []\n", + "# train_images = []\n", + "# for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", + "# train_indices = np.append(train_indices, behav[:,0,5].long().numpy())\n", + "# train_images = np.append(train_images, behav[:,0,0].numpy())\n", + "# train_indices = train_indices.astype(np.int16)\n", + "# print(train_i, (train_i+1) * batch_size, len(train_indices))" + ] + }, + { + "cell_type": "markdown", + "id": "45fad12c-f9fb-4408-8fd4-9bca324ad634", + "metadata": {}, + "source": [ + "## Load voxel betas, K-means clustering model, and images" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "039dd330-7339-4f88-8f00-45f95e47baa0", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "subj01 betas loaded into memory\n", + "voxels torch.Size([27750, 15729])\n", + "images torch.Size([73000, 3, 224, 224])\n" + ] + } + ], + "source": [ + "# load betas\n", + "f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')\n", + "voxels = f['betas'][:]\n", + "print(f\"subj0{subj} betas loaded into memory\")\n", + "voxels = torch.Tensor(voxels).to(\"cpu\").half()\n", + "if subj==1:\n", + " voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))\n", + "print(\"voxels\", voxels.shape)\n", + "num_voxels = voxels.shape[-1]\n", + "\n", + "# load orig images\n", + "f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')\n", + "images = f['images'][:]\n", + "images = torch.Tensor(images).to(\"cpu\").half()\n", + "print(\"images\", images.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b0420dc0-199e-4c1a-857d-b1747058b467", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ViT-L/14 cuda:0\n" + ] + } + ], + "source": [ + "from models import Clipper\n", + "eva02_model = Clipper(\"ViT-L/14\", device=torch.device(f\"cuda:{local_rank}\"), hidden_state=True, norm_embs=True)\n", + "\n", + "clip_seq_dim = 257\n", + "clip_emb_dim = 768\n", + "hidden_dim = 4096" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c44c271b-173f-472e-b059-a2eda0f4c4c5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "MindEyeModule()" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class MindEyeModule(nn.Module):\n", + " def __init__(self):\n", + " super(MindEyeModule, self).__init__()\n", + " def forward(self, x):\n", + " return x\n", + " \n", + "model = MindEyeModule()\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "038a5d61-4769-40b9-a004-f4e7b5b38bb0", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "param counts:\n", + "64,430,080 total\n", + "64,430,080 trainable\n", + "param counts:\n", + "64,430,080 total\n", + "64,430,080 trainable\n", + "torch.Size([2, 15729]) torch.Size([2, 4096])\n" + ] + } + ], + "source": [ + "class RidgeRegression(torch.nn.Module):\n", + " # make sure to add weight_decay when initializing optimizer\n", + " def __init__(self, input_size, out_features): \n", + " super(RidgeRegression, self).__init__()\n", + " self.linear = torch.nn.Linear(input_size, out_features)\n", + " def forward(self, x):\n", + " return self.linear(x)\n", + " \n", + "model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)\n", + "utils.count_params(model.ridge)\n", + "utils.count_params(model)\n", + "\n", + "b = torch.randn((2,voxels.shape[1]))\n", + "print(b.shape, model.ridge(b).shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "8c7f2d47-08a4-40d9-ba63-a6b11c559d42", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\"from functools import partial\\nclass BrainNetwork(nn.Module):\\n def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):\\n super().__init__()\\n norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\\n act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\\n act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\\n self.mlp = nn.ModuleList([\\n nn.Sequential(\\n nn.Linear(h, h),\\n *[item() for item in act_and_norm],\\n nn.Dropout(drop2)\\n ) for _ in range(n_blocks)\\n ])\\n self.lin1 = nn.Linear(h, out_dim, bias=True)\\n self.n_blocks = n_blocks\\n self.clip_size = clip_size\\n self.use_projector = use_projector\\n if use_projector:\\n self.projector = nn.Sequential(\\n nn.LayerNorm(clip_size),\\n nn.GELU(),\\n nn.Linear(clip_size, 2048),\\n nn.LayerNorm(2048),\\n nn.GELU(),\\n nn.Linear(2048, 2048),\\n nn.LayerNorm(2048),\\n nn.GELU(),\\n nn.Linear(2048, clip_size)\\n )\\n \\n def forward(self, x):\\n residual = x\\n for res_block in range(self.n_blocks):\\n x = self.mlp[res_block](x)\\n x += residual\\n residual = x\\n print(x.shape)\\n x = x.reshape(len(x), -1)\\n print(x.shape)\\n x = self.lin1(x)\\n print(x.shape)\\n if self.use_projector:\\n return self.projector(x.reshape(len(x), -1, self.clip_size))\\n return x\\n\\nmodel.backbone = BrainNetwork(in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_seq_dim*clip_emb_dim, use_projector=True)\\nutils.count_params(model.backbone)\\nutils.count_params(model)\\n\\nb = torch.randn((2,hidden_dim))\\nprint(b.shape, model.backbone(b).shape)\"" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\"\"\"from functools import partial\n", + "class BrainNetwork(nn.Module):\n", + " def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):\n", + " super().__init__()\n", + " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n", + " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n", + " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n", + " self.mlp = nn.ModuleList([\n", + " nn.Sequential(\n", + " nn.Linear(h, h),\n", + " *[item() for item in act_and_norm],\n", + " nn.Dropout(drop2)\n", + " ) for _ in range(n_blocks)\n", + " ])\n", + " self.lin1 = nn.Linear(h, out_dim, bias=True)\n", + " self.n_blocks = n_blocks\n", + " self.clip_size = clip_size\n", + " self.use_projector = use_projector\n", + " if use_projector:\n", + " self.projector = nn.Sequential(\n", + " nn.LayerNorm(clip_size),\n", + " nn.GELU(),\n", + " nn.Linear(clip_size, 2048),\n", + " nn.LayerNorm(2048),\n", + " nn.GELU(),\n", + " nn.Linear(2048, 2048),\n", + " nn.LayerNorm(2048),\n", + " nn.GELU(),\n", + " nn.Linear(2048, clip_size)\n", + " )\n", + " \n", + " def forward(self, x):\n", + " residual = x\n", + " for res_block in range(self.n_blocks):\n", + " x = self.mlp[res_block](x)\n", + " x += residual\n", + " residual = x\n", + " print(x.shape)\n", + " x = x.reshape(len(x), -1)\n", + " print(x.shape)\n", + " x = self.lin1(x)\n", + " print(x.shape)\n", + " if self.use_projector:\n", + " return self.projector(x.reshape(len(x), -1, self.clip_size))\n", + " return x\n", + "\n", + "model.backbone = BrainNetwork(in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_seq_dim*clip_emb_dim, use_projector=True)\n", + "utils.count_params(model.backbone)\n", + "utils.count_params(model)\n", + "\n", + "b = torch.randn((2,hidden_dim))\n", + "print(b.shape, model.backbone(b).shape)\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "863fcb22-f588-480f-ad1c-14bcda9130ef", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "param counts:\n", + "883,253,376 total\n", + "883,253,376 trainable\n", + "param counts:\n", + "947,683,456 total\n", + "947,683,456 trainable\n" + ] + } + ], + "source": [ + "from functools import partial\n", + "class BrainNetwork(nn.Module):\n", + " def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):\n", + " super().__init__()\n", + " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n", + " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n", + " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n", + " self.mlp = nn.ModuleList([\n", + " nn.Sequential(\n", + " nn.Linear(h, h),\n", + " *[item() for item in act_and_norm],\n", + " nn.Dropout(drop2)\n", + " ) for _ in range(n_blocks)\n", + " ])\n", + " self.lin1 = nn.Linear(h, out_dim, bias=True)\n", + " self.n_blocks = n_blocks\n", + " self.clip_size = clip_size\n", + " self.use_projector = use_projector\n", + " if use_projector:\n", + " self.projector = nn.Sequential(\n", + " nn.LayerNorm(clip_size),\n", + " nn.GELU(),\n", + " nn.Linear(clip_size, 2048),\n", + " nn.LayerNorm(2048),\n", + " nn.GELU(),\n", + " nn.Linear(2048, 2048),\n", + " nn.LayerNorm(2048),\n", + " nn.GELU(),\n", + " nn.Linear(2048, clip_size)\n", + " )\n", + " \n", + " def forward(self, x):\n", + " residual = x\n", + " for res_block in range(self.n_blocks):\n", + " x = self.mlp[res_block](x)\n", + " x += residual\n", + " residual = x\n", + " x = x.reshape(len(x), -1)\n", + " x = self.lin1(x)\n", + " if self.use_projector:\n", + " return self.projector(x.reshape(len(x), -1, self.clip_size))\n", + " return x\n", + "\n", + "from flash_attn import flash_attn_qkvpacked_func, flash_attn_func\n", + "from einops import rearrange\n", + "\n", + "\n", + "class FeedForward(nn.Module):\n", + " def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.):\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(input_dim, hidden_dim),\n", + " nn.GELU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(hidden_dim, output_dim),\n", + " nn.Dropout(dropout)\n", + " )\n", + " def forward(self, x):\n", + " return self.net(x)\n", + "\n", + "class TransformerBlock(nn.Module):\n", + " def __init__(self, input_embedding_dim = 16, sequence_length = 4096, num_heads = 8, dropout_attention = 0.3, dropout_residual = 0.3, dropout_pre_norm = 0.3, use_pre_norm = True, voxels_embeddings = False, embedding_in = False):\n", + " super().__init__()\n", + " self.attention = flash_attn_qkvpacked_func\n", + " self.feed_forward = FeedForward(input_embedding_dim, input_embedding_dim*4, input_embedding_dim)\n", + " self.norm1 = nn.LayerNorm(input_embedding_dim)\n", + " self.norm2 = nn.LayerNorm(input_embedding_dim)\n", + "\n", + " self.dropout_attention = nn.Dropout(dropout_attention)\n", + " self.dropout_residual = nn.Dropout(dropout_residual)\n", + " self.dropout_pre_norm = nn.Dropout(dropout_pre_norm)\n", + "\n", + " self.use_pre_norm = use_pre_norm\n", + " self.voxels_embeddings = voxels_embeddings\n", + "\n", + " if self.voxels_embeddings:\n", + " self.voxels_embeddings_projection = nn.Linear(1, input_embedding_dim)\n", + "\n", + " self.num_heads = num_heads\n", + " self.input_embedding_dim = input_embedding_dim\n", + " self.sequence_length = sequence_length\n", + " self.embedding_in = embedding_in\n", + "\n", + " #query, key, value projection for each head\n", + " self.qkv_projection = nn.ModuleList([\n", + " nn.Linear(input_embedding_dim, input_embedding_dim*3//self.num_heads, bias=False) for _ in range(num_heads)\n", + " ])\n", + "\n", + " def forward(self, voxels):\n", + " # x: (batch_size, voxels_shape)\n", + " if not self.embedding_in:\n", + " if self.voxels_embeddings:\n", + " voxels = self.voxels_embeddings_projection(voxels.unsqueeze(-1)) # (batch_size, voxels_shape, input_embedding_dim)\n", + " else:\n", + " # reshape voxels to (batch_size, voxels_shape//input_embedding_dim, input_embedding_dim)\n", + " voxels = rearrange(voxels, 'b (s i) -> b s i', s=self.sequence_length//self.input_embedding_dim, i=self.input_embedding_dim)\n", + " # voxels: (batch_size, sequence_length, input_embedding_dim)\n", + " voxels = self.dropout_pre_norm(voxels)\n", + " voxels = self.norm1(voxels)\n", + " voxels = self.dropout_attention(voxels)\n", + " qkv = torch.stack([proj(voxels) for proj in self.qkv_projection], dim=3) # (batch_size, sequence_length, 3, num_heads, input_embedding_dim)\n", + " qkv = rearrange(qkv, 'b t (hd kqv) h -> b t kqv h hd', kqv = 3).type(torch.float16).to('cuda')\n", + " qkv = self.attention(qkv, self.dropout_attention.p)\n", + " qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], -1) # (batch_size, sequence_length, input_embedding_dim*3*num_heads)\n", + " qkv = self.dropout_residual(qkv)\n", + " voxels = voxels + qkv\n", + " voxels = self.norm2(voxels)\n", + " voxels = self.dropout_residual(voxels)\n", + " voxels = self.feed_forward(voxels)\n", + " return voxels\n", + " \n", + "\n", + " \n", + "\n", + "# create model\n", + "# TransformerBlockTest = TransformerBlock(input_embedding_dim=16, sequence_length=4096, num_heads=8, dropout_attention=0.3, dropout_residual=0.3, dropout_pre_norm=0.3, use_pre_norm=True, voxels_embeddings=True).to('cuda')\n", + "# inputTest = torch.randn((2,4096)).to('cuda')\n", + "# print(inputTest.shape, TransformerBlockTest(inputTest).shape)\n", + "\n", + "\n", + "\n", + "class BrainTransformerNetwork(nn.Module):\n", + " def __init__(self, out_dim=768, in_dim=15724, embed_dim=16, n_blocks=4, num_heads=8, dropout_attention=0.3, dropout_residual=0.3, dropout_pre_norm=0.3, use_pre_norm=True, voxels_embeddings=True):\n", + " super().__init__()\n", + " self.out_dim = out_dim\n", + " self.in_dim = in_dim\n", + " self.embed_dim = embed_dim\n", + " self.n_blocks = n_blocks\n", + " self.num_heads = num_heads\n", + " self.dropout_attention = dropout_attention\n", + " self.dropout_residual = dropout_residual\n", + " self.dropout_pre_norm = dropout_pre_norm\n", + " self.use_pre_norm = use_pre_norm\n", + " self.voxels_embeddings = voxels_embeddings\n", + " self.transformer = nn.Sequential()\n", + " for i in range(n_blocks):\n", + " if voxels_embeddings and i == 0:\n", + " self.transformer.add_module(f\"transformer_block_{i}\", TransformerBlock(input_embedding_dim=embed_dim, sequence_length=in_dim, num_heads=num_heads, dropout_attention=dropout_attention, dropout_residual=dropout_residual, dropout_pre_norm=dropout_pre_norm, use_pre_norm=use_pre_norm, voxels_embeddings=False, embedding_in = False))\n", + " else:\n", + " self.transformer.add_module(f\"transformer_block_{i}\", TransformerBlock(input_embedding_dim=embed_dim, sequence_length=in_dim//embed_dim, num_heads=num_heads, dropout_attention=dropout_attention, dropout_residual=dropout_residual, dropout_pre_norm=dropout_pre_norm, use_pre_norm=use_pre_norm, voxels_embeddings=False, embedding_in = True))\n", + " \n", + " #self.pre_head_lin = nn.Linear(in_dim//embed_dim * embed_dim, in_dim, bias = False)\n", + " #self.head_lin = nn.Linear(in_dim, out_dim, bias = False)\n", + " self.gelu = nn.GELU()\n", + " self.BrainNetwork = BrainNetwork(in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_seq_dim*clip_emb_dim, use_projector=True)\n", + "\n", + " def forward(self, x):\n", + " x = self.transformer(x)\n", + " x = rearrange(x, 'b s i -> b (s i)')\n", + " x = self.gelu(x)\n", + " x = self.BrainNetwork(x)\n", + " return x\n", + " \n", + "import math\n", + "# create model\n", + "# TestBrainTransformerNetwork = BrainTransformerNetwork(out_dim=768, in_dim=4096, embed_dim=16, n_blocks=4, num_heads=8, dropout_attention=0.3, dropout_residual=0.3, dropout_pre_norm=0.3, use_pre_norm=True, voxels_embeddings=True).to('cuda')\n", + "# inputTest = torch.randn((2,4096)).to('cuda')\n", + "# print(inputTest.shape, TestBrainTransformerNetwork(inputTest).shape)\n", + "\n", + "#model.backbone = BrainNetwork(in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_seq_dim*clip_emb_dim, use_projector=True)\n", + "#utils.count_params(model.backbone)\n", + "#utils.count_params(model)\n", + "\n", + "#b = torch.randn((2,hidden_dim))\n", + "#print(b.shape, model.backbone(b).shape)\n", + "\n", + "model.backbone = BrainTransformerNetwork(out_dim=clip_seq_dim*clip_emb_dim, in_dim=4096, embed_dim=64, n_blocks=2, num_heads=16, dropout_attention=0.5, dropout_residual=0.4, dropout_pre_norm=0.4, use_pre_norm=True, voxels_embeddings=True).to('cuda')\n", + "utils.count_params(model.backbone)\n", + "utils.count_params(model)\n", + "\n", + "#b = torch.randn((2,hidden_dim)).to('cuda')\n", + "#print(b.shape, model.backbone(b).shape)\n", + "b = None" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1ce49f9e-2e43-42fb-8072-c991abfcce79", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#b = torch.randn((2,hidden_dim)).to('cuda')\n", + "#print(b.shape, model.backbone(b).shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "e14d0482-dc42-43b9-9ce1-953c32f2c9c1", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Done with model preparations!\n" + ] + } + ], + "source": [ + "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n", + "opt_grouped_parameters = [\n", + " {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},\n", + " {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 2e-2},\n", + " {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n", + "]\n", + "\n", + "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)\n", + "\n", + "if lr_scheduler_type == 'linear':\n", + " lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n", + " optimizer,\n", + " total_iters=int(num_epochs*(num_train*num_devices//batch_size)),\n", + " last_epoch=-1\n", + " )\n", + "elif lr_scheduler_type == 'cycle':\n", + " total_steps=int(num_epochs*(num_train*num_devices//batch_size))\n", + " lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", + " optimizer, \n", + " max_lr=max_lr,\n", + " total_steps=total_steps,\n", + " final_div_factor=10,\n", + " last_epoch=-1, pct_start=2/num_epochs\n", + " )\n", + " \n", + "def save_ckpt(tag): \n", + " ckpt_path = outdir+f'/{tag}.pth'\n", + " print(f'saving {ckpt_path}',flush=True)\n", + " unwrapped_model = accelerator.unwrap_model(model)\n", + " try:\n", + " torch.save({\n", + " 'epoch': epoch,\n", + " 'model_state_dict': unwrapped_model.state_dict(),\n", + " 'optimizer_state_dict': optimizer.state_dict(),\n", + " 'lr_scheduler': lr_scheduler.state_dict(),\n", + " 'train_losses': losses,\n", + " 'test_losses': test_losses,\n", + " 'lrs': lrs,\n", + " }, ckpt_path)\n", + " except:\n", + " print(\"Couldn't save... moving on to prevent crashing.\")\n", + " del unwrapped_model\n", + " \n", + "print(\"\\nDone with model preparations!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "62a7f9f3-aedb-4c9e-925b-64a3642b8c43", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "wandb_log = True" + ] + }, + { + "cell_type": "markdown", + "id": "983f458b-35b8-49f2-b6db-80296cece730", + "metadata": {}, + "source": [ + "# Weights and Biases" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "0a25a662-daa8-4de9-9233-8364800fcb6b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "wandb mindeye2 run transformer test run\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mckadirt\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "wandb_config:\n", + " {'model_name': 'test', 'batch_size': 128, 'num_epochs': 12, 'use_image_aug': False, 'max_lr': 0.005, 'lr_scheduler_type': 'cycle', 'mixup_pct': 0.66, 'num_train': 24958, 'num_test': 2770, 'seed': 42, 'distributed': True, 'num_devices': 1, 'world_size': 1}\n" + ] + }, + { + "data": { + "text/html": [ + "wandb version 0.15.9 is available! To upgrade, please run:\n", + " $ pip install wandb --upgrade" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.15.5" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb/run-20230905_130625-6wdt860v" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run transformer test run to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://stability.wandb.io/ckadirt/mindeye2" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://stability.wandb.io/ckadirt/mindeye2/runs/6wdt860v" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "# params for wandb\n", + "if local_rank==0 and wandb_log: # only use main process for wandb logging\n", + " import wandb\n", + " \n", + " wandb_project = 'mindeye2'\n", + " wandb_run = 'transformer test run'\n", + " wandb_notes = ''\n", + " \n", + " print(f\"wandb {wandb_project} run {wandb_run}\")\n", + " wandb.login(host='https://stability.wandb.io')#, relogin=True)\n", + " wandb_config = {\n", + " \"model_name\": model_name,\n", + " \"batch_size\": batch_size,\n", + " \"num_epochs\": num_epochs,\n", + " \"use_image_aug\": use_image_aug,\n", + " \"max_lr\": max_lr,\n", + " \"lr_scheduler_type\": lr_scheduler_type,\n", + " \"mixup_pct\": mixup_pct,\n", + " \"num_train\": num_train,\n", + " \"num_test\": num_test,\n", + " \"seed\": seed,\n", + " \"distributed\": distributed,\n", + " \"num_devices\": num_devices,\n", + " \"world_size\": world_size,\n", + " }\n", + " print(\"wandb_config:\\n\",wandb_config)\n", + " if False: # wandb_auto_resume\n", + " print(\"wandb_id:\",model_name)\n", + " wandb.init(\n", + " id = model_name,\n", + " project=wandb_project,\n", + " name=wandb_run,\n", + " config=wandb_config,\n", + " notes=wandb_notes,\n", + " resume=\"allow\",\n", + " )\n", + " else:\n", + " wandb.init(\n", + " project=wandb_project,\n", + " name=wandb_run,\n", + " config=wandb_config,\n", + " notes=wandb_notes,\n", + " )\n", + "else:\n", + " wandb_log = False" + ] + }, + { + "cell_type": "markdown", + "id": "d5690151-2131-4918-b750-e869cbd1a8a8", + "metadata": {}, + "source": [ + "# Main" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "12de6387-6e18-4e4b-b5ce-a847d625330a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "epoch = 0\n", + "losses, test_losses, lrs = [], [], []\n", + "best_test_loss = 1e9\n", + "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n", + "\n", + "resume_from_ckpt = False\n", + "\n", + "# Optionally resume from checkpoint #\n", + "if resume_from_ckpt:\n", + " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", + " try:\n", + " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", + " except:\n", + " print('last.pth failed... trying last_backup.pth')\n", + " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", + " epoch = checkpoint['epoch']\n", + " print(\"Epoch\",epoch)\n", + " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", + " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", + " diffusion_prior.load_state_dict(checkpoint['model_state_dict'])\n", + " del checkpoint\n", + "elif False: #wandb_log:\n", + " if wandb.run.resumed:\n", + " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", + " try:\n", + " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", + " except:\n", + " print('last.pth failed... trying last_backup.pth')\n", + " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", + " epoch = checkpoint['epoch']\n", + " print(\"Epoch\",epoch)\n", + " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", + " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", + " diffusion_prior.load_state_dict(checkpoint['model_state_dict'])\n", + " del checkpoint\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "99f09f76-4481-4133-b09a-a22b10dbc0c4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-09-05 13:06:35,335] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.9.5, git-hash=unknown, git-branch=unknown\n", + "[2023-09-05 13:06:35,733] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False\n", + "[2023-09-05 13:06:35,735] [INFO] [logging.py:96:log_dist] [Rank 0] Removing param_group that has no 'params' in the client Optimizer\n", + "[2023-09-05 13:06:35,735] [INFO] [logging.py:96:log_dist] [Rank 0] Using client Optimizer as basic optimizer\n", + "[2023-09-05 13:06:35,737] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Basic Optimizer = AdamW\n", + "[2023-09-05 13:06:35,737] [INFO] [utils.py:54:is_zero_supported_optimizer] Checking ZeRO support for optimizer=AdamW type=\n", + "[2023-09-05 13:06:35,738] [INFO] [logging.py:96:log_dist] [Rank 0] Creating torch.float16 ZeRO stage 2 optimizer\n", + "[2023-09-05 13:06:35,738] [INFO] [stage_1_and_2.py:133:__init__] Reduce bucket size 10000000\n", + "[2023-09-05 13:06:35,738] [INFO] [stage_1_and_2.py:134:__init__] Allgather bucket size 500,000,000\n", + "[2023-09-05 13:06:35,739] [INFO] [stage_1_and_2.py:135:__init__] CPU Offload: False\n", + "[2023-09-05 13:06:35,739] [INFO] [stage_1_and_2.py:136:__init__] Round robin gradient partitioning: False\n", + "Rank: 0 partition count [1, 1, 1] and sizes[(64430080, False), (883012608, False), (240768, False)] \n", + "[2023-09-05 13:06:37,941] [INFO] [utils.py:785:see_memory_usage] Before initializing optimizer states\n", + "[2023-09-05 13:06:37,942] [INFO] [utils.py:786:see_memory_usage] MA 7.31 GB Max_MA 7.31 GB CA 7.33 GB Max_CA 7 GB \n", + "[2023-09-05 13:06:37,943] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 96.0 GB, percent = 8.6%\n", + "[2023-09-05 13:06:38,117] [INFO] [utils.py:785:see_memory_usage] After initializing optimizer states\n", + "[2023-09-05 13:06:38,118] [INFO] [utils.py:786:see_memory_usage] MA 14.37 GB Max_MA 24.48 GB CA 24.99 GB Max_CA 25 GB \n", + "[2023-09-05 13:06:38,119] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 96.0 GB, percent = 8.6%\n", + "[2023-09-05 13:06:38,120] [INFO] [stage_1_and_2.py:488:__init__] optimizer state initialized\n", + "[2023-09-05 13:06:38,267] [INFO] [utils.py:785:see_memory_usage] After initializing ZeRO optimizer\n", + "[2023-09-05 13:06:38,268] [INFO] [utils.py:786:see_memory_usage] MA 14.37 GB Max_MA 14.37 GB CA 24.99 GB Max_CA 25 GB \n", + "[2023-09-05 13:06:38,269] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 96.0 GB, percent = 8.6%\n", + "[2023-09-05 13:06:38,272] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Final Optimizer = AdamW\n", + "[2023-09-05 13:06:38,273] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed using client LR scheduler\n", + "[2023-09-05 13:06:38,273] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed LR Scheduler = None\n", + "[2023-09-05 13:06:38,274] [INFO] [logging.py:96:log_dist] [Rank 0] step=0, skipped=0, lr=[0.00019999999999999966, 0.00019999999999999966, 0.00019999999999999966], mom=[(0.95, 0.999), (0.95, 0.999), (0.95, 0.999)]\n", + "[2023-09-05 13:06:38,275] [INFO] [config.py:960:print] DeepSpeedEngine configuration:\n", + "[2023-09-05 13:06:38,276] [INFO] [config.py:964:print] activation_checkpointing_config {\n", + " \"partition_activations\": false, \n", + " \"contiguous_memory_optimization\": false, \n", + " \"cpu_checkpointing\": false, \n", + " \"number_checkpoints\": null, \n", + " \"synchronize_checkpoint_boundary\": false, \n", + " \"profile\": false\n", + "}\n", + "[2023-09-05 13:06:38,276] [INFO] [config.py:964:print] aio_config ................... {'block_size': 26214400, 'queue_depth': 32, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}\n", + "[2023-09-05 13:06:38,277] [INFO] [config.py:964:print] amp_enabled .................. False\n", + "[2023-09-05 13:06:38,277] [INFO] [config.py:964:print] amp_params ................... False\n", + "[2023-09-05 13:06:38,278] [INFO] [config.py:964:print] autotuning_config ............ {\n", + " \"enabled\": false, \n", + " \"start_step\": null, \n", + " \"end_step\": null, \n", + " \"metric_path\": null, \n", + " \"arg_mappings\": null, \n", + " \"metric\": \"throughput\", \n", + " \"model_info\": null, \n", + " \"results_dir\": \"autotuning_results\", \n", + " \"exps_dir\": \"autotuning_exps\", \n", + " \"overwrite\": true, \n", + " \"fast\": true, \n", + " \"start_profile_step\": 3, \n", + " \"end_profile_step\": 5, \n", + " \"tuner_type\": \"gridsearch\", \n", + " \"tuner_early_stopping\": 5, \n", + " \"tuner_num_trials\": 50, \n", + " \"model_info_path\": null, \n", + " \"mp_size\": 1, \n", + " \"max_train_batch_size\": null, \n", + " \"min_train_batch_size\": 1, \n", + " \"max_train_micro_batch_size_per_gpu\": 1.024000e+03, \n", + " \"min_train_micro_batch_size_per_gpu\": 1, \n", + " \"num_tuning_micro_batch_sizes\": 3\n", + "}\n", + "[2023-09-05 13:06:38,279] [INFO] [config.py:964:print] bfloat16_enabled ............. False\n", + "[2023-09-05 13:06:38,279] [INFO] [config.py:964:print] checkpoint_parallel_write_pipeline False\n", + "[2023-09-05 13:06:38,280] [INFO] [config.py:964:print] checkpoint_tag_validation_enabled True\n", + "[2023-09-05 13:06:38,280] [INFO] [config.py:964:print] checkpoint_tag_validation_fail False\n", + "[2023-09-05 13:06:38,281] [INFO] [config.py:964:print] comms_config ................. \n", + "[2023-09-05 13:06:38,282] [INFO] [config.py:964:print] communication_data_type ...... None\n", + "[2023-09-05 13:06:38,282] [INFO] [config.py:964:print] compression_config ........... {'weight_quantization': {'shared_parameters': {'enabled': False, 'quantizer_kernel': False, 'schedule_offset': 0, 'quantize_groups': 1, 'quantize_verbose': False, 'quantization_type': 'symmetric', 'quantize_weight_in_forward': False, 'rounding': 'nearest', 'fp16_mixed_quantize': False, 'quantize_change_ratio': 0.001}, 'different_groups': {}}, 'activation_quantization': {'shared_parameters': {'enabled': False, 'quantization_type': 'symmetric', 'range_calibration': 'dynamic', 'schedule_offset': 1000}, 'different_groups': {}}, 'sparse_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'row_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'head_pruning': {'shared_parameters': {'enabled': False, 'method': 'topk', 'schedule_offset': 1000}, 'different_groups': {}}, 'channel_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'layer_reduction': {'enabled': False}}\n", + "[2023-09-05 13:06:38,283] [INFO] [config.py:964:print] curriculum_enabled_legacy .... False\n", + "[2023-09-05 13:06:38,283] [INFO] [config.py:964:print] curriculum_params_legacy ..... False\n", + "[2023-09-05 13:06:38,284] [INFO] [config.py:964:print] data_efficiency_config ....... {'enabled': False, 'seed': 1234, 'data_sampling': {'enabled': False, 'num_epochs': 1000, 'num_workers': 0, 'curriculum_learning': {'enabled': False}}, 'data_routing': {'enabled': False, 'random_ltd': {'enabled': False, 'layer_token_lr_schedule': {'enabled': False}}}}\n", + "[2023-09-05 13:06:38,284] [INFO] [config.py:964:print] data_efficiency_enabled ...... False\n", + "[2023-09-05 13:06:38,285] [INFO] [config.py:964:print] dataloader_drop_last ......... False\n", + "[2023-09-05 13:06:38,285] [INFO] [config.py:964:print] disable_allgather ............ False\n", + "[2023-09-05 13:06:38,286] [INFO] [config.py:964:print] dump_state ................... False\n", + "[2023-09-05 13:06:38,287] [INFO] [config.py:964:print] dynamic_loss_scale_args ...... None\n", + "[2023-09-05 13:06:38,287] [INFO] [config.py:964:print] eigenvalue_enabled ........... False\n", + "[2023-09-05 13:06:38,288] [INFO] [config.py:964:print] eigenvalue_gas_boundary_resolution 1\n", + "[2023-09-05 13:06:38,288] [INFO] [config.py:964:print] eigenvalue_layer_name ........ bert.encoder.layer\n", + "[2023-09-05 13:06:38,289] [INFO] [config.py:964:print] eigenvalue_layer_num ......... 0\n", + "[2023-09-05 13:06:38,289] [INFO] [config.py:964:print] eigenvalue_max_iter .......... 100\n", + "[2023-09-05 13:06:38,290] [INFO] [config.py:964:print] eigenvalue_stability ......... 1e-06\n", + "[2023-09-05 13:06:38,290] [INFO] [config.py:964:print] eigenvalue_tol ............... 0.01\n", + "[2023-09-05 13:06:38,291] [INFO] [config.py:964:print] eigenvalue_verbose ........... False\n", + "[2023-09-05 13:06:38,291] [INFO] [config.py:964:print] elasticity_enabled ........... False\n", + "[2023-09-05 13:06:38,292] [INFO] [config.py:964:print] flops_profiler_config ........ {\n", + " \"enabled\": false, \n", + " \"recompute_fwd_factor\": 0.0, \n", + " \"profile_step\": 1, \n", + " \"module_depth\": -1, \n", + " \"top_modules\": 1, \n", + " \"detailed\": true, \n", + " \"output_file\": null\n", + "}\n", + "[2023-09-05 13:06:38,293] [INFO] [config.py:964:print] fp16_auto_cast ............... False\n", + "[2023-09-05 13:06:38,293] [INFO] [config.py:964:print] fp16_enabled ................. True\n", + "[2023-09-05 13:06:38,294] [INFO] [config.py:964:print] fp16_master_weights_and_gradients False\n", + "[2023-09-05 13:06:38,294] [INFO] [config.py:964:print] global_rank .................. 0\n", + "[2023-09-05 13:06:38,295] [INFO] [config.py:964:print] grad_accum_dtype ............. None\n", + "[2023-09-05 13:06:38,295] [INFO] [config.py:964:print] gradient_accumulation_steps .. 1\n", + "[2023-09-05 13:06:38,296] [INFO] [config.py:964:print] gradient_clipping ............ 1.0\n", + "[2023-09-05 13:06:38,296] [INFO] [config.py:964:print] gradient_predivide_factor .... 1.0\n", + "[2023-09-05 13:06:38,297] [INFO] [config.py:964:print] hybrid_engine ................ enabled=False max_out_tokens=512 inference_tp_size=1 release_inference_cache=False pin_parameters=True tp_gather_partition_size=8\n", + "[2023-09-05 13:06:38,297] [INFO] [config.py:964:print] initial_dynamic_scale ........ 65536\n", + "[2023-09-05 13:06:38,298] [INFO] [config.py:964:print] load_universal_checkpoint .... False\n", + "[2023-09-05 13:06:38,299] [INFO] [config.py:964:print] loss_scale ................... 0\n", + "[2023-09-05 13:06:38,299] [INFO] [config.py:964:print] memory_breakdown ............. False\n", + "[2023-09-05 13:06:38,300] [INFO] [config.py:964:print] mics_hierarchial_params_gather False\n", + "[2023-09-05 13:06:38,300] [INFO] [config.py:964:print] mics_shard_size .............. -1\n", + "[2023-09-05 13:06:38,301] [INFO] [config.py:964:print] monitor_config ............... tensorboard=TensorBoardConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') wandb=WandbConfig(enabled=False, group=None, team=None, project='deepspeed') csv_monitor=CSVConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') enabled=False\n", + "[2023-09-05 13:06:38,302] [INFO] [config.py:964:print] nebula_config ................ {\n", + " \"enabled\": false, \n", + " \"persistent_storage_path\": null, \n", + " \"persistent_time_interval\": 100, \n", + " \"num_of_version_in_retention\": 2, \n", + " \"enable_nebula_load\": true, \n", + " \"load_path\": null\n", + "}\n", + "[2023-09-05 13:06:38,302] [INFO] [config.py:964:print] optimizer_legacy_fusion ...... False\n", + "[2023-09-05 13:06:38,303] [INFO] [config.py:964:print] optimizer_name ............... None\n", + "[2023-09-05 13:06:38,303] [INFO] [config.py:964:print] optimizer_params ............. None\n", + "[2023-09-05 13:06:38,304] [INFO] [config.py:964:print] pipeline ..................... {'stages': 'auto', 'partition': 'best', 'seed_layers': False, 'activation_checkpoint_interval': 0}\n", + "[2023-09-05 13:06:38,304] [INFO] [config.py:964:print] pld_enabled .................. False\n", + "[2023-09-05 13:06:38,305] [INFO] [config.py:964:print] pld_params ................... False\n", + "[2023-09-05 13:06:38,305] [INFO] [config.py:964:print] prescale_gradients ........... False\n", + "[2023-09-05 13:06:38,306] [INFO] [config.py:964:print] scheduler_name ............... None\n", + "[2023-09-05 13:06:38,307] [INFO] [config.py:964:print] scheduler_params ............. None\n", + "[2023-09-05 13:06:38,307] [INFO] [config.py:964:print] sparse_attention ............. None\n", + "[2023-09-05 13:06:38,308] [INFO] [config.py:964:print] sparse_gradients_enabled ..... False\n", + "[2023-09-05 13:06:38,308] [INFO] [config.py:964:print] steps_per_print .............. inf\n", + "[2023-09-05 13:06:38,309] [INFO] [config.py:964:print] train_batch_size ............. 128\n", + "[2023-09-05 13:06:38,309] [INFO] [config.py:964:print] train_micro_batch_size_per_gpu 128\n", + "[2023-09-05 13:06:38,310] [INFO] [config.py:964:print] use_node_local_storage ....... False\n", + "[2023-09-05 13:06:38,310] [INFO] [config.py:964:print] wall_clock_breakdown ......... False\n", + "[2023-09-05 13:06:38,311] [INFO] [config.py:964:print] world_size ................... 1\n", + "[2023-09-05 13:06:38,312] [INFO] [config.py:964:print] zero_allow_untested_optimizer True\n", + "[2023-09-05 13:06:38,312] [INFO] [config.py:964:print] zero_config .................. stage=2 contiguous_gradients=True reduce_scatter=True reduce_bucket_size=10000000 allgather_partitions=True allgather_bucket_size=500,000,000 overlap_comm=False load_from_fp32_weights=True elastic_checkpoint=False offload_param=DeepSpeedZeroOffloadParamConfig(device='none', nvme_path=PosixPath('/scratch'), buffer_count=5, buffer_size=4000000000, max_in_cpu=1,000,000,000, pin_memory=True) offload_optimizer=DeepSpeedZeroOffloadOptimizerConfig(device='none', nvme_path=PosixPath('/scratch'), buffer_count=4, pin_memory=True, pipeline=False, pipeline_read=False, pipeline_write=False, fast_init=False) sub_group_size=1000000000 cpu_offload_param=None cpu_offload_use_pin_memory=None cpu_offload=None prefetch_bucket_size=10000000 param_persistence_threshold=100000 model_persistence_threshold=sys.maxsize max_live_parameters=1000000000 max_reuse_distance=1000000000 gather_16bit_weights_on_model_save=True stage3_gather_fp16_weights_on_model_save=False ignore_unused_parameters=True legacy_stage1=False round_robin_gradients=False mics_shard_size=-1 mics_hierarchical_params_gather=False memory_efficient_linear=True\n", + "[2023-09-05 13:06:38,313] [INFO] [config.py:964:print] zero_enabled ................. True\n", + "[2023-09-05 13:06:38,313] [INFO] [config.py:964:print] zero_force_ds_cpu_optimizer .. True\n", + "[2023-09-05 13:06:38,314] [INFO] [config.py:964:print] zero_optimization_stage ...... 2\n", + "[2023-09-05 13:06:38,314] [INFO] [config.py:950:print_user_config] json = {\n", + " \"bf16\": {\n", + " \"enabled\": false\n", + " }, \n", + " \"fp16\": {\n", + " \"enabled\": true\n", + " }, \n", + " \"zero_optimization\": {\n", + " \"stage\": 2, \n", + " \"contiguous_gradients\": true, \n", + " \"stage3_gather_16bit_weights_on_model_save\": true, \n", + " \"stage3_max_live_parameters\": 1.000000e+09, \n", + " \"stage3_max_reuse_distance\": 1.000000e+09, \n", + " \"stage3_prefetch_bucket_size\": 1.000000e+07, \n", + " \"stage3_param_persistence_threshold\": 1.000000e+05, \n", + " \"reduce_bucket_size\": 1.000000e+07, \n", + " \"sub_group_size\": 1.000000e+09, \n", + " \"offload_optimizer\": {\n", + " \"device\": \"none\", \n", + " \"nvme_path\": \"/scratch\", \n", + " \"pin_memory\": true\n", + " }, \n", + " \"offload_param\": {\n", + " \"device\": \"none\", \n", + " \"nvme_path\": \"/scratch\", \n", + " \"buffer_size\": 4.000000e+09, \n", + " \"pin_memory\": true\n", + " }\n", + " }, \n", + " \"aio\": {\n", + " \"block_size\": 2.621440e+07, \n", + " \"queue_depth\": 32, \n", + " \"thread_count\": 1, \n", + " \"single_submit\": false, \n", + " \"overlap_events\": true\n", + " }, \n", + " \"gradient_accumulation_steps\": 1, \n", + " \"gradient_clipping\": 1.0, \n", + " \"steps_per_print\": inf, \n", + " \"train_batch_size\": 128, \n", + " \"train_micro_batch_size_per_gpu\": 128, \n", + " \"wall_clock_breakdown\": false, \n", + " \"zero_allow_untested_optimizer\": true\n", + "}\n" + ] + } + ], + "source": [ + "model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(\n", + "model, optimizer, train_dl, test_dl, lr_scheduler\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "b4a3368c-e6ce-49cc-b970-ee3dba12dfcd", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "test starting with epoch 0 / 12\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/12 [00:00