Spaces:
Paused
Paused
| from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, set_seed | |
| from torch.utils.data import DataLoader | |
| from torch.nn import Linear, Module | |
| from typing import Dict, List | |
| from collections import Counter, defaultdict | |
| from itertools import chain | |
| import torch | |
| torch.manual_seed(0) | |
| set_seed(34) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(0) | |
| class MimicTransformer(Module): | |
| def __init__(self, num_labels=738, tokenizer_name='clinical', cutoff=512): | |
| """ | |
| :param args: | |
| """ | |
| super().__init__() | |
| self.tokenizer_name = self.find_tokenizer(tokenizer_name) | |
| self.num_labels = num_labels | |
| self.config = AutoConfig.from_pretrained(self.tokenizer_name, num_labels=self.num_labels) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, config=self.config) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(self.tokenizer_name, config=self.config) | |
| self.model.eval() | |
| if 'longformer' in self.tokenizer_name: | |
| self.cutoff = self.model.config.max_position_embeddings | |
| else: | |
| self.cutoff = cutoff | |
| self.linear = Linear(in_features=self.cutoff, out_features=1) | |
| def parse_icds(self, instances: List[Dict]): | |
| token_list = defaultdict(set) | |
| token_freq_list = [] | |
| for instance in instances: | |
| icds = list(chain(*instance['icd'])) | |
| icd_dict_list = list({icd['start']: icd for icd in icds}.values()) | |
| for icd_dict in icd_dict_list: | |
| icd_ent = icd_dict['text'] | |
| icd_tokenized = self.tokenizer(icd_ent, add_special_tokens=False)['input_ids'] | |
| icd_dict['tokens'] = icd_tokenized | |
| icd_dict['labels'] = [] | |
| for i,token in enumerate(icd_tokenized): | |
| if i != 0: | |
| label = "I-ATTN" | |
| else: | |
| label = "B-ATTN" | |
| icd_dict['labels'].append(label) | |
| token_list[token].add(label) | |
| token_freq_list.append(str(token) + ": " + label) | |
| token_tag_freqs = Counter(token_freq_list) | |
| for token in token_list: | |
| if len(token_list[token]) == 2: | |
| inside_count = token_tag_freqs[str(token) + ": I-ATTN"] | |
| begin_count = token_tag_freqs[str(token) + ": B-ATTN"] | |
| if begin_count > inside_count: | |
| token_list[token].remove('I-ATTN') | |
| else: | |
| token_list[token].remove('B-ATTN') | |
| return token_list | |
| def collate_mimic( | |
| self, instances: List[Dict], device='cuda' | |
| ): | |
| tokenized = [ | |
| self.tokenizer.encode( | |
| ' '.join(instance['description']), max_length=self.cutoff, truncation=True, padding='max_length' | |
| ) for instance in instances | |
| ] | |
| entries = [instance['entry'] for instance in instances] | |
| labels = torch.tensor([x['drg'] for x in instances], dtype=torch.long).to(device).unsqueeze(1) | |
| inputs = torch.tensor(tokenized, dtype=torch.long).to(device) | |
| icds = self.parse_icds(instances) | |
| xai_labels = torch.zeros(size=inputs.shape, dtype=torch.float32).to(device) | |
| for i,row in enumerate(inputs): | |
| for j,ele in enumerate(row): | |
| if ele.item() in icds: | |
| xai_labels[i][j] = 1 | |
| return { | |
| 'text': inputs, | |
| 'drg': labels, | |
| 'entry': entries, | |
| 'icds': icds, | |
| 'xai': xai_labels | |
| } | |
| def forward(self, input_ids, attention_mask=None, drg_labels=None): | |
| if drg_labels: | |
| cls_results = self.model(input_ids, attention_mask=attention_mask, labels=drg_labels, output_attentions=True) | |
| else: | |
| cls_results = self.model(input_ids, attention_mask=attention_mask, output_attentions=True) | |
| last_attn = cls_results[-1][-1] # (batch, attn_heads, tokens, tokens) | |
| # last_attn = torch.mean(torch.stack(cls_results[-1])[:], dim=0) | |
| # last_layer_attn = torch.mean(last_attn[:, :-3, :, :], dim=1) | |
| last_layer_attn = last_attn[:, -1, :, :] | |
| xai_logits = self.linear(last_layer_attn).squeeze(dim=-1) | |
| return (cls_results, xai_logits) | |
| def find_tokenizer(self, tokenizer_name): | |
| """ | |
| :param args: | |
| :return: | |
| """ | |
| if tokenizer_name == 'clinical_longformer': | |
| return 'yikuan8/Clinical-Longformer' | |
| if tokenizer_name == 'clinical': | |
| return 'emilyalsentzer/Bio_ClinicalBERT' | |
| else: | |
| # standard transformer | |
| return 'bert-based-uncased' |