add backend handling
Browse files- custom_st.py +6 -0
custom_st.py
CHANGED
|
@@ -26,9 +26,15 @@ class Transformer(nn.Module):
|
|
| 26 |
processor_args: Optional[Dict[str, Any]] = None,
|
| 27 |
cache_dir: Optional[str] = None,
|
| 28 |
device: str = 'cuda:0',
|
|
|
|
| 29 |
**kwargs,
|
| 30 |
) -> None:
|
| 31 |
super(Transformer, self).__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
self.device = device
|
| 34 |
self.dimension = dimension
|
|
|
|
| 26 |
processor_args: Optional[Dict[str, Any]] = None,
|
| 27 |
cache_dir: Optional[str] = None,
|
| 28 |
device: str = 'cuda:0',
|
| 29 |
+
backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
|
| 30 |
**kwargs,
|
| 31 |
) -> None:
|
| 32 |
super(Transformer, self).__init__()
|
| 33 |
+
|
| 34 |
+
if backend != 'torch':
|
| 35 |
+
raise ValueError(
|
| 36 |
+
f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
|
| 37 |
+
)
|
| 38 |
|
| 39 |
self.device = device
|
| 40 |
self.dimension = dimension
|