tahirjm commited on
Commit
5aece43
·
verified ·
1 Parent(s): b89c1a3

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. model_onnx.py +3 -3
model_onnx.py CHANGED
@@ -29,8 +29,8 @@ class IndicASRModel(PreTrainedModel):
29
  self.models = {}
30
  names = ['encoder', 'ctc_decoder', 'rnnt_decoder', 'joint_enc', 'joint_pred', 'joint_pre_net'] + [f'joint_post_net_{z}' for z in ['as', 'bn', 'brx', 'doi', 'gu', 'hi', 'kn', 'kok', 'ks', 'mai', 'ml', 'mni', 'mr', 'ne', 'or', 'pa', 'sa', 'sat', 'sd', 'ta', 'te', 'ur']]
31
  self.models = {}
32
- self.models['preprocessor'] = torch.jit.load(f'{config.ts_folder}/assets/preprocessor.ts', map_location=self.config.device)
33
-
34
  for n in names:
35
  component_name = f'{config.ts_folder}/assets/{n}.onnx'
36
  if os.path.exists(config.ts_folder):
@@ -55,7 +55,7 @@ class IndicASRModel(PreTrainedModel):
55
 
56
  def encode(self, wav):
57
  # pass through preprocessor
58
- audio_signal, length = self.models['preprocessor'](input_signal=wav.to(self.config.device), length=torch.tensor([wav.shape[-1]]).to(self.config.device))
59
  outputs, encoded_lengths = self.models['encoder'].run(['outputs', 'encoded_lengths'], {'audio_signal': audio_signal.cpu().numpy(), 'length': length.cpu().numpy()})
60
  return outputs, encoded_lengths
61
 
 
29
  self.models = {}
30
  names = ['encoder', 'ctc_decoder', 'rnnt_decoder', 'joint_enc', 'joint_pred', 'joint_pre_net'] + [f'joint_post_net_{z}' for z in ['as', 'bn', 'brx', 'doi', 'gu', 'hi', 'kn', 'kok', 'ks', 'mai', 'ml', 'mni', 'mr', 'ne', 'or', 'pa', 'sa', 'sat', 'sd', 'ta', 'te', 'ur']]
31
  self.models = {}
32
+ self.d = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
+ self.models['preprocessor'] = torch.jit.load(f'{config.ts_folder}/assets/preprocessor.ts', map_location=self.d)
34
  for n in names:
35
  component_name = f'{config.ts_folder}/assets/{n}.onnx'
36
  if os.path.exists(config.ts_folder):
 
55
 
56
  def encode(self, wav):
57
  # pass through preprocessor
58
+ audio_signal, length = self.models['preprocessor'](input_signal=wav.to(self.d), length=torch.tensor([wav.shape[-1]]).to(self.d))
59
  outputs, encoded_lengths = self.models['encoder'].run(['outputs', 'encoded_lengths'], {'audio_signal': audio_signal.cpu().numpy(), 'length': length.cpu().numpy()})
60
  return outputs, encoded_lengths
61