Image-Text-to-Text
sentence-transformers
Safetensors
Transformers
qwen2_vl
Qwen2-VL
conversational
cheesyFishes commited on
Commit
6a23f44
·
verified ·
1 Parent(s): fdba9e3

improve again

Browse files
Files changed (1) hide show
  1. custom_st.py +60 -70
custom_st.py CHANGED
@@ -9,7 +9,7 @@ import requests
9
  import torch
10
  from PIL import Image
11
  from torch import nn
12
- from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, AutoConfig
13
 
14
  class Transformer(nn.Module):
15
  save_in_root: bool = True
@@ -21,11 +21,9 @@ class Transformer(nn.Module):
21
  max_pixels: int = 768 * 28 * 28,
22
  min_pixels: int = 1 * 28 * 28,
23
  dimension: int = 2048,
 
24
  cache_dir: Optional[str] = None,
25
  device: str = 'cuda:0',
26
- config_args: Optional[Dict[str, Any]] = None,
27
- model_args: Optional[Dict[str, Any]] = None,
28
- processor_args: Optional[Dict[str, Any]] = None,
29
  **kwargs,
30
  ) -> None:
31
  super(Transformer, self).__init__()
@@ -34,61 +32,55 @@ class Transformer(nn.Module):
34
  self.dimension = dimension
35
  self.max_pixels = max_pixels
36
  self.min_pixels = min_pixels
37
- self.model_name_or_path = model_name_or_path
38
- self.processor_name_or_path = processor_name_or_path or model_name_or_path
39
- self.cache_dir = cache_dir
40
 
41
- self.config_args = config_args or {}
42
- self.model_args = model_args or {}
43
- self.processor_args = processor_args or {}
44
-
45
- self.document_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is shown in this image?<|im_end|>\n<|endoftext|>"
46
- self.query_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Query: %s<|im_end|>\n<|endoftext|>"
47
-
48
- @classmethod
49
- def load(cls, input_path: str) -> 'Transformer':
50
- config_path = os.path.join(input_path, 'config.json')
51
- if os.path.exists(config_path):
52
- with open(config_path) as f:
53
- config = json.load(f)
54
- else:
55
- config = {}
56
-
57
- instance = cls(model_name_or_path=input_path, **config)
58
-
59
- # Load model with flash attention if available
60
  try:
61
- instance.model = Qwen2VLForConditionalGeneration.from_pretrained(
62
- input_path,
63
  attn_implementation="flash_attention_2",
64
  torch_dtype=torch.bfloat16,
65
- device_map=instance.device,
66
- cache_dir=instance.cache_dir,
67
- **instance.model_args
68
  ).eval()
69
  except (ImportError, ValueError) as e:
70
  print(f"Flash attention not available, falling back to default attention: {e}")
71
- instance.model = Qwen2VLForConditionalGeneration.from_pretrained(
72
- input_path,
73
  torch_dtype=torch.bfloat16,
74
- device_map=instance.device,
75
- cache_dir=instance.cache_dir,
76
- **instance.model_args
77
  ).eval()
78
 
79
  # Initialize processor
80
- instance.processor = AutoProcessor.from_pretrained(
81
- input_path,
82
- min_pixels=instance.min_pixels,
83
- max_pixels=instance.max_pixels,
84
- cache_dir=instance.cache_dir,
85
- **instance.processor_args
86
  )
87
 
88
- instance.model.padding_side = "left"
89
- instance.processor.tokenizer.padding_side = "left"
90
-
91
- return instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def _smart_resize(self, height: int, width: int) -> tuple[int, int]:
94
  h_bar = max(28, self._round_by_factor(height, 28))
@@ -132,21 +124,8 @@ class Transformer(nn.Module):
132
 
133
  for sample in texts:
134
  if isinstance(sample, str):
135
- if sample.startswith('http') or sample.startswith('data:image/'):
136
- try:
137
- if sample.startswith('http'):
138
- response = requests.get(sample)
139
- image = Image.open(BytesIO(response.content)).convert('RGB')
140
- else:
141
- image = self._decode_data_image(sample).convert('RGB')
142
- processed_texts.append(self.document_prompt)
143
- processed_images.append(self._resize_image(image))
144
- except Exception as e:
145
- processed_texts.append(self.query_prompt % sample)
146
- processed_images.append(dummy_image)
147
- else:
148
- processed_texts.append(self.query_prompt % sample)
149
- processed_images.append(dummy_image)
150
  elif isinstance(sample, Image.Image):
151
  processed_texts.append(self.document_prompt)
152
  processed_images.append(self._resize_image(sample))
@@ -186,21 +165,32 @@ class Transformer(nn.Module):
186
  return {k: v.to(self.device) for k, v in inputs.items()}
187
 
188
  def save(self, output_path: str, safe_serialization: bool = True) -> None:
 
 
 
 
189
  # Save the configuration
190
  config = {
191
- 'model_name_or_path': self.model_name_or_path,
192
- 'processor_name_or_path': self.processor_name_or_path,
193
  'max_pixels': self.max_pixels,
194
  'min_pixels': self.min_pixels,
195
  'dimension': self.dimension,
196
- 'config_args': self.config_args,
197
- 'model_args': self.model_args,
198
- 'processor_args': self.processor_args,
199
  }
200
 
201
- os.makedirs(output_path, exist_ok=True)
202
- with open(os.path.join(output_path, 'config.json'), 'w') as f:
203
  json.dump(config, f)
204
 
205
- self.model.save_pretrained(output_path, safe_serialization=safe_serialization)
206
- self.processor.save_pretrained(output_path)
 
 
 
 
 
 
 
 
 
 
 
9
  import torch
10
  from PIL import Image
11
  from torch import nn
12
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
13
 
14
  class Transformer(nn.Module):
15
  save_in_root: bool = True
 
21
  max_pixels: int = 768 * 28 * 28,
22
  min_pixels: int = 1 * 28 * 28,
23
  dimension: int = 2048,
24
+ max_seq_length: Optional[int] = None,
25
  cache_dir: Optional[str] = None,
26
  device: str = 'cuda:0',
 
 
 
27
  **kwargs,
28
  ) -> None:
29
  super(Transformer, self).__init__()
 
32
  self.dimension = dimension
33
  self.max_pixels = max_pixels
34
  self.min_pixels = min_pixels
35
+ self.max_seq_length = max_seq_length
 
 
36
 
37
+ # Initialize model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  try:
39
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
40
+ model_name_or_path,
41
  attn_implementation="flash_attention_2",
42
  torch_dtype=torch.bfloat16,
43
+ device_map=device,
44
+ cache_dir=cache_dir,
45
+ **kwargs
46
  ).eval()
47
  except (ImportError, ValueError) as e:
48
  print(f"Flash attention not available, falling back to default attention: {e}")
49
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
50
+ model_name_or_path,
51
  torch_dtype=torch.bfloat16,
52
+ device_map=device,
53
+ cache_dir=cache_dir,
54
+ **kwargs
55
  ).eval()
56
 
57
  # Initialize processor
58
+ self.processor = AutoProcessor.from_pretrained(
59
+ processor_name_or_path or model_name_or_path,
60
+ min_pixels=min_pixels,
61
+ max_pixels=max_pixels,
62
+ cache_dir=cache_dir
 
63
  )
64
 
65
+ # Set padding sides
66
+ self.model.padding_side = "left"
67
+ self.processor.tokenizer.padding_side = "left"
68
+
69
+ # Store prompts
70
+ self.document_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is shown in this image?<|im_end|>\n<|endoftext|>"
71
+ self.query_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Query: %s<|im_end|>\n<|endoftext|>"
72
+
73
+ # Try to infer max_seq_length if not provided
74
+ if self.max_seq_length is None:
75
+ if (
76
+ hasattr(self.model, 'config')
77
+ and hasattr(self.model.config, 'max_position_embeddings')
78
+ and hasattr(self.processor.tokenizer, 'model_max_length')
79
+ ):
80
+ self.max_seq_length = min(
81
+ self.model.config.max_position_embeddings,
82
+ self.processor.tokenizer.model_max_length,
83
+ )
84
 
85
  def _smart_resize(self, height: int, width: int) -> tuple[int, int]:
86
  h_bar = max(28, self._round_by_factor(height, 28))
 
124
 
125
  for sample in texts:
126
  if isinstance(sample, str):
127
+ processed_texts.append(self.query_prompt % sample)
128
+ processed_images.append(dummy_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  elif isinstance(sample, Image.Image):
130
  processed_texts.append(self.document_prompt)
131
  processed_images.append(self._resize_image(sample))
 
165
  return {k: v.to(self.device) for k, v in inputs.items()}
166
 
167
  def save(self, output_path: str, safe_serialization: bool = True) -> None:
168
+ """Save the model, tokenizer and processor to the given path."""
169
+ self.model.save_pretrained(output_path, safe_serialization=safe_serialization)
170
+ self.processor.save_pretrained(output_path)
171
+
172
  # Save the configuration
173
  config = {
174
+ 'model_name_or_path': output_path,
 
175
  'max_pixels': self.max_pixels,
176
  'min_pixels': self.min_pixels,
177
  'dimension': self.dimension,
178
+ 'max_seq_length': self.max_seq_length,
 
 
179
  }
180
 
181
+ config_path = os.path.join(output_path, 'sentence_bert_config.json')
182
+ with open(config_path, 'w') as f:
183
  json.dump(config, f)
184
 
185
+ @staticmethod
186
+ def load(input_path: str) -> 'Transformer':
187
+ """Load a saved model from the given path."""
188
+ # Load configuration
189
+ config_path = os.path.join(input_path, 'sentence_bert_config.json')
190
+ if os.path.exists(config_path):
191
+ with open(config_path) as f:
192
+ config = json.load(f)
193
+ else:
194
+ config = {'model_name_or_path': input_path}
195
+
196
+ return Transformer(**config)