Spaces:
Paused
Paused
| import torch | |
| def mean_pooling(token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
| """Perform attention-aware mean pooling. | |
| This method takes in embeddings of shape (batch, sequence, embedding_size) and performs average | |
| pooling across the sequence dimension to yield embeddings of size (batch, embedding_size). | |
| From: | |
| https://github.com/UKPLab/sentence-transformers/blob/0b5ef4be93d2b21de3a918a084b48aab6ba48595/sentence_transformers/model_card_templates.py#L134 # noqa: E501 | |
| Args: | |
| token_embeddings (`torch.Tensor`): The embeddings we wish to pool over of shape | |
| (batch, sequence, embedding_size). This will pool over the sequence to yield | |
| (batch, embedding_size). | |
| attention_mask (`torch.Tensor`): The binary attention mask across the embedings of shape | |
| Returns: | |
| (`torch.Tensor`) The mean pooled embeddings of size (batch, embedding_size). | |
| """ | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |