anakin87 commited on
Commit
8bfc45f
·
verified ·
1 Parent(s): 6a159c9

choose fa2 if GPU available

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -27,16 +27,15 @@ MAX_MAX_NEW_TOKENS = 2048
27
  DEFAULT_MAX_NEW_TOKENS = 1024
28
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
 
30
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
-
32
  model_id = "google/gemma-3-270m-it"
33
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,)
 
 
34
  model = AutoModelForCausalLM.from_pretrained(
35
  model_id,
36
  device_map="auto",
37
  torch_dtype=torch.bfloat16,
38
- attn_implementation="flash_attention_2",
39
- trust_remote_code=True,
40
  )
41
  model.config.sliding_window = 4096
42
  model.eval()
 
27
  DEFAULT_MAX_NEW_TOKENS = 1024
28
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
 
 
 
30
  model_id = "google/gemma-3-270m-it"
31
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
32
+
33
+ attn_impl = "flash_attention_2" if torch.cuda.is_available() else "eager"
34
  model = AutoModelForCausalLM.from_pretrained(
35
  model_id,
36
  device_map="auto",
37
  torch_dtype=torch.bfloat16,
38
+ attn_implementation=attn_impl,
 
39
  )
40
  model.config.sliding_window = 4096
41
  model.eval()