rewicks commited on
Commit
c9d06c3
·
verified ·
1 Parent(s): fa320a1

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +60 -100
model.py CHANGED
@@ -1,12 +1,11 @@
1
  import torch
2
  import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
  from transformers import PreTrainedModel
6
 
7
  from typing import List
8
 
9
- from .config import LidirlLSTMConfig
 
10
 
11
  def torch_max_no_pads(model_out, lengths):
12
  indices = torch.arange(model_out.size(1)).to(model_out.device)
@@ -15,6 +14,12 @@ def torch_max_no_pads(model_out, lengths):
15
  max_pool = torch.max(model_out, 1)[0]
16
  return max_pool
17
 
 
 
 
 
 
 
18
 
19
  class ProjectionLayer(nn.Module):
20
  """
@@ -33,115 +38,70 @@ class ProjectionLayer(nn.Module):
33
  return self.proj(x)
34
 
35
 
36
- class MinLSTMCell(nn.Module):
 
 
37
  """
38
- https://arxiv.org/pdf/2410.01201
39
- https://github.com/YecanLee/min-LSTM-torch/blob/main/minLSTMcell.py
40
- bidirectional and parallel
41
- hold layer depth and sweep out the other dimensions
42
  """
43
  def __init__(self,
44
- embed_dim,
45
- hidden_dim):
46
- super(MinLSTMCell, self).__init__()
 
 
 
 
 
47
  self.embed_dim = embed_dim
48
- self.hidden_dim = hidden_dim
49
- self.output_dim = embed_dim
50
-
51
- # Initialize the linear layers for the forget gate, input gate, and hidden state transformation
52
- self.linear_f = nn.Linear(embed_dim, hidden_dim)
53
- self.linear_i = nn.Linear(embed_dim, hidden_dim)
54
- self.linear_h = nn.Linear(embed_dim, hidden_dim)
55
-
56
- def parallel_scan_log(self, log_coeffs, log_values):
57
- # log_coeffs: (batch_size, seq_len, input_size)
58
- # log_values: (batch_size, seq_len + 1, input_size)
59
- a_star = F.pad(torch.cumsum(log_coeffs, dim=1), (0, 0, 1, 0))
60
- log_h0_plus_b_star = torch.logcumsumexp(
61
- log_values - a_star, dim=1)
62
- log_h = a_star + log_h0_plus_b_star
63
- return torch.exp(log_h)[:, 1:]
64
-
65
- def g(self, x):
66
- return torch.where(x >= 0, x+0.5, torch.sigmoid(x))
67
-
68
- def log_g(self, x):
69
- return torch.where(x >= 0, (F.relu(x)+0.5).log(), -F.softplus(-x))
70
-
71
- def forward(self, inputs):
72
- h_init = torch.zeros(inputs.size(0), 1, self.hidden_dim, device=inputs.device)
73
-
74
- diff = F.softplus(-self.linear_f(inputs)) - F.softplus(-self.linear_i(inputs))
75
-
76
- log_f = -F.softplus(diff)
77
- log_i = -F.softplus(-diff)
78
- log_h_0 = torch.log(h_init)
79
-
80
- log_tilde_h = self.log_g(self.linear_h(inputs))
81
-
82
- h = self.parallel_scan_log(log_f, torch.cat([log_h_0, log_i + log_tilde_h], dim=1))
83
- return h
84
-
85
-
86
- class LSTMBlock(nn.Module):
87
- def __init__(self,
88
- embed_dim : int = 512,
89
- hidden_dim : int = 2048,
90
- num_layers : int = 6,
91
- dropout : float = 0.1,
92
- bidirectional : bool = False
93
- ):
94
- super(LSTMBlock, self).__init__()
95
-
96
- self.layers = []
97
- last_dim = embed_dim
98
- for _ in range(num_layers):
99
- self.layers.append(MinLSTMCell(last_dim, hidden_dim))
100
- self.layers.append(nn.LayerNorm(hidden_dim, elementwise_affine=True))
101
- self.layers.append(nn.GELU())
102
- self.layers.append(nn.Dropout(dropout))
103
- last_dim = hidden_dim
104
- self.model = nn.Sequential(*self.layers)
105
- self.bidirectionality_term = 2 if bidirectional else 1
106
- self.output_dim = hidden_dim * self.bidirectionality_term
107
- self.bidirectional = bidirectional
108
-
109
- def flip_sequence(self, inputs, lengths):
110
- # Here we want to flip the sequence but keep the right-padding
111
- # We can do this by flipping the sequence and then flipping the padding
112
- new = []
113
- for inp, leng in zip(inputs, lengths):
114
- new.append(inp[:leng].flip(0))
115
- return pad_sequence(new, batch_first=True).to(inputs.device)
116
 
117
  def forward(self, inputs, lengths):
118
- encoding = self.model(inputs)
119
- last_token = encoding[torch.arange(encoding.size(0)), lengths - 1].view(inputs.size(0), 1, -1)
120
- if self.bidirectional:
121
- reverse_sequence = self.flip_sequence(inputs, lengths)
122
- reverse_encoding = self.model(reverse_sequence)
123
- reverse_last_token = reverse_encoding[torch.arange(reverse_encoding.size(0)), lengths - 1].view(inputs.size(0), 1, -1)
124
- last_token = torch.cat((last_token, reverse_last_token), dim=-1)
125
 
126
- return last_token, torch.ones((inputs.size(0), 1), device=inputs.device, dtype=torch.long)
127
 
 
 
 
128
 
129
- class LidirlLSTM(PreTrainedModel):
 
 
130
  """
