Commit
·
5ee2c37
1
Parent(s):
4fa2261
Remove triton flash implementation
Browse files- modeling_bert.py +0 -18
modeling_bert.py
CHANGED
|
@@ -63,12 +63,6 @@ try:
|
|
| 63 |
except ImportError:
|
| 64 |
scaled_dot_product_attention = None
|
| 65 |
|
| 66 |
-
# Triton implementation
|
| 67 |
-
try:
|
| 68 |
-
from .flash_attn_triton import flash_attn_func
|
| 69 |
-
except Exception:
|
| 70 |
-
flash_attn_func = None
|
| 71 |
-
|
| 72 |
# This is used by encode but user may not have it installed
|
| 73 |
try:
|
| 74 |
from tqdm.autonotebook import trange
|
|
@@ -324,18 +318,6 @@ class JinaBertSelfAttention(nn.Module):
|
|
| 324 |
output_attentions: Optional[bool] = False,
|
| 325 |
bias: Optional[torch.FloatTensor] = None,
|
| 326 |
) -> Tuple[torch.Tensor]:
|
| 327 |
-
if self.attn_implementation == 'triton':
|
| 328 |
-
b, s, h = hidden_states.shape
|
| 329 |
-
q = self.query(hidden_states)
|
| 330 |
-
k = self.key(hidden_states)
|
| 331 |
-
v = self.value(hidden_states)
|
| 332 |
-
# B x S x hidden_dim -> B x S x num_heads x head_dim
|
| 333 |
-
q = q.view(b, s, self.num_attention_heads, self.attention_head_size)
|
| 334 |
-
k = k.view(b, s, self.num_attention_heads, self.attention_head_size)
|
| 335 |
-
v = v.view(b, s, self.num_attention_heads, self.attention_head_size)
|
| 336 |
-
attn = flash_attn_func(q, k, v, bias)
|
| 337 |
-
return (attn.view(b, s, h),)
|
| 338 |
-
|
| 339 |
mixed_query_layer = self.query(hidden_states)
|
| 340 |
|
| 341 |
# If this is instantiated as a cross-attention module, the keys
|
|
|
|
| 63 |
except ImportError:
|
| 64 |
scaled_dot_product_attention = None
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
# This is used by encode but user may not have it installed
|
| 67 |
try:
|
| 68 |
from tqdm.autonotebook import trange
|
|
|
|
| 318 |
output_attentions: Optional[bool] = False,
|
| 319 |
bias: Optional[torch.FloatTensor] = None,
|
| 320 |
) -> Tuple[torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
mixed_query_layer = self.query(hidden_states)
|
| 322 |
|
| 323 |
# If this is instantiated as a cross-attention module, the keys
|