LINC-BIT's picture
Upload 1912 files
b84549f verified
from typing import List
import torch
from methods.base.model import BaseModel
import tqdm
from torch import nn
import torch.nn.functional as F
from abc import abstractmethod
from methods.elasticdnn.model.base import ElasticDNNUtil
from methods.elasticdnn.pipeline.offline.fm_to_md.base import FM_to_MD_Util
from methods.elasticdnn.pipeline.offline.fm_lora.base import FMLoRA_Util
from utils.dl.common.model import LayerActivation
class ElasticDNN_OfflineFMModel(BaseModel):
def get_required_model_components(self) -> List[str]:
return ['main']
@abstractmethod
def generate_md_by_reducing_width(self, reducing_width_ratio, samples: torch.Tensor):
pass
@abstractmethod
def forward_to_get_task_loss(self, x, y, *args, **kwargs):
raise NotImplementedError
@abstractmethod
def get_feature_hook(self) -> LayerActivation:
pass
@abstractmethod
def get_elastic_dnn_util(self) -> ElasticDNNUtil:
pass
@abstractmethod
def get_lora_util(self) -> FMLoRA_Util:
pass
@abstractmethod
def get_task_head_params(self):
pass
class ElasticDNN_OfflineClsFMModel(ElasticDNN_OfflineFMModel):
def get_accuracy(self, test_loader, *args, **kwargs):
acc = 0
sample_num = 0
self.to_eval_mode()
with torch.no_grad():
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False)
for batch_index, (x, y) in pbar:
x, y = x.to(self.device), y.to(self.device)
output = self.infer(x)
pred = F.softmax(output.logits, dim=1).argmax(dim=1)
#correct = torch.eq(torch.argmax(output.logits,dim = 1), y).sum().item()
correct = torch.eq(pred, y).sum().item()
acc += correct
sample_num += len(y)
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, '
f'cur_batch_acc: {(correct / len(y)):.4f}')
acc /= sample_num
return acc
def infer(self, x, *args, **kwargs):
return self.models_dict['main'](x)
import numpy as np
class StreamSegMetrics:
"""
Stream Metrics for Semantic Segmentation Task
"""
def __init__(self, n_classes):
self.n_classes = n_classes
self.confusion_matrix = np.zeros((n_classes, n_classes))
def update(self, label_trues, label_preds):
for lt, lp in zip(label_trues, label_preds):
self.confusion_matrix += self._fast_hist( lt.flatten(), lp.flatten() )
@staticmethod
def to_str(results):
string = "\n"
for k, v in results.items():
if k!="Class IoU":
string += "%s: %f\n"%(k, v)
return string
def _fast_hist(self, label_true, label_pred):
mask = (label_true >= 0) & (label_true < self.n_classes)
hist = np.bincount(
self.n_classes * label_true[mask].astype(int) + label_pred[mask],
minlength=self.n_classes ** 2,
).reshape(self.n_classes, self.n_classes)
return hist
def get_results(self):
"""Returns accuracy score evaluation result.
- overall accuracy
- mean accuracy
- mean IU
- fwavacc
"""
hist = self.confusion_matrix
acc = np.diag(hist).sum() / hist.sum()
acc_cls = np.diag(hist) / hist.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
mean_iu = np.nanmean(iu)
freq = hist.sum(axis=1) / hist.sum()
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
cls_iu = dict(zip(range(self.n_classes), iu))
return {
"Overall Acc": acc,
"Mean Acc": acc_cls,
"FreqW Acc": fwavacc,
"Mean IoU": mean_iu,
"Class IoU": cls_iu,
}
def reset(self):
self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
class ElasticDNN_OfflineSegFMModel(ElasticDNN_OfflineFMModel):
def __init__(self, name: str, models_dict_path: str, device: str, num_classes):
super().__init__(name, models_dict_path, device)
self.num_classes = num_classes
def get_accuracy(self, test_loader, *args, **kwargs):
device = self.device
self.to_eval_mode()
metrics = StreamSegMetrics(self.num_classes)
metrics.reset()
import tqdm
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), leave=False, dynamic_ncols=True)
with torch.no_grad():
for batch_index, (x, y) in pbar:
x, y = x.to(device, dtype=x.dtype, non_blocking=True, copy=False), \
y.to(device, dtype=y.dtype, non_blocking=True, copy=False)
output = self.infer(x)
pred = output.detach().max(dim=1)[1].cpu().numpy()
metrics.update((y + 0).cpu().numpy(), pred)
res = metrics.get_results()
pbar.set_description(f'cur batch mIoU: {res["Mean IoU"]:.4f}')
res = metrics.get_results()
return res['Mean IoU']
def infer(self, x, *args, **kwargs):
return self.models_dict['main'](x)
class ElasticDNN_OfflineDetFMModel(ElasticDNN_OfflineFMModel):
def __init__(self, name: str, models_dict_path: str, device: str, num_classes):
super().__init__(name, models_dict_path, device)
self.num_classes = num_classes
def get_accuracy(self, test_loader, *args, **kwargs):
# print('DeeplabV3: start test acc')
_d = test_loader.dataset
from data import build_dataloader
if _d.__class__.__name__ == 'MergedDataset':
# print('\neval on merged datasets')
datasets = _d.datasets
test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None) for d in datasets]
accs = [self.get_accuracy(loader) for loader in test_loaders]
# print(accs)
return sum(accs) / len(accs)
# print('dataset len', len(test_loader.dataset))
model = self.models_dict['main']
device = self.device
model.eval()
# print('# classes', model.num_classes)
model = model.to(device)
from dnns.yolov3.coco_evaluator import COCOEvaluator
from utils.common.others import HiddenPrints
with torch.no_grad():
with HiddenPrints():
evaluator = COCOEvaluator(
dataloader=test_loader,
img_size=(224, 224),
confthre=0.01,
nmsthre=0.65,
num_classes=self.num_classes,
testdev=False
)
res = evaluator.evaluate(model, False, False)
map50 = res[1]
# print('eval info', res[-1])
return map50
def infer(self, x, *args, **kwargs):
if len(args) > 0:
print(args, len(args))
return self.models_dict['main'](x, *args) # forward(x, label)
return self.models_dict['main'](x)
class ElasticDNN_OfflineSenClsFMModel(ElasticDNN_OfflineFMModel):
def get_accuracy(self, test_loader, *args, **kwargs):
acc = 0
sample_num = 0
self.to_eval_mode()
with torch.no_grad():
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False)
for batch_index, (x, y) in pbar:
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(self.device)
y = y.to(self.device)
output = self.infer(x)
pred = F.softmax(output, dim=1).argmax(dim=1)
correct = torch.eq(pred, y).sum().item()
acc += correct
sample_num += len(y)
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, '
f'cur_batch_acc: {(correct / len(y)):.4f}')
acc /= sample_num
return acc
def infer(self, x, *args, **kwargs):
return self.models_dict['main'].forward(**x)
from accelerate.utils.operations import pad_across_processes
class ElasticDNN_OfflineTrFMModel(ElasticDNN_OfflineFMModel):
def get_accuracy(self, test_loader, *args, **kwargs):
# TODO: BLEU
from sacrebleu import corpus_bleu
acc = 0
num_batches = 0
self.to_eval_mode()
from data.datasets.sentiment_classification.global_bert_tokenizer import get_tokenizer
tokenizer = get_tokenizer()
def _decode(o):
# https://github.com/huggingface/transformers/blob/main/examples/research_projects/seq2seq-distillation/finetune.py#L133
o = tokenizer.batch_decode(o, skip_special_tokens=True)
return [oi.strip() for oi in o]
with torch.no_grad():
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False)
for batch_index, (x, y) in pbar:
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(self.device)
label = y.to(self.device)
# generated_tokens = self.infer(x, generate=True)
generated_tokens = self.infer(x).logits.argmax(-1)
# pad tokens
generated_tokens = pad_across_processes(
generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
)
# pad label
label = pad_across_processes(
label, dim=1, pad_index=tokenizer.pad_token_id
)
label = label.cpu().numpy()
label = np.where(label != -100, label, tokenizer.pad_token_id)
decoded_output = _decode(generated_tokens)
decoded_y = _decode(y)
decoded_y = [decoded_y]
if batch_index == 0:
print(decoded_y, decoded_output)
bleu = corpus_bleu(decoded_output, decoded_y).score
pbar.set_description(f'cur_batch_bleu: {bleu:.4f}')
acc += bleu
num_batches += 1
acc /= num_batches
return acc
def infer(self, x, *args, **kwargs):
if 'token_type_ids' in x.keys():
del x['token_type_ids']
if 'generate' in kwargs:
return self.models_dict['main'].generate(
x['input_ids'],
attention_mask=x["attention_mask"],
max_length=512
)
return self.models_dict['main'](**x)
from nltk.metrics import accuracy as nltk_acc
class ElasticDNN_OfflineTokenClsFMModel(ElasticDNN_OfflineFMModel):
def get_accuracy(self, test_loader, *args, **kwargs):
acc = 0
sample_num = 0
self.to_eval_mode()
with torch.no_grad():
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False)
for batch_index, (x, y) in pbar:
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(self.device)
# print(x)
y = y.to(self.device)
output = self.infer(x)
# torch.Size([16, 512, 43]) torch.Size([16, 512])
for oi, yi, xi in zip(output, y, x['input_ids']):
# oi: 512, 43; yi: 512
seq_len = xi.nonzero().size(0)
# print(output.size(), y.size())
pred = F.softmax(oi, dim=-1).argmax(dim=-1)
correct = torch.eq(pred[1: seq_len], yi[1: seq_len]).sum().item()
# print(output.size(), y.size())
acc += correct
sample_num += seq_len
pbar.set_description(f'seq_len: {seq_len}, cur_seq_acc: {(correct / seq_len):.4f}')
acc /= sample_num
return acc
def infer(self, x, *args, **kwargs):
return self.models_dict['main'](**x)
class ElasticDNN_OfflineMMClsFMModel(ElasticDNN_OfflineFMModel):
# def __init__(self, name: str, models_dict_path: str, device: str, class_to_label_idx_map):
# super().__init__(name, models_dict_path, device)
# self.class_to_label_idx_map = class_to_label_idx_map
def get_accuracy(self, test_loader, *args, **kwargs):
acc = 0
sample_num = 0
self.to_eval_mode()
batch_size = 1
with torch.no_grad():
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False)
for batch_index, (x, y) in pbar:
if batch_index * batch_size > 2000:
break
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(self.device)
y = y.to(self.device)
# print(x)
raw_texts = x['texts'][:]
x['texts'] = list(set(x['texts']))
# print(x['texts'])
batch_size = len(y)
x['for_training'] = False
output = self.infer(x)
output = output.logits_per_image
# print(output.size())
# exit()
# y = torch.arange(len(y), device=self.device)
y = torch.LongTensor([x['texts'].index(rt) for rt in raw_texts]).to(self.device)
# print(y)
# exit()
# print(output.size(), y.size())
pred = F.softmax(output, dim=1).argmax(dim=1)
correct = torch.eq(pred, y).sum().item()
acc += correct
sample_num += len(y)
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, '
f'cur_batch_acc: {(correct / len(y)):.4f}')
acc /= sample_num
return acc
def infer(self, x, *args, **kwargs):
x['for_training'] = self.models_dict['main'].training
return self.models_dict['main'](**x)
class VQAScore:
def __init__(self):
# self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum")
# self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.score = torch.tensor(0.0)
self.total = torch.tensor(0.0)
def update(self, logits, target):
logits, target = (
logits.detach().float().to(self.score.device),
target.detach().float().to(self.score.device),
)
logits = torch.max(logits, 1)[1]
one_hots = torch.zeros(*target.size()).to(target)
one_hots.scatter_(1, logits.view(-1, 1), 1)
scores = one_hots * target
self.score += scores.sum()
self.total += len(logits)
def compute(self):
return self.score / self.total
class ElasticDNN_OfflineVQAFMModel(ElasticDNN_OfflineFMModel):
# def __init__(self, name: str, models_dict_path: str, device: str, class_to_label_idx_map):
# super().__init__(name, models_dict_path, device)
# self.class_to_label_idx_map = class_to_label_idx_map
def get_accuracy(self, test_loader, *args, **kwargs):
acc = 0
sample_num = 0
vqa_score = VQAScore()
self.to_eval_mode()
with torch.no_grad():
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False)
for batch_index, (x, y) in pbar:
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(self.device)
y = y.to(self.device)
output = self.infer(x).logits
vqa_score.update(output, y)
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_acc: {vqa_score.compute():.4f}')
return vqa_score.compute()
def infer(self, x, *args, **kwargs):
return self.models_dict['main'](**x)
class ElasticDNN_OfflineMDModel(BaseModel):
def get_required_model_components(self) -> List[str]:
return ['main']
@abstractmethod
def forward_to_get_task_loss(self, x, y, *args, **kwargs):
raise NotImplementedError
@abstractmethod
def get_feature_hook(self) -> LayerActivation:
pass
@abstractmethod
def get_distill_loss(self, student_output, teacher_output):
pass
@abstractmethod
def get_matched_param_of_fm(self, self_param_name, fm: nn.Module):
pass
class ElasticDNN_OfflineClsMDModel(ElasticDNN_OfflineMDModel):
def get_accuracy(self, test_loader, *args, **kwargs):
acc = 0
sample_num = 0
self.to_eval_mode()
with torch.no_grad():
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False)
for batch_index, (x, y) in pbar:
x, y = x.to(self.device), y.to(self.device)
output = self.infer(x)
pred = F.softmax(output, dim=1).argmax(dim=1)
correct = torch.eq(pred, y).sum().item()
acc += correct
sample_num += len(y)
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, '
f'cur_batch_acc: {(correct / len(y)):.4f}')
acc /= sample_num
return acc
def infer(self, x, *args, **kwargs):
return self.models_dict['main'](x)
class ElasticDNN_OfflineSegMDModel(ElasticDNN_OfflineMDModel):
def __init__(self, name: str, models_dict_path: str, device: str, num_classes):
super().__init__(name, models_dict_path, device)
self.num_classes = num_classes
def get_accuracy(self, test_loader, *args, **kwargs):
device = self.device
self.to_eval_mode()
metrics = StreamSegMetrics(self.num_classes)
metrics.reset()
import tqdm
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), leave=False, dynamic_ncols=True)
with torch.no_grad():
for batch_index, (x, y) in pbar:
x, y = x.to(device, dtype=x.dtype, non_blocking=True, copy=False), \
y.to(device, dtype=y.dtype, non_blocking=True, copy=False)
output = self.infer(x)
pred = output.detach().max(dim=1)[1].cpu().numpy()
metrics.update((y + 0).cpu().numpy(), pred)
res = metrics.get_results()
pbar.set_description(f'cur batch mIoU: {res["Mean IoU"]:.4f}')
res = metrics.get_results()
return res['Mean IoU']
def infer(self, x, *args, **kwargs):
return self.models_dict['main'](x)
class ElasticDNN_OfflineDetMDModel(ElasticDNN_OfflineMDModel):
def __init__(self, name: str, models_dict_path: str, device: str, num_classes):
super().__init__(name, models_dict_path, device)
self.num_classes = num_classes
def get_accuracy(self, test_loader, *args, **kwargs):
# print('DeeplabV3: start test acc')
_d = test_loader.dataset
from data import build_dataloader
if _d.__class__.__name__ == 'MergedDataset':
# print('\neval on merged datasets')
datasets = _d.datasets
test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None) for d in datasets]
accs = [self.get_accuracy(loader) for loader in test_loaders]
# print(accs)
return sum(accs) / len(accs)
# print('dataset len', len(test_loader.dataset))
model = self.models_dict['main']
device = self.device
model.eval()
# print('# classes', model.num_classes)
model = model.to(device)
from dnns.yolov3.coco_evaluator import COCOEvaluator
from utils.common.others import HiddenPrints
with torch.no_grad():
with HiddenPrints():
evaluator = COCOEvaluator(
dataloader=test_loader,
img_size=(224, 224),
confthre=0.01,
nmsthre=0.65,
num_classes=self.num_classes,
testdev=False
)
res = evaluator.evaluate(model, False, False)
map50 = res[1]
# print('eval info', res[-1])
return map50
def infer(self, x, *args, **kwargs):
if len(args) > 0:
return self.models_dict['main'](x, *args) # forward(x, label)
return self.models_dict['main'](x)
class ElasticDNN_OfflineSenClsMDModel(ElasticDNN_OfflineMDModel):
def get_accuracy(self, test_loader, *args, **kwargs):
acc = 0
sample_num = 0
self.to_eval_mode()
with torch.no_grad():
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False)
for batch_index, (x, y) in pbar:
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(self.device)
y = y.to(self.device)
output = self.infer(x)
pred = F.softmax(output, dim=1).argmax(dim=1)
# print(pred, y)
correct = torch.eq(pred, y).sum().item()
acc += correct
sample_num += len(y)
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, '
f'cur_batch_acc: {(correct / len(y)):.4f}')
acc /= sample_num
return acc
def infer(self, x, *args, **kwargs):
return self.models_dict['main'](**x)
class ElasticDNN_OfflineTrMDModel(ElasticDNN_OfflineMDModel):
def get_accuracy(self, test_loader, *args, **kwargs):
# TODO: BLEU
from sacrebleu import corpus_bleu
acc = 0
num_batches = 0
self.to_eval_mode()
from data.datasets.sentiment_classification.global_bert_tokenizer import get_tokenizer
tokenizer = get_tokenizer()
def _decode(o):
# https://github.com/huggingface/transformers/blob/main/examples/research_projects/seq2seq-distillation/finetune.py#L133
o = tokenizer.batch_decode(o, skip_special_tokens=True, clean_up_tokenization_spaces=True)
return [oi.strip().replace(' ', '') for oi in o]
with torch.no_grad():
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False)
for batch_index, (x, y) in pbar:
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(self.device)
y = y.to(self.device)
output = self.infer(x)
decoded_output = _decode(output.argmax(-1))
decoded_y = _decode(y)
decoded_y = [decoded_y]
# print(x, decoded_y, decoded_output, output.argmax(-1))
bleu = corpus_bleu(decoded_output, decoded_y).score
pbar.set_description(f'cur_batch_bleu: {bleu:.4f}')
acc += bleu
num_batches += 1
acc /= num_batches
return acc
def infer(self, x, *args, **kwargs):
return self.models_dict['main'](**x)
class ElasticDNN_OfflineTokenClsMDModel(ElasticDNN_OfflineMDModel):
def get_accuracy(self, test_loader, *args, **kwargs):
acc = 0
sample_num = 0
self.to_eval_mode()
with torch.no_grad():
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False)
for batch_index, (x, y) in pbar:
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(self.device)
# print(x)
y = y.to(self.device)
output = self.infer(x)
# torch.Size([16, 512, 43]) torch.Size([16, 512])
for oi, yi, xi in zip(output, y, x['input_ids']):
# oi: 512, 43; yi: 512
seq_len = xi.nonzero().size(0)
# print(output.size(), y.size())
pred = F.softmax(oi, dim=-1).argmax(dim=-1)
correct = torch.eq(pred[1: seq_len], yi[1: seq_len]).sum().item()
# print(output.size(), y.size())
acc += correct
sample_num += seq_len
pbar.set_description(f'seq_len: {seq_len}, cur_seq_acc: {(correct / seq_len):.4f}')
acc /= sample_num
return acc
def infer(self, x, *args, **kwargs):
return self.models_dict['main'](**x)
class ElasticDNN_OfflineMMClsMDModel(ElasticDNN_OfflineMDModel):
def get_accuracy(self, test_loader, *args, **kwargs):
acc = 0
sample_num = 0
self.to_eval_mode()
batch_size = 1
with torch.no_grad():
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False)
for batch_index, (x, y) in pbar:
if batch_index * batch_size > 2000:
break
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(self.device)
y = y.to(self.device)
# print(x)
raw_texts = x['texts'][:]
x['texts'] = list(set(x['texts']))
# print(x['texts'])
batch_size = len(y)
x['for_training'] = False
output = self.infer(x)
output = output.logits_per_image
# print(output.size())
# exit()
# y = torch.arange(len(y), device=self.device)
y = torch.LongTensor([x['texts'].index(rt) for rt in raw_texts]).to(self.device)
# print(y)
# exit()
# print(output.size(), y.size())
pred = F.softmax(output, dim=1).argmax(dim=1)
correct = torch.eq(pred, y).sum().item()
acc += correct
sample_num += len(y)
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, '
f'cur_batch_acc: {(correct / len(y)):.4f}')
acc /= sample_num
return acc
def infer(self, x, *args, **kwargs):
return self.models_dict['main'](**x)
class ElasticDNN_OfflineVQAMDModel(ElasticDNN_OfflineMDModel):
# def __init__(self, name: str, models_dict_path: str, device: str, class_to_label_idx_map):
# super().__init__(name, models_dict_path, device)
# self.class_to_label_idx_map = class_to_label_idx_map
def get_accuracy(self, test_loader, *args, **kwargs):
acc = 0
sample_num = 0
vqa_score = VQAScore()
self.to_eval_mode()
with torch.no_grad():
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False)
for batch_index, (x, y) in pbar:
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(self.device)
y = y.to(self.device)
output = self.infer(x).logits
vqa_score.update(output, y)
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_acc: {vqa_score.compute():.4f}')
return vqa_score.compute()
def infer(self, x, *args, **kwargs):
return self.models_dict['main'](**x)