|
import torch |
|
import torch.nn as nn |
|
from sentence_transformers import models |
|
|
|
class CustTrans(models.Transformer): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.curr_task_type = None |
|
self._rebuild_taskembedding(['sts', 'quora']) |
|
|
|
def forward(self, inputs, task_type=None): |
|
|
|
enc = self.auto_model(**inputs).last_hidden_state |
|
|
|
if task_type == None: |
|
task_type = self.curr_task_type |
|
|
|
if task_type in self.task_types: |
|
idx = torch.tensor(self.task_types.index(task_type), device=self.TaskEmbedding.weight.device) |
|
hyp = self.TaskEmbedding(idx) |
|
inputs['token_embeddings'] = self._project(enc, hyp) |
|
|
|
else: |
|
inputs['token_embeddings'] = enc |
|
|
|
return inputs |
|
|
|
def _set_curr_task_type(self, task_type): |
|
self.curr_task_type = task_type |
|
|
|
def _set_taskembedding_grad(self, value): |
|
self.TaskEmbedding.weight.requires_grad = value |
|
|
|
def _set_transformer_grad(self, value): |
|
for param in self.auto_model.parameters(): |
|
param.requires_grad = value |
|
|
|
def _rebuild_taskembedding(self, task_types): |
|
self.task_types = task_types |
|
self.task_emb = 1 - torch.eye(len(self.task_types),768) |
|
self.TaskEmbedding = nn.Embedding(len(self.task_types), 768).from_pretrained(self.task_emb) |
|
|
|
def _project(self, v, normal_hyper): |
|
|
|
return v*normal_hyper |
|
|