cheesyFishes commited on
Commit
0965666
·
verified ·
1 Parent(s): 9a36057

Update custom_st.py

Browse files
Files changed (1) hide show
  1. custom_st.py +2 -5
custom_st.py CHANGED
@@ -53,9 +53,6 @@ class Transformer(nn.Module):
53
  'max_pixels': max_pixels,
54
  'cache_dir': cache_dir
55
  })
56
-
57
- # remove trust_remote_code
58
- model_kwargs.pop('trust_remote_code', None)
59
 
60
  # Initialize model
61
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -259,7 +256,7 @@ class Transformer(nn.Module):
259
 
260
  # ensure inputs are on the same device as the model
261
  device = next(self.model.parameters()).device
262
- inputs = {k: v.to(device) for k, v in inputs.items()}
263
 
264
  with torch.no_grad():
265
  output = self.model(
@@ -274,7 +271,7 @@ class Transformer(nn.Module):
274
  )
275
  return features
276
 
277
- def tokenize(self, texts: List[Union[str, Image.Image]], padding: str = 'longest') -> Dict[str, torch.Tensor]:
278
  processed_texts, processed_images = self._process_input(texts)
279
 
280
  return self.processor(
 
53
  'max_pixels': max_pixels,
54
  'cache_dir': cache_dir
55
  })
 
 
 
56
 
57
  # Initialize model
58
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
 
256
 
257
  # ensure inputs are on the same device as the model
258
  device = next(self.model.parameters()).device
259
+ inputs = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
260
 
261
  with torch.no_grad():
262
  output = self.model(
 
271
  )
272
  return features
273
 
274
+ def tokenize(self, texts: List[Union[str, Image.Image, bytes]], padding: str = 'longest') -> Dict[str, torch.Tensor]:
275
  processed_texts, processed_images = self._process_input(texts)
276
 
277
  return self.processor(