cheesyFishes
commited on
Update custom_st.py
Browse files- custom_st.py +2 -5
custom_st.py
CHANGED
@@ -53,9 +53,6 @@ class Transformer(nn.Module):
|
|
53 |
'max_pixels': max_pixels,
|
54 |
'cache_dir': cache_dir
|
55 |
})
|
56 |
-
|
57 |
-
# remove trust_remote_code
|
58 |
-
model_kwargs.pop('trust_remote_code', None)
|
59 |
|
60 |
# Initialize model
|
61 |
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
@@ -259,7 +256,7 @@ class Transformer(nn.Module):
|
|
259 |
|
260 |
# ensure inputs are on the same device as the model
|
261 |
device = next(self.model.parameters()).device
|
262 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
263 |
|
264 |
with torch.no_grad():
|
265 |
output = self.model(
|
@@ -274,7 +271,7 @@ class Transformer(nn.Module):
|
|
274 |
)
|
275 |
return features
|
276 |
|
277 |
-
def tokenize(self, texts: List[Union[str, Image.Image]], padding: str = 'longest') -> Dict[str, torch.Tensor]:
|
278 |
processed_texts, processed_images = self._process_input(texts)
|
279 |
|
280 |
return self.processor(
|
|
|
53 |
'max_pixels': max_pixels,
|
54 |
'cache_dir': cache_dir
|
55 |
})
|
|
|
|
|
|
|
56 |
|
57 |
# Initialize model
|
58 |
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
|
256 |
|
257 |
# ensure inputs are on the same device as the model
|
258 |
device = next(self.model.parameters()).device
|
259 |
+
inputs = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
|
260 |
|
261 |
with torch.no_grad():
|
262 |
output = self.model(
|
|
|
271 |
)
|
272 |
return features
|
273 |
|
274 |
+
def tokenize(self, texts: List[Union[str, Image.Image, bytes]], padding: str = 'longest') -> Dict[str, torch.Tensor]:
|
275 |
processed_texts, processed_images = self._process_input(texts)
|
276 |
|
277 |
return self.processor(
|