Remove seq_len arg in rotary_emb (#1443)
Browse files* remove seq_len in llama rotary_emb
* chore: lint
---------
Co-authored-by: Wing Lian <[email protected]>
    	
        src/axolotl/monkeypatch/llama_attn_hijack_flash.py
    CHANGED
    
    | 
         @@ -284,12 +284,7 @@ def flashattn_forward_with_s2attn( 
     | 
|
| 284 | 
         
             
                # [bsz, nh, q_len, hd]
         
     | 
| 285 | 
         
             
                # pylint: disable=duplicate-code
         
     | 
| 286 | 
         | 
| 287 | 
         
            -
                 
     | 
| 288 | 
         
            -
                if past_key_value is not None:
         
     | 
| 289 | 
         
            -
                    kv_seq_len += past_key_value[0].shape[-2]
         
     | 
| 290 | 
         
            -
                cos, sin = self.rotary_emb(
         
     | 
| 291 | 
         
            -
                    value_states, seq_len=kv_seq_len, position_ids=position_ids
         
     | 
| 292 | 
         
            -
                )
         
     | 
| 293 | 
         
             
                query_states, key_states = apply_rotary_pos_emb(
         
     | 
| 294 | 
         
             
                    query_states, key_states, cos, sin, position_ids
         
     | 
| 295 | 
         
             
                )
         
     | 
| 
         @@ -435,13 +430,7 @@ def flashattn_forward( 
     | 
|
| 435 | 
         
             
                # [bsz, q_len, nh, hd]
         
     | 
| 436 | 
         
             
                # [bsz, nh, q_len, hd]
         
     | 
| 437 | 
         | 
| 438 | 
         
            -
                 
     | 
| 439 | 
         
            -
                if past_key_value is not None:
         
     | 
| 440 | 
         
            -
                    kv_seq_len += past_key_value[0].shape[-2]
         
     | 
| 441 | 
         
            -
             
     | 
| 442 | 
         
            -
                cos, sin = self.rotary_emb(
         
     | 
| 443 | 
         
            -
                    value_states, seq_len=kv_seq_len, position_ids=position_ids
         
     | 
| 444 | 
         
            -
                )
         
     | 
| 445 | 
         
             
                query_states, key_states = apply_rotary_pos_emb(
         
     | 
| 446 | 
         
             
                    query_states, key_states, cos, sin, position_ids
         
     | 
| 447 | 
         
             
                )
         
     | 
| 
         | 
|
| 284 | 
         
             
                # [bsz, nh, q_len, hd]
         
     | 
| 285 | 
         
             
                # pylint: disable=duplicate-code
         
     | 
| 286 | 
         | 
| 287 | 
         
            +
                cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 288 | 
         
             
                query_states, key_states = apply_rotary_pos_emb(
         
     | 
| 289 | 
         
             
                    query_states, key_states, cos, sin, position_ids
         
     | 
| 290 | 
         
             
                )
         
     | 
| 
         | 
|
| 430 | 
         
             
                # [bsz, q_len, nh, hd]
         
     | 
| 431 | 
         
             
                # [bsz, nh, q_len, hd]
         
     | 
| 432 | 
         | 
| 433 | 
         
            +
                cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 434 | 
         
             
                query_states, key_states = apply_rotary_pos_emb(
         
     | 
| 435 | 
         
             
                    query_states, key_states, cos, sin, position_ids
         
     | 
| 436 | 
         
             
                )
         
     | 
    	
        src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
    CHANGED
    
    | 
         @@ -80,11 +80,7 @@ def xformers_forward( 
     | 
|
| 80 | 
         
             
                # [bsz, q_len, nh, hd]
         
     | 
| 81 | 
         
             
                # [bsz, nh, q_len, hd]
         
     | 
| 82 | 
         | 
| 83 | 
         
            -
                 
     | 
| 84 | 
         
            -
                if past_key_value is not None:
         
     | 
| 85 | 
         
            -
                    kv_seq_len += past_key_value[0].shape[-2]
         
     | 
| 86 | 
         
            -
             
     | 
| 87 | 
         
            -
                cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
         
     | 
| 88 | 
         
             
                query_states, key_states = apply_rotary_pos_emb(
         
     | 
| 89 | 
         
             
                    query_states, key_states, cos, sin, position_ids
         
     | 
| 90 | 
         
             
                )
         
     | 
| 
         | 
|
| 80 | 
         
             
                # [bsz, q_len, nh, hd]
         
     | 
| 81 | 
         
             
                # [bsz, nh, q_len, hd]
         
     | 
| 82 | 
         | 
| 83 | 
         
            +
                cos, sin = self.rotary_emb(value_states)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 84 | 
         
             
                query_states, key_states = apply_rotary_pos_emb(
         
     | 
| 85 | 
         
             
                    query_states, key_states, cos, sin, position_ids
         
     | 
| 86 | 
         
             
                )
         
     |