cheesyFishes commited on
Commit
51c192a
·
verified ·
1 Parent(s): 0388752

add back original code?

Browse files
Files changed (1) hide show
  1. custom_st.py +195 -34
custom_st.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import math
5
  from io import BytesIO
6
  from typing import Any, Dict, List, Literal, Optional, Union
 
7
 
8
  import requests
9
  import torch
@@ -21,54 +22,74 @@ class Transformer(nn.Module):
21
  max_pixels: int = 768 * 28 * 28,
22
  min_pixels: int = 1 * 28 * 28,
23
  dimension: int = 2048,
 
 
 
 
 
24
  cache_dir: Optional[str] = None,
25
- device: str = 'cuda:0',
26
  **kwargs,
27
  ) -> None:
28
  super(Transformer, self).__init__()
 
 
 
 
 
29
 
30
- self.device = device
31
  self.dimension = dimension
32
  self.max_pixels = max_pixels
33
  self.min_pixels = min_pixels
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # trust_remote_code is not needed here
36
- kwargs.pop("trust_remote_code")
37
 
38
- # Try to use flash attention if available, fallback to default attention if not
39
- try:
40
- self.model = Qwen2VLForConditionalGeneration.from_pretrained(
41
- model_name_or_path,
42
- attn_implementation="flash_attention_2",
43
- torch_dtype=torch.bfloat16,
44
- device_map=device,
45
- cache_dir=cache_dir,
46
- **kwargs
47
- ).eval()
48
- except (ImportError, ValueError) as e:
49
- print(f"Flash attention not available, falling back to default attention: {e}")
50
- self.model = Qwen2VLForConditionalGeneration.from_pretrained(
51
- model_name_or_path,
52
- torch_dtype=torch.bfloat16,
53
- device_map=device,
54
- cache_dir=cache_dir,
55
- **kwargs
56
- ).eval()
57
 
58
  # Initialize processor
59
  self.processor = AutoProcessor.from_pretrained(
60
  processor_name_or_path or model_name_or_path,
61
- min_pixels=min_pixels,
62
- max_pixels=max_pixels,
63
- cache_dir=cache_dir
64
  )
65
 
 
66
  self.model.padding_side = "left"
67
  self.processor.tokenizer.padding_side = "left"
68
 
 
69
  self.document_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is shown in this image?<|im_end|>\n<|endoftext|>"
70
  self.query_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Query: %s<|im_end|>\n<|endoftext|>"
71
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def _smart_resize(self, height: int, width: int) -> tuple[int, int]:
73
  h_bar = max(28, self._round_by_factor(height, 28))
74
  w_bar = max(28, self._round_by_factor(width, 28))
@@ -104,27 +125,142 @@ class Transformer(nn.Module):
104
  image_data = base64.b64decode(data)
105
  return Image.open(BytesIO(image_data))
106
 
