perticarari commited on
Commit
6fcf1a4
·
verified ·
1 Parent(s): 008a8d1

Initial commit

Browse files
Files changed (2) hide show
  1. cust_transformer.py +47 -0
  2. modules.json +1 -1
cust_transformer.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from sentence_transformers import models
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ class CustTrans(models.Transformer):
7
+
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.curr_task_type = None
11
+ self._rebuild_taskembedding(['sts', 'quora'])
12
+
13
+ def forward(self, inputs, task_type=None):
14
+
15
+ enc = self.auto_model(**inputs).last_hidden_state
16
+
17
+ if task_type == None:
18
+ task_type = self.curr_task_type
19
+
20
+ if task_type in self.task_types:
21
+ idx = torch.tensor(self.task_types.index(task_type), device=self.TaskEmbedding.weight.device)
22
+ hyp = self.TaskEmbedding(idx)
23
+ inputs['token_embeddings'] = self._project(enc, hyp)
24
+
25
+ else:
26
+ inputs['token_embeddings'] = enc
27
+
28
+ return inputs
29
+
30
+ def _set_curr_task_type(self, task_type):
31
+ self.curr_task_type = task_type
32
+
33
+ def _set_taskembedding_grad(self, value):
34
+ self.TaskEmbedding.weight.requires_grad = value
35
+
36
+ def _set_transformer_grad(self, value):
37
+ for param in self.auto_model.parameters():
38
+ param.requires_grad = value
39
+
40
+ def _rebuild_taskembedding(self, task_types):
41
+ self.task_types = task_types
42
+ self.task_emb = 1 - torch.eye(len(self.task_types),768)
43
+ self.TaskEmbedding = nn.Embedding(len(self.task_types), 768).from_pretrained(self.task_emb)
44
+
45
+ def _project(self, v, normal_hyper):
46
+ # return v - torch.dot(v, normal_hyper)*normal_hyper / torch.norm(normal_hyper)**2
47
+ return v*normal_hyper
modules.json CHANGED
@@ -3,7 +3,7 @@
3
  "idx": 0,
4
  "name": "0",
5
  "path": "",
6
- "type": "custom_trans.CustTrans"
7
  },
8
  {
9
  "idx": 1,
 
3
  "idx": 0,
4
  "name": "0",
5
  "path": "",
6
+ "type": "cust_transformer.CustTrans"
7
  },
8
  {
9
  "idx": 1,