Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| from tqdm import tqdm | |
| from text.g2p_module import G2PModule, LexiconModule | |
| from text.symbol_table import SymbolTable | |
| ''' | |
| phoneExtractor: extract phone from text | |
| ''' | |
| class phoneExtractor: | |
| def __init__(self, cfg, dataset_name=None, phone_symbol_file=None): | |
| ''' | |
| Args: | |
| cfg: config | |
| dataset_name: name of dataset | |
| ''' | |
| self.cfg = cfg | |
| # phone symbols dict | |
| self.phone_symbols = set() | |
| # phone symbols dict file | |
| if phone_symbol_file is not None: | |
| self.phone_symbols_file = phone_symbol_file | |
| elif dataset_name is not None: | |
| self.dataset_name = dataset_name | |
| self.phone_symbols_file = os.path.join(cfg.preprocess.processed_dir, | |
| dataset_name, | |
| cfg.preprocess.symbols_dict) | |
| # initialize g2p module | |
| if cfg.preprocess.phone_extractor in ["espeak", "pypinyin", "pypinyin_initials_finals"]: | |
| self.g2p_module = G2PModule(backend=cfg.preprocess.phone_extractor) | |
| elif cfg.preprocess.phone_extractor == 'lexicon': | |
| assert cfg.preprocess.lexicon_path != "" | |
| self.g2p_module = LexiconModule(cfg.preprocess.lexicon_path) | |
| else: | |
| print('No suppert to', cfg.preprocess.phone_extractor) | |
| raise | |
| def extract_phone(self, text): | |
| ''' | |
| Extract phone from text | |
| Args: | |
| text: text of utterance | |
| Returns: | |
| phone_symbols: set of phone symbols | |
| phone_seq: list of phone sequence of each utterance | |
| ''' | |
| if self.cfg.preprocess.phone_extractor in ["espeak", "pypinyin", "pypinyin_initials_finals"]: | |
| text = text.replace("β", '"').replace("β", '"') | |
| phone = self.g2p_module.g2p_conversion(text=text) | |
| self.phone_symbols.update(phone) | |
| phone_seq = [phn for phn in phone] | |
| elif self.cfg.preprocess.phone_extractor == 'lexicon': | |
| phone_seq = self.g2p_module.g2p_conversion(text) | |
| phone = phone_seq | |
| if not isinstance(phone_seq, list): | |
| phone_seq = phone_seq.split() | |
| return phone_seq | |
| def save_dataset_phone_symbols_to_table(self): | |
| # load and merge saved phone symbols | |
| if os.path.exists(self.phone_symbols_file): | |
| phone_symbol_dict_saved = SymbolTable.from_file(self.phone_symbols_file)._sym2id.keys() | |
| self.phone_symbols.update(set(phone_symbol_dict_saved)) | |
| # save phone symbols | |
| phone_symbol_dict = SymbolTable() | |
| for s in sorted(list(self.phone_symbols)): | |
| phone_symbol_dict.add(s) | |
| phone_symbol_dict.to_file(self.phone_symbols_file) | |
| def extract_utt_phone_sequence(cfg, metadata): | |
| ''' | |
| Extract phone sequence from text | |
| Args: | |
| cfg: config | |
| metadata: list of dict, each dict contains "Uid", "Text" | |
| ''' | |
| dataset_name = cfg.dataset[0] | |
| # output path | |
| out_path = os.path.join(cfg.preprocess.processed_dir, dataset_name, cfg.preprocess.phone_dir) | |
| os.makedirs(out_path, exist_ok=True) | |
| phone_extractor = phoneExtractor(cfg, dataset_name) | |
| for utt in tqdm(metadata): | |
| uid = utt["Uid"] | |
| text = utt["Text"] | |
| phone_seq = phone_extractor.extract_phone(text) | |
| phone_path = os.path.join(out_path, uid+'.phone') | |
| with open(phone_path, 'w') as fin: | |
| fin.write(' '.join(phone_seq)) | |
| if cfg.preprocess.phone_extractor != 'lexicon': | |
| phone_extractor.save_dataset_phone_symbols_to_table() | |
| def save_all_dataset_phone_symbols_to_table(self, cfg, dataset): | |
| # phone symbols dict | |
| phone_symbols = set() | |
| for dataset_name in dataset: | |
| phone_symbols_file = os.path.join(cfg.preprocess.processed_dir, | |
| dataset_name, | |
| cfg.preprocess.symbols_dict) | |
| # load and merge saved phone symbols | |
| assert os.path.exists(phone_symbols_file) | |
| phone_symbol_dict_saved = SymbolTable.from_file(phone_symbols_file)._sym2id.keys() | |
| phone_symbols.update(set(phone_symbol_dict_saved)) | |
| # save all phone symbols to each dataset | |
| phone_symbol_dict = SymbolTable() | |
| for s in sorted(list(phone_symbols)): | |
| phone_symbol_dict.add(s) | |
| for dataset_name in dataset: | |
| phone_symbols_file = os.path.join(cfg.preprocess.processed_dir, | |
| dataset_name, | |
| cfg.preprocess.symbols_dict) | |
| phone_symbol_dict.to_file(phone_symbols_file) | |