when add new tokens get CUDA error: device-side assert triggered
i try a sample scripte by add new to tokens to tokenizer , and try to do inference but i get the following error
RuntimeError Traceback (most recent call last)
/tmp/ipython-input-2193345016.py in <cell line: 0>()
2
3 with torch.inference_mode():
----> 4 generation = model.generate(**inputs, max_new_tokens=120, do_sample=False, cache_implementation="static")
5 generate_ids = generation[0][input_len:]
6
10 frames
/usr/local/lib/python3.11/dist-packages/transformers/models/gemma3n/modeling_gemma3n.py in forward(self, input_ids, per_layer_inputs, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, cache_position, kwargs)
1658 current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
1659 new_magnitude = torch.mean(current_hidden_state2, dim=-1, keepdim=True)
-> 1660 new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
1661 current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
1662 temp_hidden_states.append(current_hidden_state)
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with TORCH_USE_CUDA_DSA
to enable device-side assertions.
by run the following scripte:
model_id = "google/gemma-3n-E2B-it"
if torch.cuda.get_device_capability()[0] >= 8:
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float16
Define model init arguments
model_kwargs = dict(
attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
device_map="auto", # Let torch decide how to load the model
)
BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)
model = Gemma3nForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
**model_kwargs
)
print("--- Model loaded with 4-bit quantization ---")
print(f"Model device: {model.device}")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
token_to_add = ['token1', 'token2', 'token3']
tokenizer.add_tokens(token_to_add)
new_tokens_id = tokenizer.convert_tokens_to_ids(token_to_add)
new_tokens_id
model.resize_token_embeddings(len(tokenizer))
prompt = "hi token1, then token2, then token3 hi"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
inputs
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=120, do_sample=False, cache_implementation="static")
generate_ids = generation[0][input_len:]
Generate
decoded = tokenizer.batch_decode(generation, skip_special_tokens=True)
print(''.join(decoded))