Inference Issue with Multi GPU's

#70
by m-nameer - opened

Error:
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1478: indexSelectSmallIndex: block: [4,0,0], thread: [109,0,0] Assertion srcIndex < srcSelectDimSize failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1478: indexSelectSmallIndex: block: [4,0,0], thread: [110,0,0] Assertion srcIndex < srcSelectDimSize failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1478: indexSelectSmallIndex: block: [4,0,0], thread: [111,0,0] Assertion srcIndex < srcSelectDimSize failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1478: indexSelectSmallIndex: block: [4,0,0], thread: [112,0,0] Assertion srcIndex < srcSelectDimSize failed.
.........
Traceback (most recent call last):
File "/home/aisquare/MTIS/Nameer/Live/LLM_Finetuning/test_contiguous_fix.py", line 36, in
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/transformers/generation/utils.py", line 2539, in generate
result = self._sample(
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/transformers/generation/utils.py", line 2870, in _sample
outputs = model_forward(**model_inputs, return_dict=True)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/accelerate/hooks.py", line 175, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1077, in forward
outputs = self.model(
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/transformers/utils/generic.py", line 940, in wrapper
output = func(self, *args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 937, in forward
outputs = self.language_model(
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/accelerate/hooks.py", line 175, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/transformers/utils/generic.py", line 1064, in wrapper
outputs = func(self, *args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 555, in forward
layer_outputs = decoder_layer(
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/transformers/modeling_layers.py", line 94, in call
return super().call(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/accelerate/hooks.py", line 175, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 389, in forward
hidden_states, self_attn_weights = self.self_attn(
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/accelerate/hooks.py", line 175, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 315, in forward
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/accelerate/hooks.py", line 175, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/aisquare/miniconda3/envs/ddp_finetuning/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16BF, lda, b, CUDA_R_16BF, ldb, &fbeta, c, std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16BF, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)

Code:
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
import torch

model_id = "google/gemma-3-4b-it"

Put the model in bf16 (weights), not the inputs

model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
).eval()

processor = AutoProcessor.from_pretrained(model_id)

messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{"role": "user", "content": [{"type": "text", "text": "Hello, How are you?"}]}
]

Create inputs on CPU with correct dtypes (input_ids=int64)

inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
)

Move to the model device without changing dtype

inputs = {k: v.to(model.device) for k, v in inputs.items()}

(Optional) sanity check

print(inputs["input_ids"].dtype) # torch.int64

print(inputs["attention_mask"].dtype) # torch.int64 or torch.bool

input_len = inputs["input_ids"].shape[-1]

with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]

decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)

Transformer Version:
Name: transformers
Version: 4.56.0

Sign up or log in to comment