107
- def _process_input(self, texts: List[Union[str, Image.Image]]) -> tuple[List[str], List[Image.Image]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  processed_texts = []
109
  processed_images = []
110
  dummy_image = Image.new('RGB', (56, 56))
111
 
112
  for sample in texts:
113
  if isinstance(sample, str):
114
- processed_texts.append(self.query_prompt % sample)
115
- processed_images.append(dummy_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  elif isinstance(sample, Image.Image):
117
  processed_texts.append(self.document_prompt)
118
  processed_images.append(self._resize_image(sample))
 
 
 
 
 
 
 
 
 
119
 
120
  return processed_texts, processed_images
121
 
122
  def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
123
- cache_position = torch.arange(0, features['input_ids'].shape[0])
124
  inputs = self.model.prepare_inputs_for_generation(
125
  **features, cache_position=cache_position, use_cache=False
126
  )
127
 
 
 
 
 
128
  with torch.no_grad():
129
  output = self.model(
130
  **inputs,
@@ -141,16 +277,41 @@ class Transformer(nn.Module):
141
  def tokenize(self, texts: List[Union[str, Image.Image]], padding: str = 'longest') -> Dict[str, torch.Tensor]:
142
  processed_texts, processed_images = self._process_input(texts)
143
 
144
- inputs = self.processor(
145
  text=processed_texts,
146
  images=processed_images,
147
  videos=None,
148
  padding=padding,
149
  return_tensors='pt'
150
  )
151
-
152
- return {k: v.to(self.device) for k, v in inputs.items()}
153
 
154
  def save(self, output_path: str, safe_serialization: bool = True) -> None:
 
155
  self.model.save_pretrained(output_path, safe_serialization=safe_serialization)
156
- self.processor.save_pretrained(output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import math
5
  from io import BytesIO
6
  from typing import Any, Dict, List, Literal, Optional, Union
7
+ from urllib.parse import urlparse
8
 
9
  import requests
10
  import torch
 
22
  max_pixels: int = 768 * 28 * 28,
23
  min_pixels: int = 1 * 28 * 28,
24
  dimension: int = 2048,
25
+ max_seq_length: Optional[int] = None,
26
+ model_args: Optional[Dict[str, Any]] = None,
27
+ processor_args: Optional[Dict[str, Any]] = None,
28
+ tokenizer_args: Optional[Dict[str, Any]] = None,
29
+ config_args: Optional[Dict[str, Any]] = None,
30
  cache_dir: Optional[str] = None,
31
+ backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
32
  **kwargs,
33
  ) -> None:
34
  super(Transformer, self).__init__()
35
+
36
+ if backend != 'torch':
37
+ raise ValueError(
38
+ f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
39
+ )
40
 
 
41
  self.dimension = dimension
42
  self.max_pixels = max_pixels
43
  self.min_pixels = min_pixels
44
+ self.max_seq_length = max_seq_length
45
+
46
+ # Handle args
47
+ model_kwargs = model_args or {}
48
+ model_kwargs.update(kwargs)
49
+
50
+ processor_kwargs = processor_args or {}
51
+ processor_kwargs.update({
52
+ 'min_pixels': min_pixels,
53
+ 'max_pixels': max_pixels,
54
+ 'cache_dir': cache_dir
55
+ })
56
 
57
+ # remove trust_remote_code
58
+ model_kwargs.pop('trust_remote_code', None)
59
 
60
+ # Initialize model
61
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
62
+ model_name_or_path,
63
+ cache_dir=cache_dir,
64
+ **model_kwargs
65
+ ).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Initialize processor
68
  self.processor = AutoProcessor.from_pretrained(
69
  processor_name_or_path or model_name_or_path,
70
+ **processor_kwargs
 
 
71
  )
72
 
73
+ # Set padding sides
74
  self.model.padding_side = "left"
75
  self.processor.tokenizer.padding_side = "left"
76
 
77
+ # Store prompts
78
  self.document_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is shown in this image?<|im_end|>\n<|endoftext|>"
79
  self.query_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Query: %s<|im_end|>\n<|endoftext|>"
80
 
81
+ # Try to infer max_seq_length if not provided
82
+ if self.max_seq_length is None:
83
+ if (
84
+ hasattr(self.model, 'config')
85
+ and hasattr(self.model.config, 'max_position_embeddings')
86
+ and hasattr(self.processor.tokenizer, 'model_max_length')
87
+ ):
88
+ self.max_seq_length = min(
89
+ self.model.config.max_position_embeddings,
90
+ self.processor.tokenizer.model_max_length,
91
+ )
92
+
93
  def _smart_resize(self, height: int, width: int) -> tuple[int, int]:
94
  h_bar = max(28, self._round_by_factor(height, 28))
95
  w_bar = max(28, self._round_by_factor(width, 28))
 
125
  image_data = base64.b64decode(data)
126
  return Image.open(BytesIO(image_data))
127
 
128
+ @staticmethod
129
+ def _is_valid_url(url: str) -> bool:
130
+ try:
131
+ result = urlparse(url)
132
+ # Check if scheme and netloc are present and scheme is http/https
133
+ return all([result.scheme in ('http', 'https'), result.netloc])
134
+ except Exception:
135
+ return False
136
+
137
+ @staticmethod
138
+ def _is_safe_path(path: str) -> bool:
139
+ try:
140
+ # Convert to absolute path and normalize
141
+ abs_path = os.path.abspath(os.path.normpath(path))
142
+ # Check if file exists and is a regular file (not a directory or special file)
143
+ return os.path.isfile(abs_path)
144
+ except Exception:
145
+ return False
146
+
147
+ @staticmethod
148
+ def _load_image_from_url(url: str) -> Image.Image:
149
+ try:
150
+ response = requests.get(
151
+ url,
152
+ stream=True,
153
+ timeout=10, # Add timeout
154
+ headers={'User-Agent': 'Mozilla/5.0'} # Add user agent
155
+ )
156
+ response.raise_for_status()
157
+
158
+ # Check content type
159
+ content_type = response.headers.get('content-type', '')
160
+ if not content_type.startswith('image/'):
161
+ raise ValueError(f"Invalid content type: {content_type}")
162
+
163
+ # Limit file size (e.g., 10MB)
164
+ content = BytesIO()
165
+ size = 0
166
+ max_size = 10 * 1024 * 1024 # 10MB
167
+
168
+ for chunk in response.iter_content(chunk_size=8192):
169
+ size += len(chunk)
170
+ if size > max_size:
171
+ raise ValueError("File too large")
172
+ content.write(chunk)
173
+
174
+ content.seek(0)
175
+ return Image.open(content)
176
+ except Exception as e:
177
+ raise ValueError(f"Failed to load image from URL: {str(e)}")
178
+
179
+ @staticmethod
180
+ def _load_image_from_path(image_path: str) -> Image.Image:
181
+ try:
182
+ # Convert to absolute path and normalize
183
+ abs_path = os.path.abspath(os.path.normpath(image_path))
184
+
185
+ # Check file size before loading
186
+ file_size = os.path.getsize(abs_path)
187
+ max_size = 10 * 1024 * 1024 # 10MB
188
+ if file_size > max_size:
189
+ raise ValueError("File too large")
190
+
191
+ with Image.open(abs_path) as img:
192
+ # Make a copy to ensure file handle is closed
193
+ return img.copy()
194
+ except Exception as e:
195
+ raise ValueError(f"Failed to load image from path: {str(e)}")
196
+
197
+ @staticmethod
198
+ def _load_image_from_bytes(image_bytes: bytes) -> Image.Image:
199
+ try:
200
+ # Check size
201
+ if len(image_bytes) > 10 * 1024 * 1024: # 10MB
202
+ raise ValueError("Image data too large")
203
+
204
+ return Image.open(BytesIO(image_bytes))
205
+ except Exception as e:
206
+ raise ValueError(f"Failed to load image from bytes: {str(e)}")
207
+
208
+ def _process_input(self, texts: List[Union[str, Image.Image, bytes]]) -> tuple[List[str], List[Image.Image]]:
209
  processed_texts = []
210
  processed_images = []
211
  dummy_image = Image.new('RGB', (56, 56))
212
 
213
  for sample in texts:
214
  if isinstance(sample, str):
215
+ # Check if the string is a valid URL
216
+ if self._is_valid_url(sample):
217
+ try:
218
+ img = self._load_image_from_url(sample)
219
+ processed_texts.append(self.document_prompt)
220
+ processed_images.append(self._resize_image(img))
221
+ except Exception as e:
222
+ # If URL loading fails, treat as regular text
223
+ processed_texts.append(self.query_prompt % sample)
224
+ processed_images.append(dummy_image)
225
+ # Check if the string is a valid file path
226
+ elif self._is_safe_path(sample):
227
+ try:
228
+ img = self._load_image_from_path(sample)
229
+ processed_texts.append(self.document_prompt)
230
+ processed_images.append(self._resize_image(img))
231
+ except Exception as e:
232
+ # If image loading fails, treat as regular text
233
+ processed_texts.append(self.query_prompt % sample)
234
+ processed_images.append(dummy_image)
235
+ else:
236
+ # Regular text query
237
+ processed_texts.append(self.query_prompt % sample)
238
+ processed_images.append(dummy_image)
239
  elif isinstance(sample, Image.Image):
240
  processed_texts.append(self.document_prompt)
241
  processed_images.append(self._resize_image(sample))
242
+ elif isinstance(sample, bytes):
243
+ try:
244
+ img = self._load_image_from_bytes(sample)
245
+ processed_texts.append(self.document_prompt)
246
+ processed_images.append(self._resize_image(img))
247
+ except Exception as e:
248
+ # If bytes can't be converted to image, use dummy
249
+ processed_texts.append(self.document_prompt)
250
+ processed_images.append(dummy_image)
251
 
252
  return processed_texts, processed_images
253
 
254
  def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
255
+ cache_position = torch.arange(0, features['input_ids'].shape[1])
256
  inputs = self.model.prepare_inputs_for_generation(
257
  **features, cache_position=cache_position, use_cache=False
258
  )
259
 
260
+ # ensure inputs are on the same device as the model
261
+ device = next(self.model.parameters()).device
262
+ inputs = {k: v.to(device) for k, v in inputs.items()}
263
+
264
  with torch.no_grad():
265
  output = self.model(
266
  **inputs,
 
277
  def tokenize(self, texts: List[Union[str, Image.Image]], padding: str = 'longest') -> Dict[str, torch.Tensor]:
278
  processed_texts, processed_images = self._process_input(texts)
279
 
280
+ return self.processor(
281
  text=processed_texts,
282
  images=processed_images,
283
  videos=None,
284
  padding=padding,
285
  return_tensors='pt'
286
  )
 
 
287
 
288
  def save(self, output_path: str, safe_serialization: bool = True) -> None:
289
+ """Save the model, tokenizer and processor to the given path."""
290
  self.model.save_pretrained(output_path, safe_serialization=safe_serialization)
291
+ self.processor.save_pretrained(output_path)
292
+
293
+ # Save the configuration
294
+ config = {
295
+ 'model_name_or_path': output_path,
296
+ 'max_pixels': self.max_pixels,
297
+ 'min_pixels': self.min_pixels,
298
+ 'dimension': self.dimension,
299
+ 'max_seq_length': self.max_seq_length,
300
+ }
301
+
302
+ config_path = os.path.join(output_path, 'sentence_bert_config.json')
303
+ with open(config_path, 'w') as f:
304
+ json.dump(config, f)
305
+
306
+ @staticmethod
307
+ def load(input_path: str) -> 'Transformer':
308
+ """Load a saved model from the given path."""
309
+ # Load configuration
310
+ config_path = os.path.join(input_path, 'sentence_bert_config.json')
311
+ if os.path.exists(config_path):
312
+ with open(config_path) as f:
313
+ config = json.load(f)
314
+ else:
315
+ config = {'model_name_or_path': input_path}
316
+
317
+ return Transformer(**config)