fix device handling
Browse files- custom_st.py +1 -5
custom_st.py
CHANGED
|
@@ -27,7 +27,6 @@ class Transformer(nn.Module):
|
|
| 27 |
tokenizer_args: Optional[Dict[str, Any]] = None,
|
| 28 |
config_args: Optional[Dict[str, Any]] = None,
|
| 29 |
cache_dir: Optional[str] = None,
|
| 30 |
-
device: str = 'cpu',
|
| 31 |
backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
|
| 32 |
**kwargs,
|
| 33 |
) -> None:
|
|
@@ -38,7 +37,6 @@ class Transformer(nn.Module):
|
|
| 38 |
f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
|
| 39 |
)
|
| 40 |
|
| 41 |
-
self.device = device
|
| 42 |
self.dimension = dimension
|
| 43 |
self.max_pixels = max_pixels
|
| 44 |
self.min_pixels = min_pixels
|
|
@@ -160,15 +158,13 @@ class Transformer(nn.Module):
|
|
| 160 |
def tokenize(self, texts: List[Union[str, Image.Image]], padding: str = 'longest') -> Dict[str, torch.Tensor]:
|
| 161 |
processed_texts, processed_images = self._process_input(texts)
|
| 162 |
|
| 163 |
-
|
| 164 |
text=processed_texts,
|
| 165 |
images=processed_images,
|
| 166 |
videos=None,
|
| 167 |
padding=padding,
|
| 168 |
return_tensors='pt'
|
| 169 |
)
|
| 170 |
-
|
| 171 |
-
return {k: v.to(self.device) for k, v in inputs.items()}
|
| 172 |
|
| 173 |
def save(self, output_path: str, safe_serialization: bool = True) -> None:
|
| 174 |
"""Save the model, tokenizer and processor to the given path."""
|
|
|
|
| 27 |
tokenizer_args: Optional[Dict[str, Any]] = None,
|
| 28 |
config_args: Optional[Dict[str, Any]] = None,
|
| 29 |
cache_dir: Optional[str] = None,
|
|
|
|
| 30 |
backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
|
| 31 |
**kwargs,
|
| 32 |
) -> None:
|
|
|
|
| 37 |
f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
|
| 38 |
)
|
| 39 |
|
|
|
|
| 40 |
self.dimension = dimension
|
| 41 |
self.max_pixels = max_pixels
|
| 42 |
self.min_pixels = min_pixels
|
|
|
|
| 158 |
def tokenize(self, texts: List[Union[str, Image.Image]], padding: str = 'longest') -> Dict[str, torch.Tensor]:
|
| 159 |
processed_texts, processed_images = self._process_input(texts)
|
| 160 |
|
| 161 |
+
return self.processor(
|
| 162 |
text=processed_texts,
|
| 163 |
images=processed_images,
|
| 164 |
videos=None,
|
| 165 |
padding=padding,
|
| 166 |
return_tensors='pt'
|
| 167 |
)
|
|
|
|
|
|
|
| 168 |
|
| 169 |
def save(self, output_path: str, safe_serialization: bool = True) -> None:
|
| 170 |
"""Save the model, tokenizer and processor to the given path."""
|