raksama19's picture
Update utils/model.py
181ec2f verified
import logging
from collections import defaultdict
from typing import List, Optional
import torch
import torch.nn.functional as F
from PIL import Image
from timm.models.swin_transformer import SwinTransformer
from torch import nn
from transformers import (
MBartConfig,
MBartForCausalLM,
StoppingCriteria,
StoppingCriteriaList,
)
from transformers.file_utils import ModelOutput
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
class SwinEncoder(nn.Module):
r"""
Encoder based on SwinTransformer
Set the initial weights and configuration with a pretrained SwinTransformer and then
modify the detailed configurations
Args:
input_size: Input image size (width, height)
align_long_axis: Whether to rotate image if height is greater than width
window_size: Window size(=patch size) of SwinTransformer
encoder_layer: Number of layers of SwinTransformer encoder
name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local.
otherwise, `swin_base_patch4_window12_384` will be set (using `timm`).
"""
def __init__(
self,
input_size,
align_long_axis: bool = False,
window_size: int = 7,
encoder_layer: List[int] = [2, 2, 14, 2],
patch_size: int = [4, 4],
embed_dim: int = 128,
num_heads: List[int] = [4, 8, 16, 32],
):
super().__init__()
if isinstance(input_size, int):
input_size = [input_size, input_size]
self.input_size = input_size
self.align_long_axis = align_long_axis
self.window_size = window_size
self.encoder_layer = encoder_layer
self.patch_size = patch_size
self.embed_dim = embed_dim
self.num_heads = num_heads
self.model = SwinTransformer(
img_size=self.input_size,
depths=self.encoder_layer,
window_size=self.window_size,
patch_size=self.patch_size,
embed_dim=self.embed_dim,
num_heads=self.num_heads,
num_classes=0,
)
def forward(self, x: torch.Tensor, text_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
x: (batch_size, num_channels, height, width)
"""
x = self.model.patch_embed(x)
x = self.model.pos_drop(x)
x = self.model.layers(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def _set_dtype(self, dtype):
self._dtype = dtype
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(dtype=self._dtype))
return ret.type(orig_type)
class BARTDecoder(nn.Module):
"""
Decoder based on Multilingual BART
Set the initial weights and configuration with a pretrained multilingual BART model,
and modify the detailed configurations as a Donut decoder
Args:
decoder_layer:
Number of layers of BARTDecoder
max_position_embeddings:
The maximum sequence length to be trained
name_or_path:
Name of a pretrained model name either registered in huggingface.co. or saved in local,
otherwise, `facebook/mbart-large-50` will be set (using `transformers`)
"""
def __init__(
self,
tokenizer,
decoder_layer: int,
max_position_embeddings: int,
hidden_dimension: int = 1024,
**kwargs,
):
super().__init__()
self.decoder_layer = decoder_layer
self.max_position_embeddings = max_position_embeddings
self.hidden_dimension = hidden_dimension
self.tokenizer = tokenizer
self.model = MBartForCausalLM(
config=MBartConfig(
tie_word_embeddings=True,
is_decoder=True,
is_encoder_decoder=False,
add_cross_attention=True,
decoder_layers=self.decoder_layer,
max_position_embeddings=self.max_position_embeddings,
vocab_size=len(self.tokenizer),
scale_embedding=True,
add_final_layer_norm=True,
d_model=self.hidden_dimension,
)
)
# self.model.config.is_encoder_decoder = True # to get cross-attention
self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id
self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference
def add_special_tokens(self, list_of_tokens: List[str]):
"""
Add special tokens to tokenizer and resize the token embeddings
"""
newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))})
if newly_added_num > 0:
self.model.resize_token_embeddings(len(self.tokenizer))
def add_tokens(self, list_of_tokens: List[str]):
"""
Add special tokens to tokenizer and resize the token embeddings
"""
newly_added_num = self.tokenizer.add_tokens(sorted(set(list_of_tokens)))
if newly_added_num > 0:
self.model.resize_token_embeddings(len(self.tokenizer))
def prepare_inputs_for_inference(
self,
input_ids: torch.Tensor,
encoder_outputs: torch.Tensor,
past=None,
past_key_values=None,
use_cache: bool = None,
attention_mask: torch.Tensor = None,
**kwargs,
):
"""
Args:
input_ids: (batch_size, sequence_length)
Returns:
input_ids: (batch_size, sequence_length)
attention_mask: (batch_size, sequence_length)
encoder_hidden_states: (batch_size, sequence_length, embedding_dim)
"""
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
past = past or past_key_values
if past is not None:
input_ids = input_ids[:, -1:]
output = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
"encoder_hidden_states": encoder_outputs.last_hidden_state,
}
return output
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: bool = None,
output_attentions: Optional[torch.Tensor] = None,
output_hidden_states: Optional[torch.Tensor] = None,
return_dict: bool = None,
):
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@staticmethod
def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor:
"""
Resize position embeddings
Truncate if sequence length of MBart backbone is greater than given max_length,
else interpolate to max_length
"""
if weight.shape[0] > max_length:
weight = weight[:max_length, ...]
else:
weight = (
F.interpolate(
weight.permute(1, 0).unsqueeze(0),
size=max_length,
mode="linear",
align_corners=False,
)
.squeeze(0)
.permute(1, 0)
)
return weight
class DonutConfig(PretrainedConfig):
def __init__(
self,
decoder_layer: int = 10,
max_position_embeddings: int = None,
max_length: int = 4096,
hidden_dimension: int = 1024,
**kwargs,
):
super().__init__()
self.decoder_layer = decoder_layer
self.max_position_embeddings = max_length if max_position_embeddings is None else max_position_embeddings
self.max_length = max_length
self.hidden_dimension = hidden_dimension
class RunningVarTorch:
def __init__(self, L=15, norm=False):
self.values = None
self.L = L
self.norm = norm
def push(self, x: torch.Tensor):
assert x.dim() == 1
if self.values is None:
self.values = x[:, None]
elif self.values.shape[1] < self.L:
self.values = torch.cat((self.values, x[:, None]), 1)
else:
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
def variance(self):
if self.values is None:
return
if self.norm:
return torch.var(self.values, 1) / self.values.shape[1]
else:
return torch.var(self.values, 1)
class StoppingCriteriaScores(StoppingCriteria):
def __init__(self, threshold: float = 0.015, window_size: int = 200):
super().__init__()
self.threshold = threshold
self.vars = RunningVarTorch(norm=True)
self.varvars = RunningVarTorch(L=window_size)
self.stop_inds = defaultdict(int)
self.stopped = defaultdict(bool)
self.size = 0
self.window_size = window_size
@torch.no_grad()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
last_scores = scores[-1]
self.vars.push(last_scores.max(1)[0].float().cpu())
self.varvars.push(self.vars.variance())
self.size += 1
if self.size < self.window_size:
return False
varvar = self.varvars.variance()
for b in range(len(last_scores)):
if varvar[b] < self.threshold:
if self.stop_inds[b] > 0 and not self.stopped[b]:
self.stopped[b] = self.stop_inds[b] >= self.size
else:
self.stop_inds[b] = int(min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095))
else:
self.stop_inds[b] = 0
self.stopped[b] = False
return all(self.stopped.values()) and len(self.stopped) > 0
def batch(l, b=15):
subs = []
for i in range(len(l) - b):
subs.append(l[i : i + b])
return subs
def subdiv(l, b=10):
subs = []
for i in range(len(l) - b):
subs.append(l[: i + b])
return subs
class DonutModel(PreTrainedModel):
config_class = DonutConfig
base_model_prefix = "donut"
def __init__(self, config: DonutConfig, vision_tower=None, tokenizer=None):
super().__init__(config)
self.config = config
self.tokenizer = tokenizer
self.vpm = vision_tower
# build language model
self.llm = BARTDecoder(
tokenizer=tokenizer,
decoder_layer=self.config.decoder_layer,
max_position_embeddings=self.config.max_position_embeddings,
hidden_dimension=self.config.hidden_dimension,
)
self.ids_to_tokens = {id: content for content, id in self.llm.tokenizer.vocab.items()}
def get_input_embeddings(self, tensor):
return self.llm.model.get_input_embeddings()(tensor)
def forward(
self,
inputs: dict,
):
image_tensors = inputs["pixel_values"]
input_ids = inputs["input_ids"].contiguous()
attention_mask = inputs["attention_mask"]
labels = inputs["labels"].contiguous()
encoder_outputs = self.vpm(
image_tensors,
text_embedding=self.llm.model.get_input_embeddings()(input_ids),
)
decoder_outputs = self.llm(
input_ids=input_ids,
encoder_hidden_states=encoder_outputs,
attention_mask=attention_mask,
labels=labels,
)
return decoder_outputs
def get_hidden_states_during_inference(
self,
prompt_ids: torch.Tensor,
image: Image.Image = None,
image_tensors: Optional[torch.Tensor] = None,
):
if image_tensors is None:
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
if self.device.type != "mps":
image_tensors = image_tensors.to(next(self.parameters()).dtype)
image_tensors = image_tensors.to(self.device)
prompt_ids = prompt_ids.to(self.device)
all_hidden_states = self.vpm.forward_features(
image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
)
return all_hidden_states
def get_attn_weights_during_inference(
self,
prompt_ids: torch.Tensor,
image: Image.Image = None,
image_tensors: Optional[torch.Tensor] = None,
):
if image_tensors is None:
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
if self.device.type != "mps":
image_tensors = image_tensors.to(next(self.parameters()).dtype)
image_tensors = image_tensors.to(self.device)
prompt_ids = prompt_ids.to(self.device)
last_attn_score = self.vpm.get_last_layer_cross_attn_score(
image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
)
return last_attn_score
def inference(
self,
prompt_ids: torch.Tensor,
image: Image.Image = None,
image_tensors: Optional[torch.Tensor] = None,
return_attentions: bool = False,
early_stopping: bool = True,
):
"""
Generate a token sequence in an auto-regressive manner.
Args:
image: input document image (PIL.Image)
image_tensors: (1, num_channels, height, width)
convert prompt to tensor if image_tensor is not fed
"""
output = {
"predictions": list(),
"sequences": list(),
"repeats": list(),
"repetitions": list(),
}
if image is None and image_tensors is None:
logging.warn("Image not found")
return output
if image_tensors is None:
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
if self.device.type != "mps":
image_tensors = image_tensors.to(next(self.parameters()).dtype)
image_tensors = image_tensors.to(self.device)
prompt_ids = prompt_ids.to(self.device)
last_hidden_state = self.vpm(image_tensors, text_embedding=self.get_input_embeddings(prompt_ids))
encoder_outputs = ModelOutput(last_hidden_state=last_hidden_state, attentions=None)
if len(encoder_outputs.last_hidden_state.size()) == 1:
encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.unsqueeze(0)
# get decoder output
decoder_output = self.llm.model.generate(
input_ids=prompt_ids,
encoder_outputs=encoder_outputs,
min_length=1,
max_length=self.config.max_length,
pad_token_id=self.llm.tokenizer.pad_token_id,
eos_token_id=self.llm.tokenizer.eos_token_id,
use_cache=True,
return_dict_in_generate=True,
output_scores=True,
output_attentions=return_attentions,
do_sample=False,
num_beams=1,
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()] if early_stopping else []),
)
output["repetitions"] = decoder_output.sequences.clone()
output["sequences"] = decoder_output.sequences.clone()
output["scores"] = torch.stack(decoder_output.scores, 1).softmax(-1).cpu().max(-1)[0]
output["repetitions"] = self.llm.tokenizer.batch_decode(output["repetitions"], skip_special_tokens=False)
return output