shenyunhang commited on
Commit
0d71916
·
verified ·
1 Parent(s): 9c74c52

Update modeling_sensevoice.py

Browse files
Files changed (1) hide show
  1. modeling_sensevoice.py +12 -1
modeling_sensevoice.py CHANGED
@@ -1053,8 +1053,19 @@ class SenseVoiceSmall(nn.Module):
1053
 
1054
  return encoder_out, encoder_out_lens
1055
 
 
 
 
 
 
 
 
 
 
 
 
1056
  def export(self, **kwargs):
1057
- from export_meta import export_rebuild_model
1058
 
1059
  if "max_seq_len" not in kwargs:
1060
  kwargs["max_seq_len"] = 512
 
1053
 
1054
  return encoder_out, encoder_out_lens
1055
 
1056
+ def export_rebuild_model(model, **kwargs):
1057
+ model.device = kwargs.get("device")
1058
+ model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
1059
+ model.forward = types.MethodType(export_forward, model)
1060
+ model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
1061
+ model.export_input_names = types.MethodType(export_input_names, model)
1062
+ model.export_output_names = types.MethodType(export_output_names, model)
1063
+ model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
1064
+ model.export_name = types.MethodType(export_name, model)
1065
+ return model
1066
+
1067
  def export(self, **kwargs):
1068
+ # from export_meta import export_rebuild_model
1069
 
1070
  if "max_seq_len" not in kwargs:
1071
  kwargs["max_seq_len"] = 512