Text Generation
Transformers
Safetensors
PyTorch
nvidia
nemotron-h

Enable `cache_params` to work with `generate()` from `GenerationMixin`

#3

IMPORTANT: The cache doesn't seem to work very well in my tests, and given that it was disabled and still contained a breakpoint() call, I assume it was just not ready, but this is still important to be able to understand how fast the model can run when the cache is used.

In my tests, the model in Q4 BF16 goes from generating about 3.3 tok/s to about 19.3 tok/s on the same RTX 5090. Now, I understand that ideally this model should be run in native FP4 on the 5090 but this is not supported yet in pytorch, so I guess this is only possible in the NeMo engine for now.

To briefly describe my changes:

  1. Rename cache_params into past_key_values for the CausalLM model, as was already used in its prepare_inputs_for_generation method, and rename back the cache_params returned by the backbone into past_key_values to enable the transformers code to pass it back in subsequent calls.
  2. Fix various minor issues with the cache itself, like not saving some later-expected values in self or accessing .device on lists rather than individual tensors
  3. Uncomment a breakpoint() call in the mixer code (which might be working incorrectly, I'm not sure).

IMPORTANT: I do no recommend mergin this PR as-is, as the cache doesn't seem to work properly and affects the output quality quite visibly (though this might be because of my Q4 BF16 setup). I don't have time to investigate why this is the case, as this likely requires understanding deeply the entire Mamba2+Attention hybrid pipeline, but I assume whoever has put the breakpoint() in was already on track to figure this out ;-) GLHF!

I will also note that the name of the cache and the name in the warning differ, so maybe there is already a more advanced implementation somehwere at NVIDIA, I can't tell for sure.

NVIDIA org

Thank you @FremyCompany for the detailed comments. Work in progress. Removed breakpoint() for now.

Cannot merge
This branch has merge conflicts in the following files:
  • modeling_nemotron_h.py
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment