SER_Naturalistic / net /ser_atte.py
Samara369's picture
Upload 96 files
901595e verified
raw
history blame
1.22 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class EmotionRegression(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, output_dim, dropout=0.5):
super(EmotionRegression, self).__init__()
#input_dim = args[0]
#hidden_dim = args[1]
#num_layers = args[2]
#output_dim = args[3]
#p = kwargs.get("dropout", 0.5)
self.fc=nn.ModuleList([
nn.Sequential(
nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(dropout)
)
])
for lidx in range(num_layers-1):
self.fc.append(
nn.Sequential(
nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(dropout)
)
)
self.out = nn.Sequential(
nn.Linear(hidden_dim, output_dim)
)
self.inp_drop = nn.Dropout(dropout)
def get_repr(self, x):
h = self.inp_drop(x)
for lidx, fc in enumerate(self.fc):
h=fc(h)
return h
def forward(self, x):
h=self.get_repr(x)
result = self.out(h)
return result