SER_Naturalistic / utils /loss_manager_org.py
Samara369's picture
Upload 96 files
901595e verified
raw
history blame
3.19 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
import torch.autograd as autograd
from collections import defaultdict
class LogManager:
def __init__(self):
self.log_book=defaultdict(lambda: [])
def alloc_stat_type(self, stat_type):
self.log_book[stat_type] = []
def alloc_stat_type_list(self, stat_type_list):
for stat_type in stat_type_list:
self.alloc_stat_type(stat_type)
def init_stat(self):
for stat_type in self.log_book.keys():
self.log_book[stat_type] = []
def add_stat(self, stat_type, stat):
assert stat_type in self.log_book, "Wrong stat type"
self.log_book[stat_type].append(stat)
def add_torch_stat(self, stat_type, stat):
assert stat_type in self.log_book, "Wrong stat type"
self.log_book[stat_type].append(stat.detach().cpu().item())
def get_stat(self, stat_type):
result_stat = 0
stat_list = self.log_book[stat_type]
if len(stat_list) != 0:
result_stat = np.mean(stat_list)
result_stat = np.round(result_stat, 4)
return result_stat
def print_stat(self):
for stat_type in self.log_book.keys():
if len(self.log_book[stat_type]) == 0:
continue
stat = self.get_stat(stat_type)
print(stat_type,":",stat, end=' / ')
print(" ")
def get_stat_str(self):
result_str = ""
for stat_type in self.log_book.keys():
if len(self.log_book[stat_type]) == 0:
continue
stat = self.get_stat(stat_type)
result_str += str(stat) + " / "
return result_str
def CCC_loss(pred, lab, m_lab=None, v_lab=None, is_numpy=False):
"""
pred: (N, 3)
lab: (N, 3)
"""
if is_numpy:
pred = torch.Tensor(pred).float().cuda()
lab = torch.Tensor(lab).float().cuda()
m_pred = torch.mean(pred, 0, keepdim=True)
m_lab = torch.mean(lab, 0, keepdim=True)
d_pred = pred - m_pred
d_lab = lab - m_lab
v_pred = torch.var(pred, 0, unbiased=False)
v_lab = torch.var(lab, 0, unbiased=False)
corr = torch.sum(d_pred * d_lab, 0) / (torch.sqrt(torch.sum(d_pred ** 2, 0)) * torch.sqrt(torch.sum(d_lab ** 2, 0)))
s_pred = torch.std(pred, 0, unbiased=False)
s_lab = torch.std(lab, 0, unbiased=False)
ccc = (2*corr*s_pred*s_lab) / (v_pred + v_lab + (m_pred[0]-m_lab[0])**2)
return ccc
def MSE_emotion(pred, lab):
aro_loss = F.mse_loss(pred[:][0], lab[:][0])
dom_loss = F.mse_loss(pred[:][1], lab[:][1])
val_loss = F.mse_loss(pred[:][2], lab[:][2])
return [aro_loss, dom_loss, val_loss]
def CE_weight_category(pred, lab, weights):
criterion = torch.nn.CrossEntropyLoss(weight=weights)
return criterion(pred, lab)
def calc_err(pred, lab):
p = pred.detach()
t = lab.detach()
total_num = p.size()[0]
ans = torch.argmax(p, dim=1)
corr = torch.sum((ans==t).long())
err = (total_num-corr) / total_num
return err
def calc_acc(pred, lab):
err = calc_err(pred, lab)
return 1.0 - err