Image-Text-to-Text
sentence-transformers
Safetensors
Transformers
qwen2_vl
Qwen2-VL
conversational
cheesyFishes commited on
Commit
47f5e7c
·
verified ·
1 Parent(s): 6187d4b

make flash attention optional

Browse files
Files changed (1) hide show
  1. custom_st.py +19 -9
custom_st.py CHANGED
@@ -32,15 +32,25 @@ class Transformer(nn.Module):
32
  self.max_pixels = max_pixels
33
  self.min_pixels = min_pixels
34
 
35
- # Initialize model
36
- self.model = Qwen2VLForConditionalGeneration.from_pretrained(
37
- model_name_or_path,
38
- attn_implementation="flash_attention_2",
39
- torch_dtype=torch.bfloat16,
40
- device_map=device,
41
- cache_dir=cache_dir,
42
- **kwargs
43
- ).eval()
 
 
 
 
 
 
 
 
 
 
44
 
45
  # Initialize processor
46
  self.processor = AutoProcessor.from_pretrained(
 
32
  self.max_pixels = max_pixels
33
  self.min_pixels = min_pixels
34
 
35
+ # Try to use flash attention if available, fallback to default attention if not
36
+ try:
37
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
38
+ model_name_or_path,
39
+ attn_implementation="flash_attention_2",
40
+ torch_dtype=torch.bfloat16,
41
+ device_map=device,
42
+ cache_dir=cache_dir,
43
+ **kwargs
44
+ ).eval()
45
+ except (ImportError, ValueError) as e:
46
+ print(f"Flash attention not available, falling back to default attention: {e}")
47
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
48
+ model_name_or_path,
49
+ torch_dtype=torch.bfloat16,
50
+ device_map=device,
51
+ cache_dir=cache_dir,
52
+ **kwargs
53
+ ).eval()
54
 
55
  # Initialize processor
56
  self.processor = AutoProcessor.from_pretrained(