Upload folder using huggingface_hub
Browse files- model_onnx.py +3 -3
model_onnx.py
CHANGED
@@ -29,8 +29,8 @@ class IndicASRModel(PreTrainedModel):
|
|
29 |
self.models = {}
|
30 |
names = ['encoder', 'ctc_decoder', 'rnnt_decoder', 'joint_enc', 'joint_pred', 'joint_pre_net'] + [f'joint_post_net_{z}' for z in ['as', 'bn', 'brx', 'doi', 'gu', 'hi', 'kn', 'kok', 'ks', 'mai', 'ml', 'mni', 'mr', 'ne', 'or', 'pa', 'sa', 'sat', 'sd', 'ta', 'te', 'ur']]
|
31 |
self.models = {}
|
32 |
-
self.
|
33 |
-
|
34 |
for n in names:
|
35 |
component_name = f'{config.ts_folder}/assets/{n}.onnx'
|
36 |
if os.path.exists(config.ts_folder):
|
@@ -55,7 +55,7 @@ class IndicASRModel(PreTrainedModel):
|
|
55 |
|
56 |
def encode(self, wav):
|
57 |
# pass through preprocessor
|
58 |
-
audio_signal, length = self.models['preprocessor'](input_signal=wav.to(self.
|
59 |
outputs, encoded_lengths = self.models['encoder'].run(['outputs', 'encoded_lengths'], {'audio_signal': audio_signal.cpu().numpy(), 'length': length.cpu().numpy()})
|
60 |
return outputs, encoded_lengths
|
61 |
|
|
|
29 |
self.models = {}
|
30 |
names = ['encoder', 'ctc_decoder', 'rnnt_decoder', 'joint_enc', 'joint_pred', 'joint_pre_net'] + [f'joint_post_net_{z}' for z in ['as', 'bn', 'brx', 'doi', 'gu', 'hi', 'kn', 'kok', 'ks', 'mai', 'ml', 'mni', 'mr', 'ne', 'or', 'pa', 'sa', 'sat', 'sd', 'ta', 'te', 'ur']]
|
31 |
self.models = {}
|
32 |
+
self.d = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
33 |
+
self.models['preprocessor'] = torch.jit.load(f'{config.ts_folder}/assets/preprocessor.ts', map_location=self.d)
|
34 |
for n in names:
|
35 |
component_name = f'{config.ts_folder}/assets/{n}.onnx'
|
36 |
if os.path.exists(config.ts_folder):
|
|
|
55 |
|
56 |
def encode(self, wav):
|
57 |
# pass through preprocessor
|
58 |
+
audio_signal, length = self.models['preprocessor'](input_signal=wav.to(self.d), length=torch.tensor([wav.shape[-1]]).to(self.d))
|
59 |
outputs, encoded_lengths = self.models['encoder'].run(['outputs', 'encoded_lengths'], {'audio_signal': audio_signal.cpu().numpy(), 'length': length.cpu().numpy()})
|
60 |
return outputs, encoded_lengths
|
61 |
|