Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	Upload 3 files
Browse files- model/IAT_main.py +133 -0
- model/blocks.py +281 -0
- model/global_net.py +132 -0
    	
        model/IAT_main.py
    ADDED
    
    | @@ -0,0 +1,133 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            from torch import nn
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import math
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from timm.models.layers import trunc_normal_
         | 
| 9 | 
            +
            from model.blocks import CBlock_ln, SwinTransformerBlock
         | 
| 10 | 
            +
            from model.global_net import Global_pred
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            class Local_pred(nn.Module):
         | 
| 13 | 
            +
                def __init__(self, dim=16, number=4, type='ccc'):
         | 
| 14 | 
            +
                    super(Local_pred, self).__init__()
         | 
| 15 | 
            +
                    # initial convolution
         | 
| 16 | 
            +
                    self.conv1 = nn.Conv2d(3, dim, 3, padding=1, groups=1)
         | 
| 17 | 
            +
                    self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
         | 
| 18 | 
            +
                    # main blocks
         | 
| 19 | 
            +
                    block = CBlock_ln(dim)
         | 
| 20 | 
            +
                    block_t = SwinTransformerBlock(dim)  # head number
         | 
| 21 | 
            +
                    if type =='ccc':  
         | 
| 22 | 
            +
                        #blocks1, blocks2 = [block for _ in range(number)], [block for _ in range(number)]
         | 
| 23 | 
            +
                        blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
         | 
| 24 | 
            +
                        blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
         | 
| 25 | 
            +
                    elif type =='ttt':
         | 
| 26 | 
            +
                        blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
         | 
| 27 | 
            +
                    elif type =='cct':
         | 
| 28 | 
            +
                        blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
         | 
| 29 | 
            +
                    #    block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
         | 
