File size: 13,953 Bytes
cb80c28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
import logging
import numpy as np
from tqdm import tqdm
import torch
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from models.base import BaseLearner
from utils.inc_net import (
    IncrementalNet,
    CosineIncrementalNet,
    SimpleCosineIncrementalNet,
)
from utils.toolkit import target2onehot, tensor2numpy
import ot
from torch import nn
import copy

EPSILON = 1e-8

epochs = 100
lrate = 0.1
milestones = [40, 80]
lrate_decay = 0.1
batch_size = 32
memory_size = 2000
T = 2


class COIL(BaseLearner):
    def __init__(self, args):
        super().__init__(args)
        self._network = SimpleCosineIncrementalNet(args, False)
        self.data_manager = None
        self.nextperiod_initialization = None
        self.sinkhorn_reg = args["sinkhorn"]
        self.calibration_term = args["calibration_term"]
        self.args = args

    def after_task(self):
        self.nextperiod_initialization = self.solving_ot()
        self._old_network = self._network.copy().freeze()
        self._known_classes = self._total_classes

    def solving_ot(self):
        with torch.no_grad():
            if self._total_classes == self.data_manager.get_total_classnum():
                print("training over, no more ot solving")
                return None
            each_time_class_num = self.data_manager.get_task_size(1)
            self._extract_class_means(
                self.data_manager, 0, self._total_classes + each_time_class_num
            )
            former_class_means = torch.tensor(
                self._ot_prototype_means[: self._total_classes]
            )
            next_period_class_means = torch.tensor(
                self._ot_prototype_means[
                    self._total_classes : self._total_classes + each_time_class_num
                ]
            )
            Q_cost_matrix = torch.cdist(
                former_class_means, next_period_class_means, p=self.args["norm_term"]
            )
            # solving ot
            _mu1_vec = (
                torch.ones(len(former_class_means)) / len(former_class_means) * 1.0
            )
            _mu2_vec = (
                torch.ones(len(next_period_class_means)) / len(former_class_means) * 1.0
            )
            T = ot.sinkhorn(_mu1_vec, _mu2_vec, Q_cost_matrix, self.sinkhorn_reg)
            T = torch.tensor(T).float().cuda()
            transformed_hat_W = torch.mm(
                T.T, F.normalize(self._network.fc.weight, p=2, dim=1)
            )
            oldnorm = torch.norm(self._network.fc.weight, p=2, dim=1)
            newnorm = torch.norm(
                transformed_hat_W * len(former_class_means), p=2, dim=1
            )
            meannew = torch.mean(newnorm)
            meanold = torch.mean(oldnorm)
            gamma = meanold / meannew
            self.calibration_term = gamma
            self._ot_new_branch = (
                transformed_hat_W * len(former_class_means) * self.calibration_term
            )
        return transformed_hat_W * len(former_class_means) * self.calibration_term

    def solving_ot_to_old(self):
        current_class_num = self.data_manager.get_task_size(self._cur_task)
        self._extract_class_means_with_memory(
            self.data_manager, self._known_classes, self._total_classes
        )
        former_class_means = torch.tensor(
            self._ot_prototype_means[: self._known_classes]
        )
        next_period_class_means = torch.tensor(
            self._ot_prototype_means[self._known_classes : self._total_classes]
        )
        Q_cost_matrix = (
            torch.cdist(
                next_period_class_means, former_class_means, p=self.args["norm_term"]
            )
            + EPSILON
        )  # in case of numerical err
        _mu1_vec = torch.ones(len(former_class_means)) / len(former_class_means) * 1.0
        _mu2_vec = (
            torch.ones(len(next_period_class_means)) / len(former_class_means) * 1.0
        )
        T = ot.sinkhorn(_mu2_vec, _mu1_vec, Q_cost_matrix, self.sinkhorn_reg)
        T = torch.tensor(T).float().cuda()
        transformed_hat_W = torch.mm(
            T.T,
            F.normalize(self._network.fc.weight[-current_class_num:, :], p=2, dim=1),
        )
        return transformed_hat_W * len(former_class_means) * self.calibration_term

    def incremental_train(self, data_manager):
        self._cur_task += 1
        self._total_classes = self._known_classes + data_manager.get_task_size(
            self._cur_task
        )

        self._network.update_fc(self._total_classes, self.nextperiod_initialization)
        self.data_manager = data_manager

        logging.info(
            "Learning on {}-{}".format(self._known_classes, self._total_classes)
        )
        self.lamda = self._known_classes / self._total_classes
        # Loader
        train_dataset = data_manager.get_dataset(
            np.arange(self._known_classes, self._total_classes),
            source="train",
            mode="train",
            appendent=self._get_memory(),
        )
        self.train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
        )
        test_dataset = data_manager.get_dataset(
            np.arange(0, self._total_classes), source="test", mode="test"
        )
        self.test_loader = DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False, num_workers=4
        )

        self._train(self.train_loader, self.test_loader)
        
        if self.args['fixed_memory']:
            examplar_size = self.args["memory_per_class"]
        else:
            examplar_size = memory_size // self._total_classes
        self._reduce_exemplar(data_manager, examplar_size)
        self._construct_exemplar(data_manager, examplar_size)

    def _train(self, train_loader, test_loader):
        self._network.to(self._device)
        if self._old_network is not None:
            self._old_network.to(self._device)
        optimizer = optim.SGD(
            self._network.parameters(), lr=lrate, momentum=0.9, weight_decay=5e-4
        )  # 1e-5
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer=optimizer, milestones=milestones, gamma=lrate_decay
        )
        self._update_representation(train_loader, test_loader, optimizer, scheduler)

    def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
        prog_bar = tqdm(range(epochs))
        for _, epoch in enumerate(prog_bar):
            weight_ot_init = max(1.0 - (epoch / 2) ** 2, 0)
            weight_ot_co_tuning = (epoch / epochs) ** 2.0

            self._network.train()
            losses = 0.0
            correct, total = 0, 0

            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                output = self._network(inputs)
                logits = output["logits"]
                onehots = target2onehot(targets, self._total_classes)

                clf_loss = F.cross_entropy(logits, targets)
                if self._old_network is not None:

                    old_logits = self._old_network(inputs)["logits"].detach()
                    hat_pai_k = F.softmax(old_logits / T, dim=1)
                    log_pai_k = F.log_softmax(
                        logits[:, : self._known_classes] / T, dim=1
                    )
                    distill_loss = -torch.mean(torch.sum(hat_pai_k * log_pai_k, dim=1))

                    if epoch < 1:
                        features = F.normalize(output["features"], p=2, dim=1)
                        current_logit_new = F.log_softmax(
                            logits[:, self._known_classes :] / T, dim=1
                        )
                        new_logit_by_wnew_init_by_ot = F.linear(
                            features, F.normalize(self._ot_new_branch, p=2, dim=1)
                        )
                        new_logit_by_wnew_init_by_ot = F.softmax(
                            new_logit_by_wnew_init_by_ot / T, dim=1
                        )
                        new_branch_distill_loss = -torch.mean(
                            torch.sum(
                                current_logit_new * new_logit_by_wnew_init_by_ot, dim=1
                            )
                        )

                        loss = (
                            distill_loss * self.lamda
                            + clf_loss * (1 - self.lamda)
                            + 0.001 * (weight_ot_init * new_branch_distill_loss)
                        )
                    else:
                        features = F.normalize(output["features"], p=2, dim=1)
                        if i % 30 == 0:
                            with torch.no_grad():
                                self._ot_old_branch = self.solving_ot_to_old()
                        old_logit_by_wold_init_by_ot = F.linear(
                            features, F.normalize(self._ot_old_branch, p=2, dim=1)
                        )
                        old_logit_by_wold_init_by_ot = F.log_softmax(
                            old_logit_by_wold_init_by_ot / T, dim=1
                        )
                        old_branch_distill_loss = -torch.mean(
                            torch.sum(hat_pai_k * old_logit_by_wold_init_by_ot, dim=1)
                        )
                        loss = (
                            distill_loss * self.lamda
                            + clf_loss * (1 - self.lamda)
                            + self.args["reg_term"]
                            * (weight_ot_co_tuning * old_branch_distill_loss)
                        )
                else:
                    loss = clf_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses += loss.item()

                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
            test_acc = self._compute_accuracy(self._network, test_loader)
            info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
                self._cur_task,
                epoch + 1,
                epochs,
                losses / len(train_loader),
                train_acc,
                test_acc,
            )
            prog_bar.set_description(info)

        logging.info(info)

    def _extract_class_means(self, data_manager, low, high):
        self._ot_prototype_means = np.zeros(
            (data_manager.get_total_classnum(), self._network.feature_dim)
        )
        with torch.no_grad():
            for class_idx in range(low, high):
                data, targets, idx_dataset = data_manager.get_dataset(
                    np.arange(class_idx, class_idx + 1),
                    source="train",
                    mode="test",
                    ret_data=True,
                )
                idx_loader = DataLoader(
                    idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
                )
                vectors, _ = self._extract_vectors(idx_loader)
                vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
                class_mean = np.mean(vectors, axis=0)
                class_mean = class_mean / (np.linalg.norm(class_mean))
                self._ot_prototype_means[class_idx, :] = class_mean
        self._network.train()

    def _extract_class_means_with_memory(self, data_manager, low, high):

        self._ot_prototype_means = np.zeros(
            (data_manager.get_total_classnum(), self._network.feature_dim)
        )
        memoryx, memoryy = self._data_memory, self._targets_memory
        with torch.no_grad():
            for class_idx in range(0, low):
                idxes = np.where(
                    np.logical_and(memoryy >= class_idx, memoryy < class_idx + 1)
                )[0]
                data, targets = memoryx[idxes], memoryy[idxes]
                # idx_dataset=TensorDataset(data,targets)
                # idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
                _, _, idx_dataset = data_manager.get_dataset(
                    [],
                    source="train",
                    appendent=(data, targets),
                    mode="test",
                    ret_data=True,
                )
                idx_loader = DataLoader(
                    idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
                )
                vectors, _ = self._extract_vectors(idx_loader)
                vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
                class_mean = np.mean(vectors, axis=0)
                class_mean = class_mean / np.linalg.norm(class_mean)
                self._ot_prototype_means[class_idx, :] = class_mean

            for class_idx in range(low, high):
                data, targets, idx_dataset = data_manager.get_dataset(
                    np.arange(class_idx, class_idx + 1),
                    source="train",
                    mode="test",
                    ret_data=True,
                )
                idx_loader = DataLoader(
                    idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
                )
                vectors, _ = self._extract_vectors(idx_loader)
                vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
                class_mean = np.mean(vectors, axis=0)
                class_mean = class_mean / np.linalg.norm(class_mean)
                self._ot_prototype_means[class_idx, :] = class_mean
        self._network.train()