Inference fails on CPU: `ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)`

#10
by umarbutler - opened

When one runs the below code, taken exactly from the README except for the addition of device = 'cpu', they get the error ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?):

import torch
from transformers import pipeline
from pprint import pprint

pipe = pipeline(
    "fill-mask",
    model="answerdotai/ModernBERT-base",
    torch_dtype=torch.bfloat16,
    device='cpu',
)

input_text = "He walked to the [MASK]."
results = pipe(input_text)
pprint(results)

Here is the full traceback of the error:

ValueError                                Traceback (most recent call last)
Cell In[1], line 13
      5 pipe = pipeline(
      6     "fill-mask",
      7     model="answerdotai/ModernBERT-base",
      8     torch_dtype=torch.bfloat16,
      9     device='cpu',
     10 )
     12 input_text = "He walked to the [MASK]."
---> 13 results = pipe(input_text)
     14 pprint(results)

File ~/dev/.venv/lib/python3.12/site-packages/transformers/pipelines/fill_mask.py:270, in FillMaskPipeline.__call__(self, inputs, **kwargs)
    248 def __call__(self, inputs, **kwargs):
    249     """
    250     Fill the masked token in the text(s) given as inputs.
    251 
   (...)
    268         - **token_str** (str) -- The predicted token (to replace the masked one).
    269     """
--> 270     outputs = super().__call__(inputs, **kwargs)
    271     if isinstance(inputs, list) and len(inputs) == 1:
    272         return outputs[0]

File ~/dev/.venv/lib/python3.12/site-packages/transformers/pipelines/base.py:1301, in Pipeline.__call__(self, inputs, num_workers, batch_size, *args, **kwargs)
   1293     return next(
   1294         iter(
   1295             self.get_iterator(
   (...)
   1298         )
   1299     )
   1300 else:
-> 1301     return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)

File ~/dev/.venv/lib/python3.12/site-packages/transformers/pipelines/base.py:1308, in Pipeline.run_single(self, inputs, preprocess_params, forward_params, postprocess_params)
   1306 def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
   1307     model_inputs = self.preprocess(inputs, **preprocess_params)
-> 1308     model_outputs = self.forward(model_inputs, **forward_params)
   1309     outputs = self.postprocess(model_outputs, **postprocess_params)
   1310     return outputs

File ~/dev/.venv/lib/python3.12/site-packages/transformers/pipelines/base.py:1208, in Pipeline.forward(self, model_inputs, **forward_params)
   1206     with inference_context():
   1207         model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
-> 1208         model_outputs = self._forward(model_inputs, **forward_params)
   1209         model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
   1210 else:

File ~/dev/.venv/lib/python3.12/site-packages/transformers/pipelines/fill_mask.py:127, in FillMaskPipeline._forward(self, model_inputs)
    126 def _forward(self, model_inputs):
--> 127     model_outputs = self.model(**model_inputs)
    128     model_outputs["input_ids"] = model_inputs["input_ids"]
    129     return model_outputs

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:1059, in ModernBertForMaskedLM.forward(self, input_ids, attention_mask, sliding_window_mask, position_ids, labels, indices, cu_seqlens, max_seqlen, batch_size, seq_len, output_attentions, output_hidden_states, return_dict, **kwargs)
   1054         with torch.no_grad():
   1055             input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
   1056                 inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
   1057             )
-> 1059 outputs = self.model(
   1060     input_ids,
   1061     attention_mask=attention_mask,
   1062     sliding_window_mask=sliding_window_mask,
   1063     position_ids=position_ids,
   1064     indices=indices,
   1065     cu_seqlens=cu_seqlens,
   1066     max_seqlen=max_seqlen,
   1067     batch_size=batch_size,
   1068     seq_len=seq_len,
   1069     output_attentions=output_attentions,
   1070     output_hidden_states=output_hidden_states,
   1071     return_dict=return_dict,
   1072 )
   1073 last_hidden_state = outputs[0]
   1075 if self.sparse_prediction and labels is not None:
   1076     # flatten labels and output first

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:913, in ModernBertModel.forward(self, input_ids, attention_mask, sliding_window_mask, position_ids, indices, cu_seqlens, max_seqlen, batch_size, seq_len, output_attentions, output_hidden_states, return_dict)
    902     layer_outputs = self._gradient_checkpointing_func(
    903         encoder_layer.__call__,
    904         hidden_states,
   (...)
    910         output_attentions,
    911     )
    912 else:
--> 913     layer_outputs = encoder_layer(
    914         hidden_states,
    915         attention_mask=attention_mask,
    916         sliding_window_mask=sliding_window_mask,
    917         position_ids=position_ids,
    918         cu_seqlens=cu_seqlens,
    919         max_seqlen=max_seqlen,
    920         output_attentions=output_attentions,
    921     )
    922 hidden_states = layer_outputs[0]
    923 if output_attentions and len(layer_outputs) > 1:

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:529, in ModernBertEncoderLayer.forward(self, hidden_states, attention_mask, sliding_window_mask, position_ids, cu_seqlens, max_seqlen, output_attentions)
    519 def forward(
    520     self,
    521     hidden_states: torch.Tensor,
   (...)
    527     output_attentions: Optional[bool] = False,
    528 ) -> torch.Tensor:
--> 529     attn_outputs = self.attn(
    530         self.attn_norm(hidden_states),
    531         attention_mask=attention_mask,
    532         sliding_window_mask=sliding_window_mask,
    533         position_ids=position_ids,
    534         cu_seqlens=cu_seqlens,
    535         max_seqlen=max_seqlen,
    536         output_attentions=output_attentions,
    537     )
    538     hidden_states = hidden_states + attn_outputs[0]
    539     mlp_output = (
    540         self.compiled_mlp(hidden_states)
    541         if self.config.reference_compile
    542         else self.mlp(self.mlp_norm(hidden_states))
    543     )

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:487, in ModernBertAttention.forward(self, hidden_states, output_attentions, **kwargs)
    484 else:
    485     qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
