Spaces:
Runtime error
Runtime error
| from network_wrapper import network_wrapper | |
| class ClearVoice: | |
| """ The main class inferface to the end users for performing speech processing | |
| this class provides the desired model to perform the given task | |
| """ | |
| def __init__(self, task, model_names): | |
| """ Load the desired models for the specified task. Perform all the given models and return all results. | |
| Parameters: | |
| ---------- | |
| task: str | |
| the task matching any of the provided tasks: | |
| 'speech_enhancement' | |
| 'speech_separation' | |
| 'target_speaker_extraction' | |
| model_names: str or list of str | |
| the model names matching any of the provided models: | |
| 'FRCRN_SE_16K' | |
| 'MossFormer2_SE_48K' | |
| 'MossFormerGAN_SE_16K' | |
| 'MossFormer2_SS_16K' | |
| 'AV_MossFormer2_TSE_16K' | |
| Returns: | |
| -------- | |
| A ModelsList object, that can be run to get the desired results | |
| """ | |
| self.network_wrapper = network_wrapper() | |
| self.models = [] | |
| for model_name in model_names: | |
| model = self.network_wrapper(task, model_name) | |
| self.models += [model] | |
| def __call__(self, input_path, online_write=False, output_path=None): | |
| results = {} | |
| for model in self.models: | |
| result = model.process(input_path, online_write, output_path) | |
| if not online_write: | |
| results[model.name] = result | |
| if not online_write: | |
| if len(results) == 1: | |
| return results[model.name] | |
| else: | |
| return results | |
| def write(self, results, output_path): | |
| add_subdir = False | |
| use_key = False | |
| if len(self.models) > 1: add_subdir = True #multi_model is True | |
| for model in self.models: | |
| if isinstance(results, dict): | |
| if model.name in results: | |
| if len(results[model.name]) > 1: use_key = True | |
| else: | |
| if len(results) > 1: use_key = True #multi_input is True | |
| break | |
| for model in self.models: | |
| model.write(output_path, add_subdir, use_key) | |