qnguyen3 commited on
Commit
0872fb5
1 Parent(s): 859ccbd

Update modeling_llava_qwen2.py

Browse files
Files changed (1) hide show
  1. modeling_llava_qwen2.py +2 -1
modeling_llava_qwen2.py CHANGED
@@ -19,9 +19,10 @@ from transformers.image_utils import (ChannelDimension, PILImageResampling, to_n
19
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
20
  from transformers.modeling_utils import PreTrainedModel
21
  from transformers.utils import ModelOutput
 
22
 
23
  torch.set_default_device('cuda')
24
-
25
 
26
  class SigLipImageProcessor:
27
  def __init__(self,
 
19
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
20
  from transformers.modeling_utils import PreTrainedModel
21
  from transformers.utils import ModelOutput
22
+ import subprocess
23
 
24
  torch.set_default_device('cuda')
25
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
26
 
27
  class SigLipImageProcessor:
28
  def __init__(self,