when add new tokens get CUDA error: device-side assert triggered

#34
by mohamed-stifi - opened

i try a sample scripte by add new to tokens to tokenizer , and try to do inference but i get the following error
RuntimeError Traceback (most recent call last)

/tmp/ipython-input-2193345016.py in <cell line: 0>()
2
3 with torch.inference_mode():
----> 4 generation = model.generate(**inputs, max_new_tokens=120, do_sample=False, cache_implementation="static")
5 generate_ids = generation[0][input_len:]
6

10 frames

/usr/local/lib/python3.11/dist-packages/transformers/models/gemma3n/modeling_gemma3n.py in forward(self, input_ids, per_layer_inputs, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, cache_position, kwargs)
1658 current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
1659 new_magnitude = torch.mean(current_hidden_state
2, dim=-1, keepdim=True)
-> 1660 new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
1661 current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
1662 temp_hidden_states.append(current_hidden_state)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

by run the following scripte:

model_id = "google/gemma-3n-E2B-it"

if torch.cuda.get_device_capability()[0] >= 8:
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float16

Define model init arguments

model_kwargs = dict(
attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
device_map="auto", # Let torch decide how to load the model
)

BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage

model_kwargs["quantization_config"] = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

model = Gemma3nForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
**model_kwargs
)

print("--- Model loaded with 4-bit quantization ---")
print(f"Model device: {model.device}")

tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

token_to_add = ['token1', 'token2', 'token3']

tokenizer.add_tokens(token_to_add)

new_tokens_id = tokenizer.convert_tokens_to_ids(token_to_add)
new_tokens_id

model.resize_token_embeddings(len(tokenizer))

prompt = "hi token1, then token2, then token3 hi"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
inputs

input_len = inputs["input_ids"].shape[-1]

with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=120, do_sample=False, cache_implementation="static")
generate_ids = generation[0][input_len:]

Generate

decoded = tokenizer.batch_decode(generation, skip_special_tokens=True)
print(''.join(decoded))

Sign up or log in to comment