Spaces:
Running
Running
File size: 4,761 Bytes
1841ebe f73dc21 1841ebe f73dc21 1841ebe f73dc21 2690a96 f73dc21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, set_seed
from torch.utils.data import DataLoader
from torch.nn import Linear, Module
from typing import Dict, List
from collections import Counter, defaultdict
from itertools import chain
import torch
torch.manual_seed(0)
set_seed(34)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
class MimicTransformer(Module):
def __init__(self, num_labels=738, tokenizer_name='clinical', cutoff=512):
"""
:param args:
"""
super().__init__()
self.tokenizer_name = self.find_tokenizer(tokenizer_name)
self.num_labels = num_labels
self.config = AutoConfig.from_pretrained(self.tokenizer_name, num_labels=self.num_labels)
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, config=self.config)
self.model = AutoModelForSequenceClassification.from_pretrained(self.tokenizer_name, config=self.config)
self.model.eval()
if 'longformer' in self.tokenizer_name:
self.cutoff = self.model.config.max_position_embeddings
else:
self.cutoff = cutoff
self.linear = Linear(in_features=self.cutoff, out_features=1)
def parse_icds(self, instances: List[Dict]):
token_list = defaultdict(set)
token_freq_list = []
for instance in instances:
icds = list(chain(*instance['icd']))
icd_dict_list = list({icd['start']: icd for icd in icds}.values())
for icd_dict in icd_dict_list:
icd_ent = icd_dict['text']
icd_tokenized = self.tokenizer(icd_ent, add_special_tokens=False)['input_ids']
icd_dict['tokens'] = icd_tokenized
icd_dict['labels'] = []
for i,token in enumerate(icd_tokenized):
if i != 0:
label = "I-ATTN"
else:
label = "B-ATTN"
icd_dict['labels'].append(label)
token_list[token].add(label)
token_freq_list.append(str(token) + ": " + label)
token_tag_freqs = Counter(token_freq_list)
for token in token_list:
if len(token_list[token]) == 2:
inside_count = token_tag_freqs[str(token) + ": I-ATTN"]
begin_count = token_tag_freqs[str(token) + ": B-ATTN"]
if begin_count > inside_count:
token_list[token].remove('I-ATTN')
else:
token_list[token].remove('B-ATTN')
return token_list
def collate_mimic(
self, instances: List[Dict], device='cuda'
):
tokenized = [
self.tokenizer.encode(
' '.join(instance['description']), max_length=self.cutoff, truncation=True, padding='max_length'
) for instance in instances
]
entries = [instance['entry'] for instance in instances]
labels = torch.tensor([x['drg'] for x in instances], dtype=torch.long).to(device).unsqueeze(1)
inputs = torch.tensor(tokenized, dtype=torch.long).to(device)
icds = self.parse_icds(instances)
xai_labels = torch.zeros(size=inputs.shape, dtype=torch.float32).to(device)
for i,row in enumerate(inputs):
for j,ele in enumerate(row):
if ele.item() in icds:
xai_labels[i][j] = 1
return {
'text': inputs,
'drg': labels,
'entry': entries,
'icds': icds,
'xai': xai_labels
}
def forward(self, input_ids, attention_mask=None, drg_labels=None):
if drg_labels:
cls_results = self.model(input_ids, attention_mask=attention_mask, labels=drg_labels, output_attentions=True)
else:
cls_results = self.model(input_ids, attention_mask=attention_mask, output_attentions=True)
last_attn = cls_results[-1][-1] # (batch, attn_heads, tokens, tokens)
# last_attn = torch.mean(torch.stack(cls_results[-1])[:], dim=0)
# last_layer_attn = torch.mean(last_attn[:, :-3, :, :], dim=1)
last_layer_attn = last_attn[:, -1, :, :]
xai_logits = self.linear(last_layer_attn).squeeze(dim=-1)
return (cls_results, xai_logits)
def find_tokenizer(self, tokenizer_name):
"""
:param args:
:return:
"""
if tokenizer_name == 'clinical_longformer':
return 'yikuan8/Clinical-Longformer'
if tokenizer_name == 'clinical':
return 'emilyalsentzer/Bio_ClinicalBERT'
else:
# standard transformer
return 'bert-based-uncased' |