Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 4,753 Bytes
			
			| 251e479 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | import gc
import torch
import torch.nn.functional as F
from flow.flow_utils import flow_warp
# AdaIn
def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std
class AttentionControl():
    def __init__(self, inner_strength, mask_period, cross_period, ada_period,
                 warp_period):
        self.step_store = self.get_empty_store()
        self.cur_step = 0
        self.total_step = 0
        self.cur_index = 0
        self.init_store = False
        self.restore = False
        self.update = False
        self.flow = None
        self.mask = None
        self.restorex0 = False
        self.updatex0 = False
        self.inner_strength = inner_strength
        self.cross_period = cross_period
        self.mask_period = mask_period
        self.ada_period = ada_period
        self.warp_period = warp_period
    @staticmethod
    def get_empty_store():
        return {
            'first': [],
            'previous': [],
            'x0_previous': [],
            'first_ada': []
        }
    def forward(self, context, is_cross: bool, place_in_unet: str):
        cross_period = (self.total_step * self.cross_period[0],
                        self.total_step * self.cross_period[1])
        if not is_cross and place_in_unet == 'up':
            if self.init_store:
                self.step_store['first'].append(context.detach())
                self.step_store['previous'].append(context.detach())
            if self.update:
                tmp = context.clone().detach()
            if self.restore and self.cur_step >= cross_period[0] and \
                    self.cur_step <= cross_period[1]:
                context = torch.cat(
                    (self.step_store['first'][self.cur_index],
                     self.step_store['previous'][self.cur_index]),
                    dim=1).clone()
            if self.update:
                self.step_store['previous'][self.cur_index] = tmp
            self.cur_index += 1
        return context
    def update_x0(self, x0):
        if self.init_store:
            self.step_store['x0_previous'].append(x0.detach())
            style_mean, style_std = calc_mean_std(x0.detach())
            self.step_store['first_ada'].append(style_mean.detach())
            self.step_store['first_ada'].append(style_std.detach())
        if self.updatex0:
            tmp = x0.clone().detach()
        if self.restorex0:
            if self.cur_step >= self.total_step * self.ada_period[
                    0] and self.cur_step <= self.total_step * self.ada_period[
                        1]:
                x0 = F.instance_norm(x0) * self.step_store['first_ada'][
                    2 * self.cur_step +
                    1] + self.step_store['first_ada'][2 * self.cur_step]
            if self.cur_step >= self.total_step * self.warp_period[
                    0] and self.cur_step <= self.total_step * self.warp_period[
                        1]:
                pre = self.step_store['x0_previous'][self.cur_step]
                x0 = flow_warp(pre, self.flow, mode='nearest') * self.mask + (
                    1 - self.mask) * x0
        if self.updatex0:
            self.step_store['x0_previous'][self.cur_step] = tmp
        return x0
    def set_warp(self, flow, mask):
        self.flow = flow.clone()
        self.mask = mask.clone()
    def __call__(self, context, is_cross: bool, place_in_unet: str):
        context = self.forward(context, is_cross, place_in_unet)
        return context
    def set_step(self, step):
        self.cur_step = step
    def set_total_step(self, total_step):
        self.total_step = total_step
        self.cur_index = 0
    def clear_store(self):
        del self.step_store
        torch.cuda.empty_cache()
        gc.collect()
        self.step_store = self.get_empty_store()
    def set_task(self, task, restore_step=1.0):
        self.init_store = False
        self.restore = False
        self.update = False
        self.cur_index = 0
        self.restore_step = restore_step
        self.updatex0 = False
        self.restorex0 = False
        if 'initfirst' in task:
            self.init_store = True
            self.clear_store()
        if 'updatestyle' in task:
            self.update = True
        if 'keepstyle' in task:
            self.restore = True
        if 'updatex0' in task:
            self.updatex0 = True
        if 'keepx0' in task:
            self.restorex0 = True
 |