Spaces:
Running
Running
import urllib.request, urllib.error, urllib.parse | |
import json | |
import pandas as pd | |
import ssl | |
import torch | |
from pprint import pprint | |
from captum.attr import visualization | |
REST_URL = "http://data.bioontology.org" | |
API_KEY = "604a90bc-ef14-4c26-a347-f4928fa086ea" | |
ssl._create_default_https_context = ssl._create_unverified_context | |
class PyTMinMaxScalerVectorized(object): | |
""" | |
From https://discuss.pytorch.org/t/using-scikit-learns-scalers-for-torchvision/53455 | |
Transforms each channel to the range [0, 1]. | |
""" | |
def __call__(self, tensor): | |
scale = 1.0 / (tensor.max(dim=0, keepdim=True)[0] - tensor.min(dim=0, keepdim=True)[0]) | |
tensor.mul_(scale).sub_(tensor.min(dim=0, keepdim=True)[0]) | |
return tensor | |
def get_drg_link(drg_code): | |
return f'https://www.aapc.com/codes/icd9-codes/{drg_code}' | |
def prettify(dict_list, k): | |
li = [di[k] for di in dict_list] | |
result = "\n".join(l for l in li) | |
return result | |
def get_json(text_to_annotate): | |
url = REST_URL + "/annotator?text=" + urllib.parse.quote(text_to_annotate) + "&ontologies=ICD9CM" +\ | |
"&longest_only=false" + "&exclude_numbers=false" + "&whole_word_only=true" + '&exclude_synonyms=false' | |
opener = urllib.request.build_opener() | |
opener.addheaders = [('Authorization', 'apikey token=' + API_KEY)] | |
try: | |
return json.loads(opener.open(url).read()) | |
except: | |
return [] | |
def parse_results(results): | |
if len(results) == 0: | |
return [] | |
rlist = [] | |
for result in results: | |
annotations = result['annotations'] | |
for annotation in annotations: | |
start = annotation['from']-1 | |
end = annotation['to'] - 1 | |
text = annotation['text'] | |
rlist.append({ | |
'start': start, | |
'end': end, | |
'text': text, | |
'link': result['annotatedClass']['@id'] | |
}) | |
return rlist | |
def get_icd_annotations(text): | |
response = get_json(text) | |
annotation_list = parse_results(response) | |
return annotation_list | |
def subfinder(mylist, pattern): | |
mylist = mylist.tolist() | |
pattern = pattern.tolist() | |
return list(filter(lambda x: x in pattern, mylist)) | |
def tokenize_icds(tokenizer, annotations, token_ids): | |
icd_tokens = torch.zeros(token_ids.shape) | |
for annotation in annotations: | |
icd = annotation['text'] | |
icd_token_ids = tokenizer(icd, add_special_tokens=False, return_tensors='pt').input_ids[0] | |
# find index of the beginning icd token | |
starting_indices = (token_ids==icd_token_ids[0]).nonzero(as_tuple=False) | |
num_icd_tokens = icd_token_ids.shape[0] | |
# if there's more than 1 icd token for the given annotation | |
if num_icd_tokens > 1: | |
# if there's only one starting index | |
if starting_indices.shape[0] == 1: | |
starting_index = starting_indices.item() | |
icd_tokens[starting_index: starting_index + num_icd_tokens] = 1 | |
# if there's more than 1 starting index, determine which is the appropriate | |
else: | |
for starting_index in starting_indices: | |
if token_ids[starting_index + num_icd_tokens] == icd_token_ids: | |
icd_tokens[starting_index: starting_index + num_icd_tokens] = 1 | |
# otherwise, set the corresponding index to a value of 1 | |
else: | |
icd_tokens[starting_indices] = 1 | |
return icd_tokens | |
def get_attribution(text, tokenizer, model_outputs, inputs, k=7): | |
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) | |
padding_idx = tokens.index('[PAD]') | |
tokens = tokens[:padding_idx][1:-1] | |
attn = model_outputs[-1][0] | |
agg_attn, final_text = reconstruct_text(tokenizer=tokenizer, tokens=tokens, attn=attn) | |
return agg_attn, final_text | |
def reconstruct_text(tokenizer, tokens, attn): | |
""" | |
find a word -> token_id mapping that allows you to | |
perform an aggregation on the sub-tokens' attention | |
values | |
""" | |
reconstructed_text = tokenizer.convert_tokens_to_string(tokens) | |
num_subtokens = len([t for t in tokens if t.startswith('#')]) | |
aggregated_attn = torch.zeros(len(tokens) - num_subtokens) | |
token_indices = [0] | |
token_idx = 0 | |
reconstructed_tokens = [] | |
for i, token in enumerate(tokens[1:], start=1): | |
# case when a token is a subtoken | |
if token.startswith('#'): | |
token_indices.append(i) | |
else: | |
# reconstruct the tokens to make sure you're doing this correctly | |
reconstructed_token = ''.join(tokens[i].replace('#', '') for i in token_indices) | |
reconstructed_tokens.append(reconstructed_token) | |
# find the corresponding attention vectors | |
aggregated_attn[token_idx] = torch.mean(attn[token_indices]) | |
# create new index list | |
token_indices = [i] | |
token_idx += 1 | |
# reconstruct the tokens to make sure you're doing this correctly | |
reconstructed_token = ''.join(tokens[i].replace('#', '') for i in token_indices) | |
reconstructed_tokens.append(reconstructed_token) | |
# find the corresponding attention vectors | |
aggregated_attn[token_idx] = torch.mean(attn[token_indices]) | |
# final representation of text | |
final_text = ' '.join(reconstructed_tokens).replace(' .', '.') | |
final_text = final_text.replace(' ,', ',') | |
assert final_text == reconstructed_text | |
return aggregated_attn, reconstructed_tokens | |
def load_rule(path): | |
rule_df = pd.read_csv(path) | |
# remove MDC 15 - neonate and couple other codes related to postcare | |
if 'MS' in path: | |
msk = (rule_df['MDC']!='15') & (~rule_df['MS-DRG'].isin([945, 946, 949, 950, 998, 999])) | |
space = sorted(rule_df[msk]['DRG_CODE'].unique()) | |
elif 'APR' in path: | |
msk = (rule_df['MDC']!='15') & (~rule_df['APR-DRG'].isin([860, 863])) | |
space = sorted(rule_df[msk]['DRG_CODE'].unique()) | |
drg2idx = {} | |
for d in space: | |
drg2idx[d] = len(drg2idx) | |
i2d = {v:k for k,v in drg2idx.items()} | |
d2mdc, d2w = {}, {} | |
for _, r in rule_df.iterrows(): | |
drg = r['DRG_CODE'] | |
mdc = r['MDC'] | |
w = r['WEIGHT'] | |
d2mdc[drg] = mdc | |
d2w[drg] = w | |
return rule_df, drg2idx, i2d, d2mdc, d2w | |
def visualize_attn(model_results): | |
class_id = model_results['class_dsc'] | |
prob = model_results['prob'] | |
attn = model_results['attn'] | |
tokens = model_results['tokens'] | |
scaler = PyTMinMaxScalerVectorized() | |
normalized_attn = scaler(attn) | |
viz_record = visualization.VisualizationDataRecord( | |
word_attributions=normalized_attn, | |
pred_prob=prob, | |
pred_class=class_id, | |
true_class=class_id, | |
attr_class=0, | |
attr_score=1, | |
raw_input_ids=tokens, | |
convergence_score=1 | |
) | |
return visualize_text(viz_record) | |
def modify_attn_html(attn_html): | |
attn_split = attn_html.split('<mark') | |
htmls = [attn_split[0]] | |
for html in attn_split[1:]: | |
# wrap around href tag | |
href_html = f'<a href="espn.com" \ | |
<mark{html} \ | |
</a>' | |
htmls.append(href_html) | |
return "".join(htmls) | |
# copied out of captum because we need raw html instead of a jupyter widget | |
def visualize_text(datarecord): | |
dom = ["<table width: 100%>"] | |
rows = [ | |
"<th style='text-align: left'>Predicted DRG</th>" | |
"<th style='text-align: left'>Word Importance</th>" | |
] | |
pred_class_html = visualization.format_classname(datarecord.pred_class) | |
word_attn_html = visualization.format_word_importances( | |
datarecord.raw_input_ids, datarecord.word_attributions | |
) | |
word_attn_html = modify_attn_html(word_attn_html) | |
rows.append( | |
"".join( | |
[ | |
"<tr>", | |
pred_class_html, | |
word_attn_html, | |
"<tr>", | |
] | |
) | |
) | |
dom.append("".join(rows)) | |
dom.append("</table>") | |
html = "".join(dom) | |
return html | |