Spaces:
Running
on
A100
Running
on
A100
| # Copyright (c) 2025 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
| # LICENSE is in incl_licenses directory. | |
| from typing import List, Union | |
| import torch | |
| from torch.nn.functional import cross_entropy | |
| from llava.constants import IGNORE_INDEX | |
| __all__ = ["soft_cross_entropy"] | |
| def soft_cross_entropy( | |
| outputs: torch.Tensor, | |
| targets: torch.Tensor, | |
| soft_tokens: Union[torch.Tensor, List[int]], | |
| std: float = 1, | |
| ignore_index: int = IGNORE_INDEX, | |
| ) -> torch.Tensor: | |
| # Remove last token from outputs and first token from targets | |
| outputs = outputs[..., :-1, :].contiguous() | |
| targets = targets[..., 1:].contiguous() | |
| # Flatten outputs and targets | |
| targets = targets.view(-1) | |
| outputs = outputs.view(targets.size(0), -1) | |
| # Remove outputs and targets with ignore_index | |
| indices = targets != ignore_index | |
| outputs = outputs[indices] | |
| targets = targets[indices] | |
| # Convert soft token IDs to tensor | |
| if isinstance(soft_tokens, list): | |
| soft_tokens = torch.tensor(soft_tokens).to(targets) | |
| # Calculate loss for non-soft tokens | |
| indices = torch.isin(targets, soft_tokens, invert=True) | |
| loss = cross_entropy(outputs[indices], targets[indices], reduction="sum") | |
| # Calculate loss for soft tokens | |
| indices = torch.isin(targets, soft_tokens) | |
| targets_indices = torch.zeros_like(outputs[indices]) | |
| for k, target in enumerate(targets[indices]): | |
| dist = torch.exp(-((target - soft_tokens) ** 2) / (2 * std**2)) | |
| targets_indices[k][soft_tokens] = dist / dist.sum() | |
| loss += cross_entropy(outputs[indices], targets_indices, reduction="sum") | |
| # Return average loss | |
| return loss / targets.size(0) | |