Enable `cache_params` to work with `generate()` from `GenerationMixin`
IMPORTANT: The cache doesn't seem to work very well in my tests, and given that it was disabled and still contained a breakpoint()
call, I assume it was just not ready, but this is still important to be able to understand how fast the model can run when the cache is used.
In my tests, the model in Q4 BF16 goes from generating about 3.3 tok/s to about 19.3 tok/s on the same RTX 5090. Now, I understand that ideally this model should be run in native FP4 on the 5090 but this is not supported yet in pytorch, so I guess this is only possible in the NeMo engine for now.
To briefly describe my changes:
- Rename
cache_params
intopast_key_values
for the CausalLM model, as was already used in itsprepare_inputs_for_generation
method, and rename back thecache_params
returned by the backbone intopast_key_values
to enable the transformers code to pass it back in subsequent calls. - Fix various minor issues with the cache itself, like not saving some later-expected values in
self
or accessing.device
on lists rather than individual tensors - Uncomment a
breakpoint()
call in the mixer code (which might be working incorrectly, I'm not sure).
IMPORTANT: I do no recommend mergin this PR as-is, as the cache doesn't seem to work properly and affects the output quality quite visibly (though this might be because of my Q4 BF16 setup). I don't have time to investigate why this is the case, as this likely requires understanding deeply the entire Mamba2+Attention hybrid pipeline, but I assume whoever has put the breakpoint()
in was already on track to figure this out ;-) GLHF!
I will also note that the name of the cache and the name in the warning differ, so maybe there is already a more advanced implementation somehwere at NVIDIA, I can't tell for sure.
Another strategy that preserves the name cache_params btw:
https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py#L662