jon-tow commited on
Commit
c24bc36
1 Parent(s): 556500a

fix: make `flash_attn` optional (`trust_remote_code` breaks dynamic module check)

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. modeling_stablelm_epoch.py +6 -4
README.md CHANGED
@@ -59,7 +59,7 @@ model = AutoModelForCausalLM.from_pretrained(
59
  "stabilityai/stablelm-3b-4e1t",
60
  trust_remote_code=True,
61
  torch_dtype="auto",
62
- + use_flash_attention_2=True,
63
  )
64
  model.cuda()
65
  inputs = tokenizer("The weather is always wonderful", return_tensors="pt").to(model.device)
 
59
  "stabilityai/stablelm-3b-4e1t",
60
  trust_remote_code=True,
61
  torch_dtype="auto",
62
+ attn_implementation="flash_attention_2",
63
  )
64
  model.cuda()
65
  inputs = tokenizer("The weather is always wonderful", return_tensors="pt").to(model.device)
modeling_stablelm_epoch.py CHANGED
@@ -33,14 +33,16 @@ from transformers.modeling_outputs import (
33
  CausalLMOutputWithPast,
34
  )
35
  from transformers.modeling_utils import PreTrainedModel
36
- from transformers.utils import logging, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
37
 
38
  from .configuration_stablelm_epoch import StableLMEpochConfig
39
 
40
-
41
- if is_flash_attn_2_available():
42
  from flash_attn import flash_attn_func, flash_attn_varlen_func
43
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
 
 
44
 
45
 
46
  logger = logging.get_logger(__name__)
 
33
  CausalLMOutputWithPast,
34
  )
35
  from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import logging, is_flash_attn_greater_or_equal_2_10
37
 
38
  from .configuration_stablelm_epoch import StableLMEpochConfig
39
 
40
+ try:
 
41
  from flash_attn import flash_attn_func, flash_attn_varlen_func
42
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
43
+ except:
44
+ flash_attn_func, flash_attn_varlen_func = None, None
45
+ index_first_axis, pad_input, unpad_input = None, None, None
46
 
47
 
48
  logger = logging.get_logger(__name__)