Update modeling_llava_qwen2.py
Browse files- modeling_llava_qwen2.py +2 -1
modeling_llava_qwen2.py
CHANGED
@@ -19,9 +19,10 @@ from transformers.image_utils import (ChannelDimension, PILImageResampling, to_n
|
|
19 |
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
20 |
from transformers.modeling_utils import PreTrainedModel
|
21 |
from transformers.utils import ModelOutput
|
|
|
22 |
|
23 |
torch.set_default_device('cuda')
|
24 |
-
|
25 |
|
26 |
class SigLipImageProcessor:
|
27 |
def __init__(self,
|
|
|
19 |
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
20 |
from transformers.modeling_utils import PreTrainedModel
|
21 |
from transformers.utils import ModelOutput
|
22 |
+
import subprocess
|
23 |
|
24 |
torch.set_default_device('cuda')
|
25 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
26 |
|
27 |
class SigLipImageProcessor:
|
28 |
def __init__(self,
|