--> 487 attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
    488     self,
    489     qkv=qkv,
    490     rotary_emb=self.rotary_emb,
    491     local_attention=self.local_attention,
    492     bs=bs,
    493     dim=self.all_head_size,
    494     output_attentions=output_attentions,
    495     **kwargs,
    496 )
    497 hidden_states = attn_outputs[0]
    498 hidden_states = self.out_drop(self.Wo(hidden_states))

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:349, in flash_attention_forward(module, qkv, rotary_emb, cu_seqlens, max_seqlen, local_attention, bs, dim, target_dtype, **_kwargs)
    336 def flash_attention_forward(
    337     module: "ModernBertAttention",
    338     qkv: torch.Tensor,
   (...)
    347 ) -> Tuple[torch.Tensor]:
    348     # (total_seqlen, 3, nheads, headdim)
--> 349     qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
    351     convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
    352     if convert_dtype:
    353         # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
    354         # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:178, in ModernBertUnpaddedRotaryEmbedding.forward(self, qkv, cu_seqlens, max_seqlen)
    175 if max_seqlen is not None:
    176     self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
--> 178 qkv = apply_rotary_unpadded(
    179     qkv,
    180     self._cos_cached,
    181     self._sin_cached,
    182     cu_seqlens=cu_seqlens,
    183     max_seqlen=max_seqlen,
    184 )
    186 return qkv

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:136, in apply_rotary_unpadded(qkv, cos, sin, cu_seqlens, max_seqlen)
    113 def apply_rotary_unpadded(
    114     qkv,
    115     cos,
   (...)
    118     max_seqlen: Optional[int] = None,
    119 ):
    120     """
    121     Arguments:
    122         qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
   (...)
    134     Apply rotary embedding to the first rotary_dim of x.
    135     """
--> 136     return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)

File ~/dev/.venv/lib/python3.12/site-packages/torch/autograd/function.py:575, in Function.apply(cls, *args, **kwargs)
    572 if not torch._C._are_functorch_transforms_active():
    573     # See NOTE: [functorch vjp and autograd interaction]
    574     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575     return super().apply(*args, **kwargs)  # type: ignore[misc]
    577 if not is_setup_ctx_defined:
    578     raise RuntimeError(
    579         "In order to use an autograd.Function with functorch transforms "
    580         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    581         "staticmethod. For more details, please see "
    582         "https://pytorch.org/docs/main/notes/extending.func.html"
    583     )

File ~/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py:75, in ApplyRotaryEmbUnpad.forward(ctx, qkv, cos, sin, cu_seqlens, max_seqlen)
     71 # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
     72 # we get the same tensor
     73 # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
     74 qk = qkv[:, :2].view(total_nnz, -1, headdim)
---> 75 apply_rotary(
     76     qk,
     77     cos,
     78     sin,
     79     seqlen_offsets=0,
     80     cu_seqlens=cu_seqlens,
     81     max_seqlen=max_seqlen,
     82     interleaved=False,
     83     inplace=True,
     84 )
     86 ctx.save_for_backward(cos, sin, cu_seqlens)
     87 ctx.max_seqlen = max_seqlen

File ~/dev/.venv/lib/python3.12/site-packages/flash_attn/ops/triton/rotary.py:202, in apply_rotary(x, cos, sin, seqlen_offsets, cu_seqlens, max_seqlen, interleaved, inplace, conjugate)
    199 # Need this, otherwise Triton tries to launch from cuda:0 and we get
    200 # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
    201 with torch.cuda.device(x.device.index):
--> 202     rotary_kernel[grid](
    203         output,  # data ptrs
    204         x,
    205         cos,
    206         sin,
    207         cu_seqlens,
    208         seqlen_offsets,
    209         seqlen,  # shapes
    210         rotary_dim,
    211         seqlen_ro,
    212         output.stride(0) if not is_varlen else 0,  # batch_strides if not varlen else 0
    213         output.stride(-3),  # seqlen_stride or total_seqlen_stride
    214         output.stride(-2),  # nheads_stride
    215         output.stride(-1),  # headdim_stride
    216         x.stride(0) if not is_varlen else 0,  # batch_strides if not varlen else 0
    217         x.stride(-3),  # seqlen stride or total_seqlen_stride
    218         x.stride(-2),  # nheads stride
    219         x.stride(-1),  # headdim stride
    220         BLOCK_K,
    221         isinstance(seqlen_offsets, torch.Tensor),
    222         is_varlen,
    223         interleaved,
    224         conjugate,
    225         BLOCK_M,
    226     )
    227 return output

File ~/dev/.venv/lib/python3.12/site-packages/triton/runtime/jit.py:345, in KernelInterface.__getitem__.<locals>.<lambda>(*args, **kwargs)
    339 def __getitem__(self, grid) -> T:
    340     """
    341     A JIT function is launched with: fn[grid](*args, **kwargs).
    342     Hence JITFunction.__getitem__ returns a callable proxy that
    343     memorizes the grid.
    344     """
--> 345     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/triton/runtime/jit.py:691, in JITFunction.run(self, grid, warmup, *args, **kwargs)
    689     # launch kernel
    690     launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals)
--> 691     kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
    692                self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals)
    693 return kernel

File ~/dev/.venv/lib/python3.12/site-packages/triton/backends/nvidia/driver.py:365, in CudaLauncher.__call__(self, *args, **kwargs)
    364 def __call__(self, *args, **kwargs):
--> 365     self.launch(*args, **kwargs)

ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

Sign up or log in to comment