| 30 | 
            +
                    self.mul_blocks = nn.Sequential(*blocks1, nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
         | 
| 31 | 
            +
                    self.add_blocks = nn.Sequential(*blocks2, nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
                def forward(self, img):
         | 
| 35 | 
            +
                    img1 = self.relu(self.conv1(img))
         | 
| 36 | 
            +
                    mul = self.mul_blocks(img1)
         | 
| 37 | 
            +
                    add = self.add_blocks(img1)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    return mul, add
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            # Short Cut Connection on Final Layer
         | 
| 42 | 
            +
            class Local_pred_S(nn.Module):
         | 
| 43 | 
            +
                def __init__(self, in_dim=3, dim=16, number=4, type='ccc'):
         | 
| 44 | 
            +
                    super(Local_pred_S, self).__init__()
         | 
| 45 | 
            +
                    # initial convolution
         | 
| 46 | 
            +
                    self.conv1 = nn.Conv2d(in_dim, dim, 3, padding=1, groups=1)
         | 
| 47 | 
            +
                    self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
         | 
| 48 | 
            +
                    # main blocks
         | 
| 49 | 
            +
                    block = CBlock_ln(dim)
         | 
| 50 | 
            +
                    block_t = SwinTransformerBlock(dim)  # head number
         | 
| 51 | 
            +
                    if type =='ccc':
         | 
| 52 | 
            +
                        blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
         | 
| 53 | 
            +
                        blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
         | 
| 54 | 
            +
                    elif type =='ttt':
         | 
| 55 | 
            +
                        blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
         | 
| 56 | 
            +
                    elif type =='cct':
         | 
| 57 | 
            +
                        blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
         | 
| 58 | 
            +
                    #    block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
         | 
| 59 | 
            +
                    self.mul_blocks = nn.Sequential(*blocks1)
         | 
| 60 | 
            +
                    self.add_blocks = nn.Sequential(*blocks2)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    self.mul_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
         | 
| 63 | 
            +
                    self.add_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
         | 
| 64 | 
            +
                    self.apply(self._init_weights)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def _init_weights(self, m):
         | 
| 67 | 
            +
                    if isinstance(m, nn.Linear):
         | 
| 68 | 
            +
                        trunc_normal_(m.weight, std=.02)
         | 
| 69 | 
            +
                        if isinstance(m, nn.Linear) and m.bias is not None:
         | 
| 70 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 71 | 
            +
                    elif isinstance(m, nn.LayerNorm):
         | 
| 72 | 
            +
                        nn.init.constant_(m.bias, 0)
         | 
| 73 | 
            +
                        nn.init.constant_(m.weight, 1.0)
         | 
| 74 | 
            +
                    elif isinstance(m, nn.Conv2d):
         | 
| 75 | 
            +
                        fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
         | 
| 76 | 
            +
                        fan_out //= m.groups
         | 
| 77 | 
            +
                        m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
         | 
| 78 | 
            +
                        if m.bias is not None:
         | 
| 79 | 
            +
                            m.bias.data.zero_()
         | 
| 80 | 
            +
                        
         | 
| 81 | 
            +
                        
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                def forward(self, img):
         | 
| 84 | 
            +
                    img1 = self.relu(self.conv1(img))
         | 
| 85 | 
            +
                    # short cut connection
         | 
| 86 | 
            +
                    mul = self.mul_blocks(img1) + img1
         | 
| 87 | 
            +
                    add = self.add_blocks(img1) + img1
         | 
| 88 | 
            +
                    mul = self.mul_end(mul)
         | 
| 89 | 
            +
                    add = self.add_end(add)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    return mul, add
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            class IAT(nn.Module):
         | 
| 94 | 
            +
                def __init__(self, in_dim=3, with_global=True, type='lol'):
         | 
| 95 | 
            +
                    super(IAT, self).__init__()
         | 
| 96 | 
            +
                    #self.local_net = Local_pred()
         | 
| 97 | 
            +
                    
         | 
| 98 | 
            +
                    self.local_net = Local_pred_S(in_dim=in_dim)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    self.with_global = with_global
         | 
| 101 | 
            +
                    if self.with_global:
         | 
| 102 | 
            +
                        self.global_net = Global_pred(in_channels=in_dim, type=type)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def apply_color(self, image, ccm):
         | 
| 105 | 
            +
                    shape = image.shape
         | 
| 106 | 
            +
                    image = image.view(-1, 3)
         | 
| 107 | 
            +
                    image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
         | 
| 108 | 
            +
                    image = image.view(shape)
         | 
| 109 | 
            +
                    return torch.clamp(image, 1e-8, 1.0)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def forward(self, img_low):
         | 
| 112 | 
            +
                    #print(self.with_global)
         | 
| 113 | 
            +
                    mul, add = self.local_net(img_low)
         | 
| 114 | 
            +
                    img_high = (img_low.mul(mul)).add(add)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    if not self.with_global:
         | 
| 117 | 
            +
                        return img_high
         | 
| 118 | 
            +
                    
         | 
| 119 | 
            +
                    else:
         | 
| 120 | 
            +
                        gamma, color = self.global_net(img_low)
         | 
| 121 | 
            +
                        b = img_high.shape[0]
         | 
| 122 | 
            +
                        img_high = img_high.permute(0, 2, 3, 1)  # (B,C,H,W) -- (B,H,W,C)
         | 
| 123 | 
            +
                        img_high = torch.stack([self.apply_color(img_high[i,:,:,:], color[i,:,:])**gamma[i,:] for i in range(b)], dim=0)
         | 
| 124 | 
            +
                        img_high = img_high.permute(0, 3, 1, 2)  # (B,H,W,C) -- (B,C,H,W)
         | 
| 125 | 
            +
                        return img_high
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            if __name__ == "__main__":
         | 
| 129 | 
            +
                os.environ['CUDA_VISIBLE_DEVICES']='3'
         | 
| 130 | 
            +
                img = torch.Tensor(1, 3, 400, 600)
         | 
| 131 | 
            +
                net = IAT()
         | 
| 132 | 
            +
                print('total parameters:', sum(param.numel() for param in net.parameters()))
         | 
| 133 | 
            +
                _, _, high = net(img)
         | 
    	
        model/blocks.py
    ADDED
    
    | @@ -0,0 +1,281 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Code copy from uniformer source code:
         | 
| 3 | 
            +
            https://github.com/Sense-X/UniFormer
         | 
| 4 | 
            +
            """
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torch.nn as nn
         | 
| 8 | 
            +
            from functools import partial
         | 
| 9 | 
            +
            import math
         | 
| 10 | 
            +
            from timm.models.vision_transformer import VisionTransformer, _cfg
         | 
| 11 | 
            +
            from timm.models.registry import register_model
         | 
| 12 | 
            +
            from timm.models.layers import trunc_normal_, DropPath, to_2tuple
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # ResMLP's normalization
         | 
| 15 | 
            +
            class Aff(nn.Module):
         | 
| 16 | 
            +
                def __init__(self, dim):
         | 
| 17 | 
            +
                    super().__init__()
         | 
| 18 | 
            +
                    # learnable
         | 
| 19 | 
            +
                    self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
         | 
| 20 | 
            +
                    self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def forward(self, x):
         | 
| 23 | 
            +
                    x = x * self.alpha + self.beta
         | 
| 24 | 
            +
                    return x
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            # Color Normalization
         | 
| 27 | 
            +
            class Aff_channel(nn.Module):
         | 
| 28 | 
            +
                def __init__(self, dim, channel_first = True):
         | 
| 29 | 
            +
                    super().__init__()
         | 
| 30 | 
            +
                    # learnable
         | 
| 31 | 
            +
                    self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
         | 
| 32 | 
            +
                    self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
         | 
| 33 | 
            +
                    self.color = nn.Parameter(torch.eye(dim))
         | 
| 34 | 
            +
                    self.channel_first = channel_first
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def forward(self, x):
         | 
| 37 | 
            +
                    if self.channel_first:
         | 
| 38 | 
            +
                        x1 = torch.tensordot(x, self.color, dims=[[-1], [-1]])
         | 
| 39 | 
            +
                        x2 = x1 * self.alpha + self.beta
         | 
| 40 | 
            +
                    else:
         | 
| 41 | 
            +
                        x1 = x * self.alpha + self.beta
         | 
| 42 | 
            +
                        x2 = torch.tensordot(x1, self.color, dims=[[-1], [-1]])
         | 
| 43 | 
            +
                    return x2
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            class Mlp(nn.Module):
         | 
| 46 | 
            +
                # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
         | 
| 47 | 
            +
                def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
         | 
| 48 | 
            +
                    super().__init__()
         | 
| 49 | 
            +
                    out_features = out_features or in_features
         | 
| 50 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 51 | 
            +
                    self.fc1 = nn.Linear(in_features, hidden_features)
         | 
| 52 | 
            +
                    self.act = act_layer()
         | 
| 53 | 
            +
                    self.fc2 = nn.Linear(hidden_features, out_features)
         | 
| 54 | 
            +
                    self.drop = nn.Dropout(drop)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def forward(self, x):
         | 
| 57 | 
            +
                    x = self.fc1(x)
         | 
| 58 | 
            +
                    x = self.act(x)
         | 
| 59 | 
            +
                    x = self.drop(x)
         | 
| 60 | 
            +
                    x = self.fc2(x)
         | 
| 61 | 
            +
                    x = self.drop(x)
         | 
| 62 | 
            +
                    return x
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            class CMlp(nn.Module):
         | 
| 65 | 
            +
                # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
         | 
| 66 | 
            +
                def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
         | 
| 67 | 
            +
                    super().__init__()
         | 
| 68 | 
            +
                    out_features = out_features or in_features
         | 
| 69 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 70 | 
            +
                    self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
         | 
| 71 | 
            +
                    self.act = act_layer()
         | 
| 72 | 
            +
                    self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
         | 
| 73 | 
            +
                    self.drop = nn.Dropout(drop)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def forward(self, x):
         | 
| 76 | 
            +
                    x = self.fc1(x)
         | 
| 77 | 
            +
                    x = self.act(x)
         | 
| 78 | 
            +
                    x = self.drop(x)
         | 
| 79 | 
            +
                    x = self.fc2(x)
         | 
| 80 | 
            +
                    x = self.drop(x)
         | 
| 81 | 
            +
                    return x
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            class CBlock_ln(nn.Module):
         | 
| 84 | 
            +
                def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
         | 
| 85 | 
            +
                             drop_path=0., act_layer=nn.GELU, norm_layer=Aff_channel, init_values=1e-4):
         | 
| 86 | 
            +
                    super().__init__()
         | 
| 87 | 
            +
                    self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
         | 
| 88 | 
            +
                    #self.norm1 = Aff_channel(dim)
         | 
| 89 | 
            +
                    self.norm1 = norm_layer(dim)
         | 
| 90 | 
            +
                    self.conv1 = nn.Conv2d(dim, dim, 1)
         | 
| 91 | 
            +
                    self.conv2 = nn.Conv2d(dim, dim, 1)
         | 
| 92 | 
            +
                    self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
         | 
| 93 | 
            +
                    # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
         | 
| 94 | 
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 95 | 
            +
                    #self.norm2 = Aff_channel(dim)
         | 
| 96 | 
            +
                    self.norm2 = norm_layer(dim)
         | 
| 97 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 98 | 
            +
                    self.gamma_1 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
         | 
| 99 | 
            +
                    self.gamma_2 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
         | 
| 100 | 
            +
                    self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def forward(self, x):
         | 
| 103 | 
            +
                    x = x + self.pos_embed(x)
         | 
| 104 | 
            +
                    B, C, H, W = x.shape
         | 
| 105 | 
            +
                    #print(x.shape)
         | 
| 106 | 
            +
                    norm_x = x.flatten(2).transpose(1, 2)
         | 
| 107 | 
            +
                    #print(norm_x.shape)
         | 
| 108 | 
            +
                    norm_x = self.norm1(norm_x)
         | 
| 109 | 
            +
                    norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
                    x = x + self.drop_path(self.gamma_1*self.conv2(self.attn(self.conv1(norm_x))))
         | 
| 113 | 
            +
                    norm_x = x.flatten(2).transpose(1, 2)
         | 
| 114 | 
            +
                    norm_x = self.norm2(norm_x)
         | 
| 115 | 
            +
                    norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
         | 
| 116 | 
            +
                    x = x + self.drop_path(self.gamma_2*self.mlp(norm_x))
         | 
| 117 | 
            +
                    return x
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            def window_partition(x, window_size):
         | 
| 121 | 
            +
                """
         | 
| 122 | 
            +
                Args:
         | 
| 123 | 
            +
                    x: (B, H, W, C)
         | 
| 124 | 
            +
                    window_size (int): window size
         | 
| 125 | 
            +
                Returns:
         | 
| 126 | 
            +
                    windows: (num_windows*B, window_size, window_size, C)
         | 
| 127 | 
            +
                """
         | 
| 128 | 
            +
                B, H, W, C = x.shape
         | 
| 129 | 
            +
                #print(x.shape)
         | 
| 130 | 
            +
                x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
         | 
| 131 | 
            +
                windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
         | 
| 132 | 
            +
                return windows
         | 
| 133 | 
            +
             | 
| 134 | 
            +
             | 
| 135 | 
            +
            def window_reverse(windows, window_size, H, W):
         | 
| 136 | 
            +
                """
         | 
| 137 | 
            +
                Args:
         | 
| 138 | 
            +
                    windows: (num_windows*B, window_size, window_size, C)
         | 
| 139 | 
            +
                    window_size (int): Window size
         | 
| 140 | 
            +
                    H (int): Height of image
         | 
| 141 | 
            +
                    W (int): Width of image
         | 
| 142 | 
            +
                Returns:
         | 
| 143 | 
            +
                    x: (B, H, W, C)
         | 
| 144 | 
            +
                """
         | 
| 145 | 
            +
                B = int(windows.shape[0] / (H * W / window_size / window_size))
         | 
| 146 | 
            +
                x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
         | 
| 147 | 
            +
                x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
         | 
| 148 | 
            +
                return x
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            class WindowAttention(nn.Module):
         | 
| 152 | 
            +
                r""" Window based multi-head self attention (W-MSA) module with relative position bias.
         | 
| 153 | 
            +
                It supports both of shifted and non-shifted window.
         | 
| 154 | 
            +
                Args:
         | 
| 155 | 
            +
                    dim (int): Number of input channels.
         | 
| 156 | 
            +
                    window_size (tuple[int]): The height and width of the window.
         | 
| 157 | 
            +
                    num_heads (int): Number of attention heads.
         | 
| 158 | 
            +
                    qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
         | 
| 159 | 
            +
                    qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
         | 
| 160 | 
            +
                    attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
         | 
| 161 | 
            +
                    proj_drop (float, optional): Dropout ratio of output. Default: 0.0
         | 
| 162 | 
            +
                """
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
         | 
| 165 | 
            +
                    super().__init__()
         | 
| 166 | 
            +
                    self.dim = dim
         | 
| 167 | 
            +
                    self.window_size = window_size  # Wh, Ww
         | 
| 168 | 
            +
                    self.num_heads = num_heads
         | 
| 169 | 
            +
                    head_dim = dim // num_heads
         | 
| 170 | 
            +
                    self.scale = qk_scale or head_dim ** -0.5
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         | 
| 173 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 174 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 175 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    self.softmax = nn.Softmax(dim=-1)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                def forward(self, x):
         | 
| 180 | 
            +
                    B_, N, C = x.shape
         | 
| 181 | 
            +
                    qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
         | 
| 182 | 
            +
                    q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    q = q * self.scale
         | 
| 185 | 
            +
                    attn = (q @ k.transpose(-2, -1))
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    attn = self.softmax(attn)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
         | 
| 192 | 
            +
                    x = self.proj(x)
         | 
| 193 | 
            +
                    x = self.proj_drop(x)
         | 
| 194 | 
            +
                    return x
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            ## Layer_norm, Aff_norm, Aff_channel_norm
         | 
| 197 | 
            +
            class SwinTransformerBlock(nn.Module):
         | 
| 198 | 
            +
                r""" Swin Transformer Block.
         | 
| 199 | 
            +
                Args:
         | 
| 200 | 
            +
                    dim (int): Number of input channels.
         | 
| 201 | 
            +
                    input_resolution (tuple[int]): Input resulotion.
         | 
| 202 | 
            +
                    num_heads (int): Number of attention heads.
         | 
| 203 | 
            +
                    window_size (int): Window size.
         | 
| 204 | 
            +
                    shift_size (int): Shift size for SW-MSA.
         | 
| 205 | 
            +
                    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
         | 
| 206 | 
            +
                    qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
         | 
| 207 | 
            +
                    qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
         | 
| 208 | 
            +
                    drop (float, optional): Dropout rate. Default: 0.0
         | 
| 209 | 
            +
                    attn_drop (float, optional): Attention dropout rate. Default: 0.0
         | 
| 210 | 
            +
                    drop_path (float, optional): Stochastic depth rate. Default: 0.0
         | 
| 211 | 
            +
                    act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
         | 
| 212 | 
            +
                    norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
         | 
| 213 | 
            +
                """
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                def __init__(self, dim, num_heads=2, window_size=8, shift_size=0,
         | 
| 216 | 
            +
                             mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
         | 
| 217 | 
            +
                             act_layer=nn.GELU, norm_layer=Aff_channel):
         | 
| 218 | 
            +
                    super().__init__()
         | 
| 219 | 
            +
                    self.dim = dim
         | 
| 220 | 
            +
                    self.num_heads = num_heads
         | 
| 221 | 
            +
                    self.window_size = window_size
         | 
| 222 | 
            +
                    self.shift_size = shift_size
         | 
| 223 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
         | 
| 226 | 
            +
                    #self.norm1 = norm_layer(dim)
         | 
| 227 | 
            +
                    self.norm1 = norm_layer(dim)
         | 
| 228 | 
            +
                    self.attn = WindowAttention(
         | 
| 229 | 
            +
                        dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
         | 
| 230 | 
            +
                        qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 233 | 
            +
                    #self.norm2 = norm_layer(dim)
         | 
| 234 | 
            +
                    self.norm2 = norm_layer(dim)
         | 
| 235 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 236 | 
            +
                    self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                def forward(self, x):
         | 
| 239 | 
            +
                    x = x + self.pos_embed(x)
         | 
| 240 | 
            +
                    B, C, H, W = x.shape
         | 
| 241 | 
            +
                    x = x.flatten(2).transpose(1, 2)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    shortcut = x
         | 
| 244 | 
            +
                    x = self.norm1(x)
         | 
| 245 | 
            +
                    x = x.view(B, H, W, C)
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    # cyclic shift
         | 
| 248 | 
            +
                    if self.shift_size > 0:
         | 
| 249 | 
            +
                        shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
         | 
| 250 | 
            +
                    else:
         | 
| 251 | 
            +
                        shifted_x = x
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    # partition windows
         | 
| 254 | 
            +
                    x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
         | 
| 255 | 
            +
                    x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    # W-MSA/SW-MSA
         | 
| 258 | 
            +
                    attn_windows = self.attn(x_windows)  # nW*B, window_size*window_size, C
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    # merge windows
         | 
| 261 | 
            +
                    attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
         | 
| 262 | 
            +
                    shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    x = shifted_x
         | 
| 265 | 
            +
                    x = x.view(B, H * W, C)
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    # FFN
         | 
| 268 | 
            +
                    x = shortcut + self.drop_path(x)
         | 
| 269 | 
            +
                    x = x + self.drop_path(self.mlp(self.norm2(x)))
         | 
| 270 | 
            +
                    x = x.transpose(1, 2).reshape(B, C, H, W)
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    return x
         | 
| 273 | 
            +
             | 
| 274 | 
            +
             | 
| 275 | 
            +
            if __name__ == "__main__":
         | 
| 276 | 
            +
                os.environ['CUDA_VISIBLE_DEVICES']='1'
         | 
| 277 | 
            +
                cb_blovk = CBlock_ln(dim = 16)
         | 
| 278 | 
            +
                x = torch.Tensor(1, 16, 400, 600)
         | 
| 279 | 
            +
                swin = SwinTransformerBlock(dim=16, num_heads=4)
         | 
| 280 | 
            +
                x = cb_blovk(x)
         | 
| 281 | 
            +
                print(x.shape)
         | 
    	
        model/global_net.py
    ADDED
    
    | @@ -0,0 +1,132 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import imp
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            from timm.models.layers import trunc_normal_, DropPath, to_2tuple
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            from model.blocks import Mlp
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class query_Attention(nn.Module):
         | 
| 10 | 
            +
                def __init__(self, dim, num_heads=2, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
         | 
| 11 | 
            +
                    super().__init__()
         | 
| 12 | 
            +
                    self.num_heads = num_heads
         | 
| 13 | 
            +
                    head_dim = dim // num_heads
         | 
| 14 | 
            +
                    # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
         | 
| 15 | 
            +
                    self.scale = qk_scale or head_dim ** -0.5
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                    self.q = nn.Parameter(torch.ones((1, 10, dim)), requires_grad=True)
         | 
| 18 | 
            +
                    self.k = nn.Linear(dim, dim, bias=qkv_bias)
         | 
| 19 | 
            +
                    self.v = nn.Linear(dim, dim, bias=qkv_bias)
         | 
| 20 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 21 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 22 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def forward(self, x):
         | 
| 25 | 
            +
                    B, N, C = x.shape
         | 
| 26 | 
            +
                    k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
         | 
| 27 | 
            +
                    v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
         | 
| 28 | 
            +
                    q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    # k = self.k(x).reshape(B, N, self.num_heads, torch.div(C,self.num_heads, rounding_mode='floor')).permute(0, 2, 1, 3)
         | 
| 31 | 
            +
                    # v = self.v(x).reshape(B, N, self.num_heads, torch.div(C,self.num_heads, rounding_mode='floor')).permute(0, 2, 1, 3)
         | 
| 32 | 
            +
                    # q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, torch.div(C,self.num_heads, rounding_mode='floor')).permute(0, 2, 1, 3)
         | 
| 33 | 
            +
                    attn = (q @ k.transpose(-2, -1)) * self.scale
         | 
| 34 | 
            +
                    attn = attn.softmax(dim=-1)
         | 
| 35 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, 10, C)
         | 
| 38 | 
            +
                    x = self.proj(x)
         | 
| 39 | 
            +
                    x = self.proj_drop(x)
         | 
| 40 | 
            +
                    return x
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            class query_SABlock(nn.Module):
         | 
| 44 | 
            +
                def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
         | 
| 45 | 
            +
                             drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
         | 
| 46 | 
            +
                    super().__init__()
         | 
| 47 | 
            +
                    self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
         | 
| 48 | 
            +
                    self.norm1 = norm_layer(dim)
         | 
| 49 | 
            +
                    self.attn = query_Attention(
         | 
| 50 | 
            +
                        dim,
         | 
| 51 | 
            +
                        num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
         | 
| 52 | 
            +
                        attn_drop=attn_drop, proj_drop=drop)
         | 
| 53 | 
            +
                    # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
         | 
| 54 | 
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 55 | 
            +
                    self.norm2 = norm_layer(dim)
         | 
| 56 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 57 | 
            +
                    self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def forward(self, x):
         | 
| 60 | 
            +
                    x = x + self.pos_embed(x)
         | 
| 61 | 
            +
                    x = x.flatten(2).transpose(1, 2)
         | 
| 62 | 
            +
                    x = self.drop_path(self.attn(self.norm1(x)))
         | 
| 63 | 
            +
                    x = x + self.drop_path(self.mlp(self.norm2(x)))
         | 
| 64 | 
            +
                    return x
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            class conv_embedding(nn.Module):
         | 
| 68 | 
            +
                def __init__(self, in_channels, out_channels):
         | 
| 69 | 
            +
                    super(conv_embedding, self).__init__()
         | 
| 70 | 
            +
                    self.proj = nn.Sequential(
         | 
| 71 | 
            +
                        nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
         | 
| 72 | 
            +
                        nn.BatchNorm2d(out_channels // 2),
         | 
| 73 | 
            +
                        nn.GELU(),
         | 
| 74 | 
            +
                        # nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
         | 
| 75 | 
            +
                        # nn.BatchNorm2d(out_channels // 2),
         | 
| 76 | 
            +
                        # nn.GELU(),
         | 
| 77 | 
            +
                        nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
         | 
| 78 | 
            +
                        nn.BatchNorm2d(out_channels),
         | 
| 79 | 
            +
                    )
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def forward(self, x):
         | 
| 82 | 
            +
                    x = self.proj(x)
         | 
| 83 | 
            +
                    return x
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            class Global_pred(nn.Module):
         | 
| 87 | 
            +
                def __init__(self, in_channels=3, out_channels=64, num_heads=4, type='exp'):
         | 
| 88 | 
            +
                    super(Global_pred, self).__init__()
         | 
| 89 | 
            +
                    if type == 'exp':
         | 
| 90 | 
            +
                        self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=False) # False in exposure correction
         | 
| 91 | 
            +
                    else:
         | 
| 92 | 
            +
                        self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=True)  
         | 
| 93 | 
            +
                    self.color_base = nn.Parameter(torch.eye((3)), requires_grad=True)  # basic color matrix
         | 
| 94 | 
            +
                    # main blocks
         | 
| 95 | 
            +
                    self.conv_large = conv_embedding(in_channels, out_channels)
         | 
| 96 | 
            +
                    self.generator = query_SABlock(dim=out_channels, num_heads=num_heads)
         | 
| 97 | 
            +
                    self.gamma_linear = nn.Linear(out_channels, 1)
         | 
| 98 | 
            +
                    self.color_linear = nn.Linear(out_channels, 1)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    self.apply(self._init_weights)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    for name, p in self.named_parameters():
         | 
| 103 | 
            +
                        if name == 'generator.attn.v.weight':
         | 
| 104 | 
            +
                            nn.init.constant_(p, 0)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def _init_weights(self, m):
         | 
| 107 | 
            +
                    if isinstance(m, nn.Linear):
         | 
| 108 | 
            +
                        trunc_normal_(m.weight, std=.02)
         | 
| 109 | 
            +
                        if isinstance(m, nn.Linear) and m.bias is not None:
         | 
| 110 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 111 | 
            +
                    elif isinstance(m, nn.LayerNorm):
         | 
| 112 | 
            +
                        nn.init.constant_(m.bias, 0)
         | 
| 113 | 
            +
                        nn.init.constant_(m.weight, 1.0)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
                def forward(self, x):
         | 
| 117 | 
            +
                    #print(self.gamma_base)
         | 
| 118 | 
            +
                    x = self.conv_large(x)
         | 
| 119 | 
            +
                    x = self.generator(x)
         | 
| 120 | 
            +
                    gamma, color = x[:, 0].unsqueeze(1), x[:, 1:]
         | 
| 121 | 
            +
                    gamma = self.gamma_linear(gamma).squeeze(-1) + self.gamma_base
         | 
| 122 | 
            +
                    #print(self.gamma_base, self.gamma_linear(gamma))
         | 
| 123 | 
            +
                    color = self.color_linear(color).squeeze(-1).view(-1, 3, 3) + self.color_base
         | 
| 124 | 
            +
                    return gamma, color
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            if __name__ == "__main__":
         | 
| 127 | 
            +
                os.environ['CUDA_VISIBLE_DEVICES']='3'
         | 
| 128 | 
            +
                #net = Local_pred_new().cuda()
         | 
| 129 | 
            +
                img = torch.Tensor(8, 3, 400, 600)
         | 
| 130 | 
            +
                global_net = Global_pred()
         | 
| 131 | 
            +
                gamma, color = global_net(img)
         | 
| 132 | 
            +
                print(gamma.shape, color.shape)
         |