update batch inference
Browse files- modeling_qwen.py +32 -20
modeling_qwen.py
CHANGED
|
@@ -35,6 +35,8 @@ from torch import nn
|
|
| 35 |
SUPPORT_CUDA = torch.cuda.is_available()
|
| 36 |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
| 37 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
|
|
|
|
|
|
| 38 |
|
| 39 |
from .configuration_qwen import QWenConfig
|
| 40 |
from .qwen_generation_utils import (
|
|
@@ -186,7 +188,7 @@ class FlashSelfAttention(torch.nn.Module):
|
|
| 186 |
device=q.device,
|
| 187 |
)
|
| 188 |
|
| 189 |
-
if attention_mask is not None:
|
| 190 |
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
|
| 191 |
if q.size(0) == v.size(0):
|
| 192 |
q = q[indices_k]
|
|
@@ -222,7 +224,7 @@ class FlashSelfAttention(torch.nn.Module):
|
|
| 222 |
softmax_scale=self.softmax_scale,
|
| 223 |
causal=is_causal,
|
| 224 |
)
|
| 225 |
-
if attention_mask is not None and seqlen_q == seqlen_k:
|
| 226 |
output = self.pad_input(output, indices_k, batch_size, seqlen_out)
|
| 227 |
else:
|
| 228 |
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
|
|
@@ -451,7 +453,7 @@ class QWenAttention(nn.Module):
|
|
| 451 |
def forward(
|
| 452 |
self,
|
| 453 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
| 454 |
-
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
|
| 455 |
registered_causal_mask: Optional[torch.Tensor] = None,
|
| 456 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
| 457 |
attention_mask: Optional[torch.FloatTensor] = None,
|
|
@@ -543,11 +545,7 @@ class QWenAttention(nn.Module):
|
|
| 543 |
and query.is_cuda
|
| 544 |
):
|
| 545 |
q, k, v = query, key, value
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
# b s h d -> b s (h d)
|
| 549 |
-
context_layer = context_layer.flatten(2,3).contiguous()
|
| 550 |
-
|
| 551 |
else:
|
| 552 |
query = query.permute(0, 2, 1, 3)
|
| 553 |
if not self.use_cache_quantization:
|
|
@@ -561,12 +559,28 @@ class QWenAttention(nn.Module):
|
|
| 561 |
and not query.is_cuda
|
| 562 |
):
|
| 563 |
raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
|
| 571 |
attn_output = self.c_proj(context_layer)
|
| 572 |
|
|
@@ -624,7 +638,7 @@ class QWenBlock(nn.Module):
|
|
| 624 |
def forward(
|
| 625 |
self,
|
| 626 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
| 627 |
-
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
|
| 628 |
registered_causal_mask: Optional[torch.Tensor] = None,
|
| 629 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
| 630 |
attention_mask: Optional[torch.FloatTensor] = None,
|
|
@@ -890,11 +904,9 @@ class QWenModel(QWenPreTrainedModel):
|
|
| 890 |
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
| 891 |
ntk_alpha_list.append(ntk_alpha)
|
| 892 |
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
|
| 897 |
-
rotary_pos_emb_list.append(rotary_pos_emb)
|
| 898 |
|
| 899 |
hidden_states = self.drop(hidden_states)
|
| 900 |
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
| 35 |
SUPPORT_CUDA = torch.cuda.is_available()
|
| 36 |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
| 37 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
| 38 |
+
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
|
| 39 |
+
|
| 40 |
|
| 41 |
from .configuration_qwen import QWenConfig
|
| 42 |
from .qwen_generation_utils import (
|
|
|
|
| 188 |
device=q.device,
|
| 189 |
)
|
| 190 |
|
| 191 |
+
if batch_size > 1 and attention_mask is not None:
|
| 192 |
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
|
| 193 |
if q.size(0) == v.size(0):
|
| 194 |
q = q[indices_k]
|
|
|
|
| 224 |
softmax_scale=self.softmax_scale,
|
| 225 |
causal=is_causal,
|
| 226 |
)
|
| 227 |
+
if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k:
|
| 228 |
output = self.pad_input(output, indices_k, batch_size, seqlen_out)
|
| 229 |
else:
|
| 230 |
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
|
|
|
|
| 453 |
def forward(
|
| 454 |
self,
|
| 455 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
| 456 |
+
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
| 457 |
registered_causal_mask: Optional[torch.Tensor] = None,
|
| 458 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
| 459 |
attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
| 545 |
and query.is_cuda
|
| 546 |
):
|
| 547 |
q, k, v = query, key, value
|
| 548 |
+
attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
else:
|
| 550 |
query = query.permute(0, 2, 1, 3)
|
| 551 |
if not self.use_cache_quantization:
|
|
|
|
| 559 |
and not query.is_cuda
|
| 560 |
):
|
| 561 |
raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
|
| 562 |
+
|
| 563 |
+
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
| 564 |
+
causal_mask = registered_causal_mask[
|
| 565 |
+
:, :, key.size(-2) - query.size(-2): key.size(-2), :key.size(-2)
|
| 566 |
+
]
|
| 567 |
+
if attention_mask is not None:
|
| 568 |
+
attention_mask = attention_mask.expand(
|
| 569 |
+
-1, -1, causal_mask.size(2), -1
|
| 570 |
+
).masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
| 571 |
+
else:
|
| 572 |
+
attention_mask = causal_mask
|
| 573 |
+
attn_output = F.scaled_dot_product_attention(
|
| 574 |
+
query, key, value, attn_mask=attention_mask
|
| 575 |
+
).transpose(1, 2)
|
| 576 |
+
attn_weight = None
|
| 577 |
+
else:
|
| 578 |
+
attn_output, attn_weight = self._attn(
|
| 579 |
+
query, key, value, registered_causal_mask, attention_mask, head_mask
|
| 580 |
+
)
|
| 581 |
+
context_layer = self._merge_heads(
|
| 582 |
+
attn_output, self.num_heads, self.head_dim
|
| 583 |
+
)
|
| 584 |
|
| 585 |
attn_output = self.c_proj(context_layer)
|
| 586 |
|
|
|
|
| 638 |
def forward(
|
| 639 |
self,
|
| 640 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
| 641 |
+
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
| 642 |
registered_causal_mask: Optional[torch.Tensor] = None,
|
| 643 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
| 644 |
attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
| 904 |
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
| 905 |
ntk_alpha_list.append(ntk_alpha)
|
| 906 |
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
| 907 |
+
rotary_pos_emb_list = [
|
| 908 |
+
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
|
| 909 |
+
]
|
|
|
|
|
|
|
| 910 |
|
| 911 |
hidden_states = self.drop(hidden_states)
|
| 912 |
output_shape = input_shape + (hidden_states.size(-1),)
|