Update custom_st.py
Browse files- custom_st.py +2 -5
custom_st.py
CHANGED
|
@@ -53,9 +53,6 @@ class Transformer(nn.Module):
|
|
| 53 |
'max_pixels': max_pixels,
|
| 54 |
'cache_dir': cache_dir
|
| 55 |
})
|
| 56 |
-
|
| 57 |
-
# remove trust_remote_code
|
| 58 |
-
model_kwargs.pop('trust_remote_code', None)
|
| 59 |
|
| 60 |
# Initialize model
|
| 61 |
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
@@ -259,7 +256,7 @@ class Transformer(nn.Module):
|
|
| 259 |
|
| 260 |
# ensure inputs are on the same device as the model
|
| 261 |
device = next(self.model.parameters()).device
|
| 262 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 263 |
|
| 264 |
with torch.no_grad():
|
| 265 |
output = self.model(
|
|
@@ -274,7 +271,7 @@ class Transformer(nn.Module):
|
|
| 274 |
)
|
| 275 |
return features
|
| 276 |
|
| 277 |
-
def tokenize(self, texts: List[Union[str, Image.Image]], padding: str = 'longest') -> Dict[str, torch.Tensor]:
|
| 278 |
processed_texts, processed_images = self._process_input(texts)
|
| 279 |
|
| 280 |
return self.processor(
|
|
|
|
| 53 |
'max_pixels': max_pixels,
|
| 54 |
'cache_dir': cache_dir
|
| 55 |
})
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
# Initialize model
|
| 58 |
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
|
|
| 256 |
|
| 257 |
# ensure inputs are on the same device as the model
|
| 258 |
device = next(self.model.parameters()).device
|
| 259 |
+
inputs = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
|
| 260 |
|
| 261 |
with torch.no_grad():
|
| 262 |
output = self.model(
|
|
|
|
| 271 |
)
|
| 272 |
return features
|
| 273 |
|
| 274 |
+
def tokenize(self, texts: List[Union[str, Image.Image, bytes]], padding: str = 'longest') -> Dict[str, torch.Tensor]:
|
| 275 |
processed_texts, processed_images = self._process_input(texts)
|
| 276 |
|
| 277 |
return self.processor(
|