Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| from timm.models.vision_transformer import VisionTransformer | |
| from functools import partial | |
| from einops import rearrange | |
| from dnns.yolov3.yolo_fpn import YOLOFPN | |
| from dnns.yolov3.head import YOLOXHead | |
| from utils.dl.common.model import set_module, get_module | |
| from types import MethodType | |
| import os | |
| from utils.common.log import logger | |
| class VisionTransformerYOLOv3(VisionTransformer): | |
| def forward_head(self, x): | |
| # print(222) | |
| return self.head(x) | |
| def forward_features(self, x): | |
| # print(111) | |
| return self._intermediate_layers(x, n=[len(self.blocks) // 3 - 1, len(self.blocks) // 3 * 2 - 1, len(self.blocks) - 1]) | |
| def forward(self, x, targets=None): | |
| features = self.forward_features(x) | |
| return self.head(x, features, targets) | |
| def init_from_vit(vit: VisionTransformer): | |
| res = VisionTransformerYOLOv3() | |
| for attr in dir(vit): | |
| # if str(attr) not in ['forward_head', 'forward_features'] and not attr.startswith('__'): | |
| if isinstance(getattr(vit, attr), nn.Module): | |
| # print(attr) | |
| try: | |
| setattr(res, attr, getattr(vit, attr)) | |
| except Exception as e: | |
| print(attr, str(e)) | |
| return res | |
| class Norm2d(nn.Module): | |
| def __init__(self, embed_dim): | |
| super().__init__() | |
| self.ln = nn.LayerNorm(embed_dim, eps=1e-6) | |
| def forward(self, x): | |
| x = x.permute(0, 2, 3, 1) | |
| x = self.ln(x) | |
| x = x.permute(0, 3, 1, 2).contiguous() | |
| return x | |
| class ViTYOLOv3Head(nn.Module): | |
| def __init__(self, im_size, patch_size, patch_dim, num_classes, use_bigger_fpns, cls_vit_ckpt_path, init_head): | |
| super(ViTYOLOv3Head, self).__init__() | |
| self.im_size = im_size | |
| self.patch_size = patch_size | |
| # target_patch_dim: [256, 512, 512] | |
| # self.change_patchs_dim = nn.ModuleList([nn.Linear(patch_dim, target_patch_dim) for target_patch_dim in [256, 512, 512]]) | |
| # # input: (1, target_patch_dim, 14, 14) | |
| # # target feature size: {40, 20, 10} | |
| # self.change_features_size = nn.ModuleList([ | |
| # self.get_change_feature_size(cin, cout, t) for t, cin, cout in zip([40, 20, 10], [256, 512, 512], [256, 512, 512]) | |
| # ]) | |
| embed_dim = 768 | |
| self.before_fpns = nn.ModuleList([ | |
| # nn.Sequential( | |
| # nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |
| # nn.GroupNorm(embed_dim), | |
| # nn.GELU(), | |
| # nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |
| # ), | |
| nn.Sequential( | |
| nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |
| ), | |
| nn.Identity(), | |
| nn.MaxPool2d(kernel_size=2, stride=2) | |
| ]) | |
| if use_bigger_fpns == 1: | |
| logger.info('use 421x fpns') | |
| self.before_fpns = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |
| Norm2d(embed_dim), | |
| nn.GELU(), | |
| nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |
| ), | |
| nn.Sequential( | |
| nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |
| ), | |
| nn.Identity(), | |
| # nn.MaxPool2d(kernel_size=2, stride=2) | |
| ]) | |
| if use_bigger_fpns == -1: | |
| logger.info('use 1/0.5/0.25x fpns') | |
| self.before_fpns = nn.ModuleList([ | |
| # nn.Sequential( | |
| # nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |
| # ), | |
| nn.Identity(), | |
| nn.MaxPool2d(kernel_size=2, stride=2), | |
| nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), nn.MaxPool2d(kernel_size=2, stride=2)) | |
| ]) | |
| # self.fpn = YOLOFPN() | |
| self.fpn = nn.Identity() | |
| self.head = YOLOXHead(num_classes, in_channels=[768, 768, 768], act='lrelu') | |
| if init_head: | |
| logger.info('init head') | |
| self.load_pretrained_weight(cls_vit_ckpt_path) | |
| else: | |
| logger.info('do not init head') | |
| def load_pretrained_weight(self, cls_vit_ckpt_path): | |
| ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'yolox_darknet.pth')) | |
| # for k in [f'head.cls_preds.{i}.{j}' for i in [0, 1, 2] for j in ['weight', 'bias']]: | |
| # del ckpt['model'][k] | |
| removed_k = [f'head.cls_preds.{i}.{j}' for i in [0, 1, 2] for j in ['weight', 'bias']] | |
| for k, v in ckpt['model'].items(): | |
| if 'backbone.backbone' in k: | |
| removed_k += [k] | |
| if 'head.stems' in k and 'conv.weight' in k: | |
| removed_k += [k] | |
| for k in removed_k: | |
| del ckpt['model'][k] | |
| # print(ckpt['model'].keys()) | |
| new_state_dict = {} | |
| for k, v in ckpt['model'].items(): | |
| new_k = k.replace('backbone', 'fpn') | |
| new_state_dict[new_k] = v | |
| # cls_vit_ckpt = torch.load(cls_vit_ckpt_path) | |
| # for k, v in cls_vit_ckpt['main'].named_parameters(): | |
| # if not 'qkv.abs' not in k: | |
| # continue | |
| # new_state_dict[k] = v | |
| # logger.info(f'load {k} from cls vit ckpt') | |
| self.load_state_dict(new_state_dict, strict=False) | |
| def get_change_feature_size(self, in_channels, out_channels, target_size): | |
| H, W = self.im_size | |
| GS = H // self.patch_size # 14 | |
| if target_size == GS: | |
| return nn.Identity() | |
| elif target_size < GS: | |
| return nn.AdaptiveMaxPool2d((target_size, target_size)) | |
| else: | |
| return { | |
| 20: nn.Sequential( | |
| nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(3, 3), stride=2, padding=0), | |
| # nn.BatchNorm2d(out_channels), | |
| # nn.ReLU(), | |
| nn.AdaptiveMaxPool2d((target_size, target_size)) | |
| ), | |
| 40: nn.Sequential( | |
| nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(3, 3), stride=3, padding=1), | |
| # nn.BatchNorm2d(out_channels), | |
| # nn.ReLU(), | |
| ) | |
| }[target_size] | |
| def forward(self, input_images, x, targets=None): | |
| # print(111) | |
| # NOTE: YOLOX backbone (w/o FPN) output, or FPN input: {'dark3': torch.Size([4, 256, 40, 40]), 'dark4': torch.Size([4, 512, 20, 20]), 'dark5': torch.Size([4, 512, 10, 10])} | |
| # NOTE: YOLOXHead input: [torch.Size([4, 128, 40, 40]), torch.Size([4, 256, 20, 20]), torch.Size([4, 512, 10, 10])] | |
| # print(x) | |
| # print([i.size() for i in x]) | |
| x = [i[:, 1:] for i in x] | |
| x = [i.permute(0, 2, 1).reshape(input_images.size(0), -1, 14, 14) for i in x] # 14 is hardcode, obtained from timm.layers.patch_embed.py | |
| # print([i.size() for i in x]) | |
| # exit() | |
| # NOTE: old | |
| # x[0]: torch.Size([1, 196, 768]) | |
| # H, W = self.im_size | |
| # GS = H // self.patch_size # 14 | |
| # xs = [cpd(x) for x, cpd in zip(xs, self.change_patchs_dim)] # (1, 196, target_patch_dim) | |
| # xs = [rearrange(x, "b (h w) c -> b c h w", h=GS) for x in xs] # (1, target_patch_dim, 14, 14) | |
| # xs = [cfs(x) for x, cfs in zip(xs, self.change_features_size)] | |
| # print([i.size() for i in xs]) | |
| # ---------------- | |
| xs = [before_fpn(x[-1]) for i, before_fpn in zip(x, self.before_fpns)] | |
| # print([i.size() for i in xs]) | |
| # exit() | |
| # [torch.Size([1, 768, 28, 28]), torch.Size([1, 768, 14, 14]), torch.Size([1, 768, 7, 7])] | |
| xs = self.fpn(xs) | |
| # print('before head', [i.size() for i in xs]) | |
| xs = tuple(xs) | |
| if targets is not None: | |
| loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(xs, targets, input_images) | |
| return { | |
| "total_loss": loss, | |
| "iou_loss": iou_loss, | |
| "l1_loss": l1_loss, | |
| "conf_loss": conf_loss, | |
| "cls_loss": cls_loss, | |
| "num_fg": num_fg, | |
| } | |
| return self.head(xs) | |
| class ViTYOLOv3Head2(nn.Module): | |
| def __init__(self, im_size, patch_size, patch_dim, num_classes, use_bigger_fpns): | |
| super(ViTYOLOv3Head2, self).__init__() | |
| self.im_size = im_size | |
| self.patch_size = patch_size | |
| # target_patch_dim: [256, 512, 512] | |
| # self.change_patchs_dim = nn.ModuleList([nn.Linear(patch_dim, target_patch_dim) for target_patch_dim in [256, 512, 512]]) | |
| # # input: (1, target_patch_dim, 14, 14) | |
| # # target feature size: {40, 20, 10} | |
| # self.change_features_size = nn.ModuleList([ | |
| # self.get_change_feature_size(cin, cout, t) for t, cin, cout in zip([40, 20, 10], [256, 512, 512], [256, 512, 512]) | |
| # ]) | |
| embed_dim = 768 | |
| self.before_fpns = nn.ModuleList([ | |
| # nn.Sequential( | |
| # nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |
| # nn.GroupNorm(embed_dim), | |
| # nn.GELU(), | |
| # nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |
| # ), | |
| nn.Sequential( | |
| nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |
| ), | |
| nn.Identity(), | |
| nn.MaxPool2d(kernel_size=2, stride=2) | |
| ]) | |
| if use_bigger_fpns: | |
| logger.info('use 8/4/2x fpns') | |
| self.before_fpns = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |
| Norm2d(embed_dim), | |
| nn.GELU(), | |
| nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |
| ), | |
| nn.Sequential( | |
| nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |
| ), | |
| nn.Identity(), | |
| # nn.MaxPool2d(kernel_size=2, stride=2) | |
| ]) | |
| # self.fpn = YOLOFPN() | |
| self.fpn = nn.Identity() | |
| self.head = YOLOXHead(num_classes, in_channels=[768, 768, 768], act='lrelu') | |
| self.load_pretrained_weight() | |
| def load_pretrained_weight(self): | |
| ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'yolox_darknet.pth')) | |
| # for k in [f'head.cls_preds.{i}.{j}' for i in [0, 1, 2] for j in ['weight', 'bias']]: | |
| # del ckpt['model'][k] | |
| removed_k = [f'head.cls_preds.{i}.{j}' for i in [0, 1, 2] for j in ['weight', 'bias']] | |
| for k, v in ckpt['model'].items(): | |
| if 'backbone.backbone' in k: | |
| removed_k += [k] | |
| if 'head.stems' in k and 'conv.weight' in k: | |
| removed_k += [k] | |
| for k in removed_k: | |
| del ckpt['model'][k] | |
| # print(ckpt['model'].keys()) | |
| new_state_dict = {} | |
| for k, v in ckpt['model'].items(): | |
| new_k = k.replace('backbone', 'fpn') | |
| new_state_dict[new_k] = v | |
| self.load_state_dict(new_state_dict, strict=False) | |
| def get_change_feature_size(self, in_channels, out_channels, target_size): | |
| H, W = self.im_size | |
| GS = H // self.patch_size # 14 | |
| if target_size == GS: | |
| return nn.Identity() | |
| elif target_size < GS: | |
| return nn.AdaptiveMaxPool2d((target_size, target_size)) | |
| else: | |
| return { | |
| 20: nn.Sequential( | |
| nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(3, 3), stride=2, padding=0), | |
| # nn.BatchNorm2d(out_channels), | |
| # nn.ReLU(), | |
| nn.AdaptiveMaxPool2d((target_size, target_size)) | |
| ), | |
| 40: nn.Sequential( | |
| nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(3, 3), stride=3, padding=1), | |
| # nn.BatchNorm2d(out_channels), | |
| # nn.ReLU(), | |
| ) | |
| }[target_size] | |
| def forward(self, input_images, x, targets=None): | |
| # print(111) | |
| # NOTE: YOLOX backbone (w/o FPN) output, or FPN input: {'dark3': torch.Size([4, 256, 40, 40]), 'dark4': torch.Size([4, 512, 20, 20]), 'dark5': torch.Size([4, 512, 10, 10])} | |
| # NOTE: YOLOXHead input: [torch.Size([4, 128, 40, 40]), torch.Size([4, 256, 20, 20]), torch.Size([4, 512, 10, 10])] | |
| # print(x) | |
| # print([i.size() for i in x]) | |
| x = [i[:, 1:] for i in x] | |
| x = [i.permute(0, 2, 1).reshape(input_images.size(0), -1, 14, 14) for i in x] # 14 is hardcode, obtained from timm.layers.patch_embed.py | |
| # print([i.size() for i in x]) | |
| # exit() | |
| # NOTE: old | |
| # x[0]: torch.Size([1, 196, 768]) | |
| # H, W = self.im_size | |
| # GS = H // self.patch_size # 14 | |
| # xs = [cpd(x) for x, cpd in zip(xs, self.change_patchs_dim)] # (1, 196, target_patch_dim) | |
| # xs = [rearrange(x, "b (h w) c -> b c h w", h=GS) for x in xs] # (1, target_patch_dim, 14, 14) | |
| # xs = [cfs(x) for x, cfs in zip(xs, self.change_features_size)] | |
| # print([i.size() for i in xs]) | |
| # ---------------- | |
| xs = [before_fpn(i) for i, before_fpn in zip(x, self.before_fpns)] | |
| # print([i.size() for i in xs]) | |
| # exit() | |
| # [torch.Size([1, 768, 28, 28]), torch.Size([1, 768, 14, 14]), torch.Size([1, 768, 7, 7])] | |
| xs = self.fpn(xs) | |
| # print('before head', [i.size() for i in xs]) | |
| xs = tuple(xs) | |
| if targets is not None: | |
| loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(xs, targets, input_images) | |
| return { | |
| "total_loss": loss, | |
| "iou_loss": iou_loss, | |
| "l1_loss": l1_loss, | |
| "conf_loss": conf_loss, | |
| "cls_loss": cls_loss, | |
| "num_fg": num_fg, | |
| } | |
| return self.head(xs) | |
| def _forward_head(self, x): | |
| return self.head(x) | |
| # def ensure_forward_head_obj_repoint(self): | |
| # self.forward_head = MethodType(_forward_head, self) | |
| def make_vit_yolov3(vit: VisionTransformer, samples: torch.Tensor, patch_size, patch_dim, num_classes, | |
| use_bigger_fpns=False, use_multi_layer_feature=False, cls_vit_ckpt_path=None, init_head=False): | |
| assert cls_vit_ckpt_path is None | |
| # vit -> fpn -> head | |
| # modify vit.forward() to make it output middle features | |
| # vit.forward_features = partial(vit._intermediate_layers, | |
| # n=[len(vit.blocks) // 3 - 1, len(vit.blocks) // 3 * 2 - 1, len(vit.blocks) - 1]) | |
| # vit.forward_head = _forward_head | |
| # vit.__deepcopy__ = MethodType(ensure_forward_head_obj_repoint, vit) | |
| vit = VisionTransformerYOLOv3.init_from_vit(vit) | |
| if not use_multi_layer_feature: | |
| set_module(vit, 'head', ViTYOLOv3Head( | |
| im_size=(samples.size(2), samples.size(3)), | |
| patch_size=patch_size, | |
| patch_dim=patch_dim, | |
| num_classes=num_classes, | |
| use_bigger_fpns=use_bigger_fpns, | |
| cls_vit_ckpt_path=cls_vit_ckpt_path, | |
| init_head=init_head | |
| )) | |
| else: | |
| raise NotImplementedError | |
| logger.info('use multi layer feature') | |
| set_module(vit, 'head', ViTYOLOv3Head2( | |
| im_size=(samples.size(2), samples.size(3)), | |
| patch_size=patch_size, | |
| patch_dim=patch_dim, | |
| num_classes=num_classes, | |
| use_bigger_fpns=use_bigger_fpns, | |
| cls_vit_ckpt_path=cls_vit_ckpt_path | |
| )) | |
| # print(vit) | |
| vit.eval() | |
| output = vit(samples) | |
| # print([oo.size() for oo in output]) | |
| assert len(output) == samples.size(0) and output[0].size(1) == num_classes + 5, f'{[oo.size() for oo in output]}, {num_classes}' | |
| return vit | |
| if __name__ == '__main__': | |
| from dnns.vit import vit_b_16 | |
| vit_b_16 = vit_b_16() | |
| make_vit_yolov3(vit_b_16, torch.rand((1, 3, 224, 224)), 16, 768, 20) | |
| exit() | |
| from types import MethodType | |
| class Student(object): | |
| pass | |
| def set_name(self, name): | |
| self.name = name | |
| def get_name(self): | |
| print(self.name) | |
| s1 = Student() | |
| #将方法绑定到s1和s2实例中 | |
| s1.set_name = MethodType(set_name, s1) | |
| s1.get_name = MethodType(get_name, s1) | |
| s1.set_name('s1') | |
| from copy import deepcopy | |
| s2 = deepcopy(s1) | |
| s2.get_name() | |
| s2.set_name('s2') | |
| s1.get_name() | |
| s2.get_name() |