131
- Defines the Lidirl LSTM Model
132
  """
 
133
 
134
- config_class = LidirlLSTMConfig
135
  def __init__(self, config):
136
  super().__init__(config)
137
-
138
- self.encoder = LSTMBlock(
139
- embed_dim = config.embed_dim,
140
- hidden_dim = config.hidden_dim,
141
- num_layers = config.num_layers,
142
- dropout = config.dropout,
143
- bidirectional = config.bidirectional
144
- )
145
  self.embed_layer = nn.Embedding(config.vocab_size, config.embed_dim)
146
  self.proj = ProjectionLayer(self.encoder.output_dim, config.label_size, config.montecarlo_layer)
147
 
@@ -154,6 +114,7 @@ class LidirlLSTM(PreTrainedModel):
154
  for key, value in config.labels.items():
155
  self.labels[value] = key
156
 
 
157
  def forward(self, inputs, lengths):
158
  inputs = inputs[:, :self.max_length]
159
  lengths = lengths.clamp(max=self.max_length)
@@ -191,5 +152,4 @@ class LidirlLSTM(PreTrainedModel):
191
  output[batch.item()].append(
192
  (self.labels[label.item()], probs[batch, label])
193
  )
194
- return output
195
-
 
1
  import torch
2
  import torch.nn as nn
 
 
3
  from transformers import PreTrainedModel
4
 
5
  from typing import List
6
 
7
+
8
+ from .config import LidirlCNNConfig
9
 
10
  def torch_max_no_pads(model_out, lengths):
11
  indices = torch.arange(model_out.size(1)).to(model_out.device)
 
14
  max_pool = torch.max(model_out, 1)[0]
15
  return max_pool
16
 
17
+ class TransposeModule(nn.Module):
18
+ def __init__(self):
19
+ super().__init__()
20
+
21
+ def forward(self, x):
22
+ return x.transpose(1, 2)
23
 
24
  class ProjectionLayer(nn.Module):
25
  """
 
38
  return self.proj(x)
39
 
40
 
41
+ class ConvolutionalBlock(
42
+ nn.Module,
43
+ ):
44
  """
45
+ Convolutional block
46
+ https://jonathanbgn.com/2021/09/30/illustrated-wav2vec-2.html
 
 
47
  """
48
  def __init__(self,
49
+ embed_dim : int,
50
+ channels : List[int],
51
+ kernels : List[int],
52
+ strides : List[int]):
53
+
54
+ super(ConvolutionalBlock, self).__init__()
55
+ layers = []
56
+
57
  self.embed_dim = embed_dim
58
+ input_dimension = embed_dim
59
+ for channel, kernel, stride in zip(channels, kernels, strides):
60
+ next_layer = nn.Conv1d(
61
+ in_channels = input_dimension,
62
+ out_channels = channel,
63
+ kernel_size = kernel,
64
+ stride = stride,
65
+ padding = 'valid', # we handle the padding ourselves in the forward function
66
+ )
67
+ input_dimension = channel
68
+ layers.append(TransposeModule())
69
+ layers.append(next_layer)
70
+ layers.append(TransposeModule())
71
+ layers.append(nn.LayerNorm(channel, elementwise_affine=True))
72
+ layers.append(nn.GELU())
73
+ layers.append(nn.Dropout(0.1))
74
+ self.model = nn.Sequential(*layers)
75
+ self.output_dim = channels[-1]
76
+
77
+ self.min_length = 1
78
+ for kernel, stride in zip(kernels[::-1], strides[::-1]):
79
+ self.min_length = ((self.min_length - 1) * stride) + kernel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def forward(self, inputs, lengths):
82
+ # this is our padding trick instead of consistent padding
83
+ if inputs.size(1) < self.min_length:
84
+ pads = torch.zeros((inputs.size(0), self.min_length - inputs.size(1), self.embed_dim), device=inputs.device)
85
+ inputs = torch.cat((inputs, pads), dim=1)
 
 
 
86
 
87
+ outputs = self.model(inputs)
88
 
89
+ for layer_i in range(1, len(self.model), 6):
90
+ lengths = torch.floor(((lengths - self.model[layer_i].kernel_size[0]) / self.model[layer_i].stride[0]) + 1).to(lengths.device, dtype=torch.long)
91
+ lengths[lengths < 1] = 1
92
 
93
+ return outputs, lengths
94
+
95
+ class LidirlCNN(PreTrainedModel):
96
  """
97
+ Defines the Lidirl CNN MODEL
98
  """
99
+ config_class = LidirlCNNConfig
100
 
 
101
  def __init__(self, config):
102
  super().__init__(config)
103
+
104
+ self.encoder = ConvolutionalBlock(config.embed_dim, config.channels, config.kernels, config.strides)
 
 
 
 
 
 
105
  self.embed_layer = nn.Embedding(config.vocab_size, config.embed_dim)
106
  self.proj = ProjectionLayer(self.encoder.output_dim, config.label_size, config.montecarlo_layer)
107
 
 
114
  for key, value in config.labels.items():
115
  self.labels[value] = key
116
 
117
+
118
  def forward(self, inputs, lengths):
119
  inputs = inputs[:, :self.max_length]
120
  lengths = lengths.clamp(max=self.max_length)
 
152
  output[batch.item()].append(
153
  (self.labels[label.item()], probs[batch, label])
154
  )
155
+ return output