Update modeling_sensevoice.py
Browse files- modeling_sensevoice.py +12 -1
modeling_sensevoice.py
CHANGED
@@ -1053,8 +1053,19 @@ class SenseVoiceSmall(nn.Module):
|
|
1053 |
|
1054 |
return encoder_out, encoder_out_lens
|
1055 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1056 |
def export(self, **kwargs):
|
1057 |
-
from export_meta import export_rebuild_model
|
1058 |
|
1059 |
if "max_seq_len" not in kwargs:
|
1060 |
kwargs["max_seq_len"] = 512
|
|
|
1053 |
|
1054 |
return encoder_out, encoder_out_lens
|
1055 |
|
1056 |
+
def export_rebuild_model(model, **kwargs):
|
1057 |
+
model.device = kwargs.get("device")
|
1058 |
+
model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
|
1059 |
+
model.forward = types.MethodType(export_forward, model)
|
1060 |
+
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
|
1061 |
+
model.export_input_names = types.MethodType(export_input_names, model)
|
1062 |
+
model.export_output_names = types.MethodType(export_output_names, model)
|
1063 |
+
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
|
1064 |
+
model.export_name = types.MethodType(export_name, model)
|
1065 |
+
return model
|
1066 |
+
|
1067 |
def export(self, **kwargs):
|
1068 |
+
# from export_meta import export_rebuild_model
|
1069 |
|
1070 |
if "max_seq_len" not in kwargs:
|
1071 |
kwargs["max_seq_len"] = 512
|