|
from typing import List |
|
import torch |
|
from methods.base.model import BaseModel |
|
import tqdm |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from abc import abstractmethod |
|
from methods.elasticdnn.model.base import ElasticDNNUtil |
|
from methods.elasticdnn.pipeline.offline.fm_to_md.base import FM_to_MD_Util |
|
from methods.elasticdnn.pipeline.offline.fm_lora.base import FMLoRA_Util |
|
|
|
from utils.dl.common.model import LayerActivation |
|
|
|
|
|
class ElasticDNN_OfflineFMModel(BaseModel): |
|
def get_required_model_components(self) -> List[str]: |
|
return ['main'] |
|
|
|
@abstractmethod |
|
def generate_md_by_reducing_width(self, reducing_width_ratio, samples: torch.Tensor): |
|
pass |
|
|
|
@abstractmethod |
|
def forward_to_get_task_loss(self, x, y, *args, **kwargs): |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def get_feature_hook(self) -> LayerActivation: |
|
pass |
|
|
|
@abstractmethod |
|
def get_elastic_dnn_util(self) -> ElasticDNNUtil: |
|
pass |
|
|
|
@abstractmethod |
|
def get_lora_util(self) -> FMLoRA_Util: |
|
pass |
|
|
|
@abstractmethod |
|
def get_task_head_params(self): |
|
pass |
|
|
|
|
|
class ElasticDNN_OfflineClsFMModel(ElasticDNN_OfflineFMModel): |
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
acc = 0 |
|
sample_num = 0 |
|
|
|
self.to_eval_mode() |
|
|
|
with torch.no_grad(): |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) |
|
for batch_index, (x, y) in pbar: |
|
x, y = x.to(self.device), y.to(self.device) |
|
output = self.infer(x) |
|
pred = F.softmax(output.logits, dim=1).argmax(dim=1) |
|
|
|
correct = torch.eq(pred, y).sum().item() |
|
acc += correct |
|
sample_num += len(y) |
|
|
|
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' |
|
f'cur_batch_acc: {(correct / len(y)):.4f}') |
|
|
|
acc /= sample_num |
|
return acc |
|
|
|
def infer(self, x, *args, **kwargs): |
|
return self.models_dict['main'](x) |
|
|
|
|
|
import numpy as np |
|
class StreamSegMetrics: |
|
""" |
|
Stream Metrics for Semantic Segmentation Task |
|
""" |
|
def __init__(self, n_classes): |
|
self.n_classes = n_classes |
|
self.confusion_matrix = np.zeros((n_classes, n_classes)) |
|
|
|
def update(self, label_trues, label_preds): |
|
for lt, lp in zip(label_trues, label_preds): |
|
self.confusion_matrix += self._fast_hist( lt.flatten(), lp.flatten() ) |
|
|
|
@staticmethod |
|
def to_str(results): |
|
string = "\n" |
|
for k, v in results.items(): |
|
if k!="Class IoU": |
|
string += "%s: %f\n"%(k, v) |
|
|
|
return string |
|
|
|
def _fast_hist(self, label_true, label_pred): |
|
mask = (label_true >= 0) & (label_true < self.n_classes) |
|
hist = np.bincount( |
|
self.n_classes * label_true[mask].astype(int) + label_pred[mask], |
|
minlength=self.n_classes ** 2, |
|
).reshape(self.n_classes, self.n_classes) |
|
return hist |
|
|
|
def get_results(self): |
|
"""Returns accuracy score evaluation result. |
|
- overall accuracy |
|
- mean accuracy |
|
- mean IU |
|
- fwavacc |
|
""" |
|
hist = self.confusion_matrix |
|
acc = np.diag(hist).sum() / hist.sum() |
|
acc_cls = np.diag(hist) / hist.sum(axis=1) |
|
acc_cls = np.nanmean(acc_cls) |
|
iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) |
|
mean_iu = np.nanmean(iu) |
|
freq = hist.sum(axis=1) / hist.sum() |
|
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() |
|
cls_iu = dict(zip(range(self.n_classes), iu)) |
|
|
|
return { |
|
"Overall Acc": acc, |
|
"Mean Acc": acc_cls, |
|
"FreqW Acc": fwavacc, |
|
"Mean IoU": mean_iu, |
|
"Class IoU": cls_iu, |
|
} |
|
|
|
def reset(self): |
|
self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) |
|
|
|
|
|
class ElasticDNN_OfflineSegFMModel(ElasticDNN_OfflineFMModel): |
|
def __init__(self, name: str, models_dict_path: str, device: str, num_classes): |
|
super().__init__(name, models_dict_path, device) |
|
self.num_classes = num_classes |
|
|
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
device = self.device |
|
self.to_eval_mode() |
|
metrics = StreamSegMetrics(self.num_classes) |
|
metrics.reset() |
|
import tqdm |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), leave=False, dynamic_ncols=True) |
|
with torch.no_grad(): |
|
for batch_index, (x, y) in pbar: |
|
x, y = x.to(device, dtype=x.dtype, non_blocking=True, copy=False), \ |
|
y.to(device, dtype=y.dtype, non_blocking=True, copy=False) |
|
output = self.infer(x) |
|
pred = output.detach().max(dim=1)[1].cpu().numpy() |
|
metrics.update((y + 0).cpu().numpy(), pred) |
|
|
|
res = metrics.get_results() |
|
pbar.set_description(f'cur batch mIoU: {res["Mean IoU"]:.4f}') |
|
|
|
res = metrics.get_results() |
|
return res['Mean IoU'] |
|
|
|
def infer(self, x, *args, **kwargs): |
|
return self.models_dict['main'](x) |
|
|
|
|
|
class ElasticDNN_OfflineDetFMModel(ElasticDNN_OfflineFMModel): |
|
def __init__(self, name: str, models_dict_path: str, device: str, num_classes): |
|
super().__init__(name, models_dict_path, device) |
|
self.num_classes = num_classes |
|
|
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
|
|
_d = test_loader.dataset |
|
from data import build_dataloader |
|
if _d.__class__.__name__ == 'MergedDataset': |
|
|
|
datasets = _d.datasets |
|
test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None) for d in datasets] |
|
accs = [self.get_accuracy(loader) for loader in test_loaders] |
|
|
|
return sum(accs) / len(accs) |
|
|
|
|
|
|
|
model = self.models_dict['main'] |
|
device = self.device |
|
model.eval() |
|
|
|
|
|
|
|
model = model.to(device) |
|
from dnns.yolov3.coco_evaluator import COCOEvaluator |
|
from utils.common.others import HiddenPrints |
|
with torch.no_grad(): |
|
with HiddenPrints(): |
|
evaluator = COCOEvaluator( |
|
dataloader=test_loader, |
|
img_size=(224, 224), |
|
confthre=0.01, |
|
nmsthre=0.65, |
|
num_classes=self.num_classes, |
|
testdev=False |
|
) |
|
res = evaluator.evaluate(model, False, False) |
|
map50 = res[1] |
|
|
|
return map50 |
|
|
|
def infer(self, x, *args, **kwargs): |
|
if len(args) > 0: |
|
print(args, len(args)) |
|
return self.models_dict['main'](x, *args) |
|
return self.models_dict['main'](x) |
|
|
|
|
|
class ElasticDNN_OfflineSenClsFMModel(ElasticDNN_OfflineFMModel): |
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
acc = 0 |
|
sample_num = 0 |
|
|
|
self.to_eval_mode() |
|
|
|
with torch.no_grad(): |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) |
|
for batch_index, (x, y) in pbar: |
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(self.device) |
|
y = y.to(self.device) |
|
output = self.infer(x) |
|
pred = F.softmax(output, dim=1).argmax(dim=1) |
|
correct = torch.eq(pred, y).sum().item() |
|
acc += correct |
|
sample_num += len(y) |
|
|
|
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' |
|
f'cur_batch_acc: {(correct / len(y)):.4f}') |
|
|
|
acc /= sample_num |
|
return acc |
|
|
|
def infer(self, x, *args, **kwargs): |
|
return self.models_dict['main'].forward(**x) |
|
|
|
|
|
from accelerate.utils.operations import pad_across_processes |
|
|
|
|
|
class ElasticDNN_OfflineTrFMModel(ElasticDNN_OfflineFMModel): |
|
|
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
|
|
from sacrebleu import corpus_bleu |
|
|
|
acc = 0 |
|
num_batches = 0 |
|
|
|
self.to_eval_mode() |
|
|
|
from data.datasets.sentiment_classification.global_bert_tokenizer import get_tokenizer |
|
tokenizer = get_tokenizer() |
|
|
|
def _decode(o): |
|
|
|
o = tokenizer.batch_decode(o, skip_special_tokens=True) |
|
return [oi.strip() for oi in o] |
|
|
|
with torch.no_grad(): |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) |
|
for batch_index, (x, y) in pbar: |
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(self.device) |
|
label = y.to(self.device) |
|
|
|
|
|
|
|
generated_tokens = self.infer(x).logits.argmax(-1) |
|
|
|
|
|
generated_tokens = pad_across_processes( |
|
generated_tokens, dim=1, pad_index=tokenizer.pad_token_id |
|
) |
|
|
|
label = pad_across_processes( |
|
label, dim=1, pad_index=tokenizer.pad_token_id |
|
) |
|
label = label.cpu().numpy() |
|
label = np.where(label != -100, label, tokenizer.pad_token_id) |
|
|
|
decoded_output = _decode(generated_tokens) |
|
decoded_y = _decode(y) |
|
|
|
decoded_y = [decoded_y] |
|
|
|
if batch_index == 0: |
|
print(decoded_y, decoded_output) |
|
|
|
bleu = corpus_bleu(decoded_output, decoded_y).score |
|
pbar.set_description(f'cur_batch_bleu: {bleu:.4f}') |
|
|
|
acc += bleu |
|
num_batches += 1 |
|
|
|
acc /= num_batches |
|
return acc |
|
|
|
def infer(self, x, *args, **kwargs): |
|
if 'token_type_ids' in x.keys(): |
|
del x['token_type_ids'] |
|
|
|
if 'generate' in kwargs: |
|
return self.models_dict['main'].generate( |
|
x['input_ids'], |
|
attention_mask=x["attention_mask"], |
|
max_length=512 |
|
) |
|
|
|
return self.models_dict['main'](**x) |
|
|
|
|
|
from nltk.metrics import accuracy as nltk_acc |
|
|
|
class ElasticDNN_OfflineTokenClsFMModel(ElasticDNN_OfflineFMModel): |
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
acc = 0 |
|
sample_num = 0 |
|
|
|
self.to_eval_mode() |
|
|
|
with torch.no_grad(): |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) |
|
for batch_index, (x, y) in pbar: |
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(self.device) |
|
|
|
y = y.to(self.device) |
|
output = self.infer(x) |
|
|
|
|
|
|
|
for oi, yi, xi in zip(output, y, x['input_ids']): |
|
|
|
seq_len = xi.nonzero().size(0) |
|
|
|
|
|
|
|
pred = F.softmax(oi, dim=-1).argmax(dim=-1) |
|
correct = torch.eq(pred[1: seq_len], yi[1: seq_len]).sum().item() |
|
|
|
|
|
|
|
acc += correct |
|
sample_num += seq_len |
|
|
|
pbar.set_description(f'seq_len: {seq_len}, cur_seq_acc: {(correct / seq_len):.4f}') |
|
|
|
acc /= sample_num |
|
return acc |
|
|
|
def infer(self, x, *args, **kwargs): |
|
return self.models_dict['main'](**x) |
|
|
|
|
|
class ElasticDNN_OfflineMMClsFMModel(ElasticDNN_OfflineFMModel): |
|
|
|
|
|
|
|
|
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
acc = 0 |
|
sample_num = 0 |
|
|
|
self.to_eval_mode() |
|
|
|
batch_size = 1 |
|
|
|
with torch.no_grad(): |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) |
|
for batch_index, (x, y) in pbar: |
|
if batch_index * batch_size > 2000: |
|
break |
|
|
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(self.device) |
|
y = y.to(self.device) |
|
|
|
|
|
|
|
raw_texts = x['texts'][:] |
|
x['texts'] = list(set(x['texts'])) |
|
|
|
|
|
|
|
batch_size = len(y) |
|
|
|
x['for_training'] = False |
|
|
|
output = self.infer(x) |
|
|
|
output = output.logits_per_image |
|
|
|
|
|
|
|
|
|
|
|
y = torch.LongTensor([x['texts'].index(rt) for rt in raw_texts]).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
pred = F.softmax(output, dim=1).argmax(dim=1) |
|
correct = torch.eq(pred, y).sum().item() |
|
acc += correct |
|
sample_num += len(y) |
|
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' |
|
f'cur_batch_acc: {(correct / len(y)):.4f}') |
|
|
|
acc /= sample_num |
|
return acc |
|
|
|
def infer(self, x, *args, **kwargs): |
|
x['for_training'] = self.models_dict['main'].training |
|
return self.models_dict['main'](**x) |
|
|
|
|
|
|
|
class VQAScore: |
|
def __init__(self): |
|
|
|
|
|
self.score = torch.tensor(0.0) |
|
self.total = torch.tensor(0.0) |
|
|
|
def update(self, logits, target): |
|
logits, target = ( |
|
logits.detach().float().to(self.score.device), |
|
target.detach().float().to(self.score.device), |
|
) |
|
logits = torch.max(logits, 1)[1] |
|
one_hots = torch.zeros(*target.size()).to(target) |
|
one_hots.scatter_(1, logits.view(-1, 1), 1) |
|
scores = one_hots * target |
|
|
|
self.score += scores.sum() |
|
self.total += len(logits) |
|
|
|
def compute(self): |
|
return self.score / self.total |
|
|
|
|
|
class ElasticDNN_OfflineVQAFMModel(ElasticDNN_OfflineFMModel): |
|
|
|
|
|
|
|
|
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
acc = 0 |
|
sample_num = 0 |
|
|
|
vqa_score = VQAScore() |
|
|
|
self.to_eval_mode() |
|
|
|
with torch.no_grad(): |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) |
|
for batch_index, (x, y) in pbar: |
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(self.device) |
|
y = y.to(self.device) |
|
output = self.infer(x).logits |
|
|
|
vqa_score.update(output, y) |
|
|
|
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_acc: {vqa_score.compute():.4f}') |
|
|
|
return vqa_score.compute() |
|
|
|
def infer(self, x, *args, **kwargs): |
|
return self.models_dict['main'](**x) |
|
|
|
|
|
class ElasticDNN_OfflineMDModel(BaseModel): |
|
def get_required_model_components(self) -> List[str]: |
|
return ['main'] |
|
|
|
@abstractmethod |
|
def forward_to_get_task_loss(self, x, y, *args, **kwargs): |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def get_feature_hook(self) -> LayerActivation: |
|
pass |
|
|
|
@abstractmethod |
|
def get_distill_loss(self, student_output, teacher_output): |
|
pass |
|
|
|
@abstractmethod |
|
def get_matched_param_of_fm(self, self_param_name, fm: nn.Module): |
|
pass |
|
|
|
|
|
class ElasticDNN_OfflineClsMDModel(ElasticDNN_OfflineMDModel): |
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
acc = 0 |
|
sample_num = 0 |
|
|
|
self.to_eval_mode() |
|
|
|
with torch.no_grad(): |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) |
|
for batch_index, (x, y) in pbar: |
|
x, y = x.to(self.device), y.to(self.device) |
|
output = self.infer(x) |
|
pred = F.softmax(output, dim=1).argmax(dim=1) |
|
correct = torch.eq(pred, y).sum().item() |
|
acc += correct |
|
sample_num += len(y) |
|
|
|
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' |
|
f'cur_batch_acc: {(correct / len(y)):.4f}') |
|
|
|
acc /= sample_num |
|
return acc |
|
|
|
def infer(self, x, *args, **kwargs): |
|
return self.models_dict['main'](x) |
|
|
|
|
|
class ElasticDNN_OfflineSegMDModel(ElasticDNN_OfflineMDModel): |
|
def __init__(self, name: str, models_dict_path: str, device: str, num_classes): |
|
super().__init__(name, models_dict_path, device) |
|
self.num_classes = num_classes |
|
|
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
device = self.device |
|
self.to_eval_mode() |
|
metrics = StreamSegMetrics(self.num_classes) |
|
metrics.reset() |
|
import tqdm |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), leave=False, dynamic_ncols=True) |
|
with torch.no_grad(): |
|
for batch_index, (x, y) in pbar: |
|
x, y = x.to(device, dtype=x.dtype, non_blocking=True, copy=False), \ |
|
y.to(device, dtype=y.dtype, non_blocking=True, copy=False) |
|
output = self.infer(x) |
|
pred = output.detach().max(dim=1)[1].cpu().numpy() |
|
metrics.update((y + 0).cpu().numpy(), pred) |
|
|
|
res = metrics.get_results() |
|
pbar.set_description(f'cur batch mIoU: {res["Mean IoU"]:.4f}') |
|
|
|
res = metrics.get_results() |
|
return res['Mean IoU'] |
|
|
|
def infer(self, x, *args, **kwargs): |
|
return self.models_dict['main'](x) |
|
|
|
|
|
class ElasticDNN_OfflineDetMDModel(ElasticDNN_OfflineMDModel): |
|
def __init__(self, name: str, models_dict_path: str, device: str, num_classes): |
|
super().__init__(name, models_dict_path, device) |
|
self.num_classes = num_classes |
|
|
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
|
|
_d = test_loader.dataset |
|
from data import build_dataloader |
|
if _d.__class__.__name__ == 'MergedDataset': |
|
|
|
datasets = _d.datasets |
|
test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None) for d in datasets] |
|
accs = [self.get_accuracy(loader) for loader in test_loaders] |
|
|
|
return sum(accs) / len(accs) |
|
|
|
|
|
|
|
model = self.models_dict['main'] |
|
device = self.device |
|
model.eval() |
|
|
|
|
|
|
|
model = model.to(device) |
|
from dnns.yolov3.coco_evaluator import COCOEvaluator |
|
from utils.common.others import HiddenPrints |
|
with torch.no_grad(): |
|
with HiddenPrints(): |
|
evaluator = COCOEvaluator( |
|
dataloader=test_loader, |
|
img_size=(224, 224), |
|
confthre=0.01, |
|
nmsthre=0.65, |
|
num_classes=self.num_classes, |
|
testdev=False |
|
) |
|
res = evaluator.evaluate(model, False, False) |
|
map50 = res[1] |
|
|
|
return map50 |
|
|
|
def infer(self, x, *args, **kwargs): |
|
if len(args) > 0: |
|
return self.models_dict['main'](x, *args) |
|
return self.models_dict['main'](x) |
|
|
|
|
|
class ElasticDNN_OfflineSenClsMDModel(ElasticDNN_OfflineMDModel): |
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
acc = 0 |
|
sample_num = 0 |
|
|
|
self.to_eval_mode() |
|
|
|
with torch.no_grad(): |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) |
|
for batch_index, (x, y) in pbar: |
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(self.device) |
|
y = y.to(self.device) |
|
output = self.infer(x) |
|
pred = F.softmax(output, dim=1).argmax(dim=1) |
|
|
|
|
|
|
|
correct = torch.eq(pred, y).sum().item() |
|
acc += correct |
|
sample_num += len(y) |
|
|
|
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' |
|
f'cur_batch_acc: {(correct / len(y)):.4f}') |
|
|
|
acc /= sample_num |
|
return acc |
|
|
|
def infer(self, x, *args, **kwargs): |
|
return self.models_dict['main'](**x) |
|
|
|
|
|
class ElasticDNN_OfflineTrMDModel(ElasticDNN_OfflineMDModel): |
|
|
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
|
|
from sacrebleu import corpus_bleu |
|
|
|
acc = 0 |
|
num_batches = 0 |
|
|
|
self.to_eval_mode() |
|
|
|
from data.datasets.sentiment_classification.global_bert_tokenizer import get_tokenizer |
|
tokenizer = get_tokenizer() |
|
|
|
def _decode(o): |
|
|
|
o = tokenizer.batch_decode(o, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
return [oi.strip().replace(' ', '') for oi in o] |
|
|
|
with torch.no_grad(): |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) |
|
for batch_index, (x, y) in pbar: |
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(self.device) |
|
y = y.to(self.device) |
|
|
|
output = self.infer(x) |
|
decoded_output = _decode(output.argmax(-1)) |
|
decoded_y = _decode(y) |
|
|
|
decoded_y = [decoded_y] |
|
|
|
|
|
|
|
bleu = corpus_bleu(decoded_output, decoded_y).score |
|
pbar.set_description(f'cur_batch_bleu: {bleu:.4f}') |
|
|
|
acc += bleu |
|
num_batches += 1 |
|
|
|
acc /= num_batches |
|
return acc |
|
|
|
def infer(self, x, *args, **kwargs): |
|
return self.models_dict['main'](**x) |
|
|
|
|
|
class ElasticDNN_OfflineTokenClsMDModel(ElasticDNN_OfflineMDModel): |
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
acc = 0 |
|
sample_num = 0 |
|
|
|
self.to_eval_mode() |
|
|
|
with torch.no_grad(): |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) |
|
for batch_index, (x, y) in pbar: |
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(self.device) |
|
|
|
y = y.to(self.device) |
|
output = self.infer(x) |
|
|
|
|
|
|
|
for oi, yi, xi in zip(output, y, x['input_ids']): |
|
|
|
seq_len = xi.nonzero().size(0) |
|
|
|
|
|
|
|
pred = F.softmax(oi, dim=-1).argmax(dim=-1) |
|
correct = torch.eq(pred[1: seq_len], yi[1: seq_len]).sum().item() |
|
|
|
|
|
|
|
acc += correct |
|
sample_num += seq_len |
|
|
|
pbar.set_description(f'seq_len: {seq_len}, cur_seq_acc: {(correct / seq_len):.4f}') |
|
|
|
acc /= sample_num |
|
return acc |
|
|
|
def infer(self, x, *args, **kwargs): |
|
return self.models_dict['main'](**x) |
|
|
|
|
|
class ElasticDNN_OfflineMMClsMDModel(ElasticDNN_OfflineMDModel): |
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
acc = 0 |
|
sample_num = 0 |
|
|
|
self.to_eval_mode() |
|
|
|
batch_size = 1 |
|
|
|
with torch.no_grad(): |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) |
|
for batch_index, (x, y) in pbar: |
|
if batch_index * batch_size > 2000: |
|
break |
|
|
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(self.device) |
|
y = y.to(self.device) |
|
|
|
|
|
|
|
raw_texts = x['texts'][:] |
|
x['texts'] = list(set(x['texts'])) |
|
|
|
|
|
|
|
batch_size = len(y) |
|
|
|
x['for_training'] = False |
|
|
|
output = self.infer(x) |
|
|
|
output = output.logits_per_image |
|
|
|
|
|
|
|
|
|
|
|
y = torch.LongTensor([x['texts'].index(rt) for rt in raw_texts]).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
pred = F.softmax(output, dim=1).argmax(dim=1) |
|
correct = torch.eq(pred, y).sum().item() |
|
acc += correct |
|
sample_num += len(y) |
|
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' |
|
f'cur_batch_acc: {(correct / len(y)):.4f}') |
|
|
|
acc /= sample_num |
|
return acc |
|
|
|
def infer(self, x, *args, **kwargs): |
|
return self.models_dict['main'](**x) |
|
|
|
|
|
class ElasticDNN_OfflineVQAMDModel(ElasticDNN_OfflineMDModel): |
|
|
|
|
|
|
|
|
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
acc = 0 |
|
sample_num = 0 |
|
|
|
vqa_score = VQAScore() |
|
|
|
self.to_eval_mode() |
|
|
|
with torch.no_grad(): |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) |
|
for batch_index, (x, y) in pbar: |
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(self.device) |
|
y = y.to(self.device) |
|
output = self.infer(x).logits |
|
|
|
vqa_score.update(output, y) |
|
|
|
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_acc: {vqa_score.compute():.4f}') |
|
|
|
return vqa_score.compute() |
|
|
|
def infer(self, x, *args, **kwargs): |
|
return self.models_dict['main'](**x) |