Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| from typing import List, Iterable, Tuple | |
| from functools import partial | |
| import numpy as np | |
| import torch | |
| import json | |
| from utils.token_processing import fix_byte_spaces | |
| from utils.gen_utils import map_nlist | |
| def round_return_value(attentions, ndigits=5): | |
| """Rounding must happen right before it's passed back to the frontend because there is a little numerical error that's introduced converting back to lists | |
| attentions: { | |
| 'aa': { | |
| left | |
| right | |
| att | |
| } | |
| } | |
| """ | |
| rounder = partial(round, ndigits=ndigits) | |
| nested_rounder = partial(map_nlist, rounder) | |
| new_out = attentions # Modify values to save memory | |
| new_out["aa"]["att"] = nested_rounder(attentions["aa"]["att"]) | |
| return new_out | |
| def flatten_batch(x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: | |
| """Remove the batch dimension of every tensor inside the Iterable container `x`""" | |
| return tuple([x_.squeeze(0) for x_ in x]) | |
| def squeeze_contexts(x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: | |
| """Combine the last two dimensions of the context.""" | |
| shape = x[0].shape | |
| new_shape = shape[:-2] + (-1,) | |
| return tuple([x_.view(new_shape) for x_ in x]) | |
| def add_blank(xs: Tuple[torch.tensor]) -> Tuple[torch.Tensor]: | |
| """The embeddings have n_layers + 1, indicating the final output embedding.""" | |
| return (torch.zeros_like(xs[0]),) + xs | |
| class TransformerOutputFormatter: | |
| def __init__( | |
| self, | |
| sentence: str, | |
| tokens: List[str], | |
| special_tokens_mask: List[int], | |
| att: Tuple[torch.Tensor], | |
| topk_words: List[List[str]], | |
| topk_probs: List[List[float]], | |
| model_config | |
| ): | |
| assert len(tokens) > 0, "Cannot have an empty token output!" | |
| modified_att = flatten_batch(att) | |
| self.sentence = sentence | |
| self.tokens = tokens | |
| self.special_tokens_mask = special_tokens_mask | |
| self.attentions = modified_att | |
| self.topk_words = topk_words | |
| self.topk_probs = topk_probs | |
| self.model_config = model_config | |
| try: | |
| # GPT vals | |
| self.n_layer = self.model_config.n_layer | |
| self.n_head = self.model_config.n_head | |
| self.hidden_dim = self.model_config.n_embd | |
| except AttributeError: | |
| try: | |
| # BERT vals | |
| self.n_layer = self.model_config.num_hidden_layers | |
| self.n_head = self.model_config.num_attention_heads | |
| self.hidden_dim = self.model_config.hidden_size | |
| except AttributeError: raise | |
| self.__len = len(tokens)# Get the number of tokens in the input | |
| assert self.__len == self.attentions[0].shape[-1], "Attentions don't represent the passed tokens!" | |
| def to_json(self, layer:int, ndigits=5): | |
| """The original API expects the following response: | |
| aa: { | |
| att: number[][][] | |
| left: List[str] | |
| right: List[str] | |
| } | |
| """ | |
| # Convert the embeddings, attentions, and contexts into list. Perform rounding | |
| rounder = partial(round, ndigits=ndigits) | |
| nested_rounder = partial(map_nlist, rounder) | |
| def tolist(tens): return [t.tolist() for t in tens] | |
| def to_resp(tok: str, topk_words, topk_probs): | |
| return { | |
| "text": tok, | |
| "topk_words": topk_words, | |
| "topk_probs": nested_rounder(topk_probs) | |
| } | |
| side_info = [to_resp(t, w, p) for t,w,p in zip( self.tokens, | |
| self.topk_words, | |
| self.topk_probs)] | |
| out = {"aa": { | |
| "att": nested_rounder(tolist(self.attentions[layer])), | |
| "left": side_info, | |
| "right": side_info | |
| }} | |
| return out | |
| def display_tokens(self, tokens): | |
| return fix_byte_spaces(tokens) | |
| def __repr__(self): | |
| lim = 50 | |
| if len(self.sentence) > lim: s = self.sentence[:lim - 3] + "..." | |
| else: s = self.sentence[:lim] | |
| return f"TransformerOutput({s})" | |
| def __len__(self): | |
| return self.__len | |
| def to_numpy(x): | |
| """Embeddings, contexts, and attentions are stored as torch.Tensors in a tuple. Convert this to a numpy array | |
| for storage in hdf5""" | |
| return np.array([x_.detach().numpy() for x_ in x]) | |
| def to_searchable(t: Tuple[torch.Tensor]): | |
| return t.detach().numpy().astype(np.float32) |