|
import logging |
|
from dataclasses import fields |
|
from typing import Callable, List, Optional, Tuple, Union |
|
|
|
import torch |
|
from transformers import PreTrainedModel |
|
from transformers.cache_utils import Cache |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.models.auto import AutoModelForCausalLM |
|
|
|
from .config import ActivationCheckpointingStrategy, ModelConfig |
|
from .model import OLMo |
|
|
|
from .configuration_olmo import OLMoConfig |
|
from typing import ( |
|
Callable, |
|
Dict, |
|
Iterable, |
|
List, |
|
NamedTuple, |
|
Optional, |
|
Sequence, |
|
Set, |
|
Tuple, |
|
cast, |
|
) |
|
log = logging.getLogger(__name__) |
|
|
|
|
|
def create_model_config_from_pretrained_config(config: OLMoConfig): |
|
""" |
|
Utility function |
|
""" |
|
|
|
kwargs = {} |
|
for field in fields(ModelConfig): |
|
kwargs[field.name] = getattr(config, field.name) |
|
|
|
model_config = ModelConfig(**kwargs) |
|
|
|
|
|
if config._attn_implementation == "flash_attention_2": |
|
model_config.flash_attention = True |
|
elif config._attn_implementation in ("eager", "sdpa"): |
|
model_config.flash_attention = False |
|
else: |
|
raise ValueError(f"Unexpected _attn_implementation {config._attn_implementation}") |
|
|
|
return model_config |
|
|
|
|
|
class OLMoForCausalLM(PreTrainedModel): |
|
""" |
|
Extremely barebones HF model wrapper. |
|
""" |
|
|
|
config_class = OLMoConfig |
|
base_model_prefix = "model" |
|
_no_split_modules = ["OLMoBlock"] |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
supports_gradient_checkpointing = True |
|
|
|
def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False): |
|
super().__init__(config) |
|
self._gradient_checkpointing_func: Optional[Callable] = None |
|
self._gradient_checkpointing = False |
|
|
|
if not model: |
|
model_config = create_model_config_from_pretrained_config(config) |
|
|
|
model_config.init_device = "cpu" |
|
self.model = OLMo(model_config, init_params=init_params) |
|
else: |
|
self.model = model |
|
|
|
@property |
|
def gradient_checkpointing(self) -> bool: |
|
return self._gradient_checkpointing |
|
|
|
@gradient_checkpointing.setter |
|
def gradient_checkpointing(self, enabled: bool): |
|
if self._gradient_checkpointing == enabled: |
|
return |
|
|
|
|
|
|
|
checkpointing_strategy = ActivationCheckpointingStrategy.whole_layer if enabled else None |
|
self.model.set_activation_checkpointing( |
|
checkpointing_strategy, checkpoint_func=self._gradient_checkpointing_func |
|
) |
|
self._gradient_checkpointing = enabled |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
attention_bias: Optional[torch.Tensor] = None, |
|
|
|
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[ |
|
Cache |
|
] = None, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
if use_cache is None: |
|
use_cache = self.config.use_cache |
|
|
|
if output_attentions: |
|
raise ValueError("output_attentions is not yet supported in OLMo") |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.model.forward( |
|
input_ids=input_ids, |
|
input_embeddings=inputs_embeds, |
|
attention_mask=attention_mask, |
|
attention_bias=attention_bias, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
logits = outputs.logits |
|
hidden_states = outputs.hidden_states |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = torch.nn.CrossEntropyLoss() |
|
shift_logits = shift_logits.view(-1, self.config.embedding_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.attn_key_values, |
|
hidden_states=hidden_states, |
|
) |
|
|
|
def can_generate(self) -> bool: |
|
return True |
|
|
|
def prepare_inputs_for_generation( |
|
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs |
|
): |
|
if past_key_values: |
|
|
|
input_ids = input_ids[:, -1:] |
|
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values} |
|
|
|
model_inputs.update(kwargs) |
|
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache) |
|
return model_inputs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_input_embeddings(self) -> torch.nn.Module: |
|
return self.model.transformer.wte |
|
|
|
def set_input_embeddings(self, value: torch.nn.Module): |
|
self.model.transformer.wte = value |
|
|
|
def get_output_embeddings(self): |
|
if self.config.weight_tying: |
|
return self.model.transformer.wte |
|
else: |
|
return self.model.transformer.ff_out |
|
|
|
def set_output_embeddings(self, value: torch.nn.Module): |
|
if self.config.weight_tying: |
|
self.model.transformer.wte = value |
|
else: |
|
self.model.transformer.ff_out = value |
|
|
|
def tie_weights(self): |
|
""" |
|
This function is intentionally left as a no-op. |
|
|
|
Weight tying is handled as follows: |
|
- When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration. |
|
See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`. |
|
- When computing logits, the `wte` weights are used directly if `weight_tying` is enabled. |
|
See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method. |
|
|
|
Therefore, there is no need to explicitly tie the weights in this function. |
|
""" |
|
pass |
|
|
|
def resize_token_embeddings( |
|
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None |
|
) -> torch.nn.Embedding: |
|
""" |
|
Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`. |
|
|
|
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. |
|
|
|
Arguments: |
|
new_num_tokens (`int`, *optional*): |
|
The new number of tokens in the embedding matrix. Increasing the size will add newly initialized |
|
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just |
|
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. |
|
pad_to_multiple_of (`int`, *optional*): |
|
If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to |
|
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`. |
|
|
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability |
|
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more |
|
details about this, or help on choosing the correct value for resizing, refer to this guide: |
|
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc |
|
|
|
Return: |
|
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. |
|
|
|
Note: |
|
This method differs from the base class implementation by resizing the `embedding_size` attribute of the |
|
model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size` |
|
is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token |
|
embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary. |
|
""" |
|
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) |
|
if new_num_tokens is None and pad_to_multiple_of is None: |
|
return model_embeds |
|
|
|
|
|
self.config.embedding_size = model_embeds.weight.shape[0] |
|
self.model.config.embedding_size = model_embeds.weight.shape[0] |
|
|
|
|
|
if self.config.embedding_size < self.config.vocab_size: |
|
warning_message = ( |
|
f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size " |
|
f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary " |
|
"size is less than or equal to the new token embedding size." |
|
) |
|
log.warning(warning_message) |
|
|
|
|
|
self.tie_weights() |
|
|
|
return model_embeds |
|
|
|
|
|
|
|
|
|
|
|
AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM) |
|
|