Spaces:
Running
on
Zero
Running
on
Zero
| from text import cleaned_text_to_sequence | |
| import os | |
| # if os.environ.get("version","v1")=="v1": | |
| # from text import chinese | |
| # from text.symbols import symbols | |
| # else: | |
| # from text import chinese2 as chinese | |
| # from text.symbols2 import symbols | |
| from text import symbols as symbols_v1 | |
| from text import symbols2 as symbols_v2 | |
| special = [ | |
| # ("%", "zh", "SP"), | |
| ("¥", "zh", "SP2"), | |
| ("^", "zh", "SP3"), | |
| # ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧 | |
| ] | |
| def load_nvrtc(): | |
| import torch,sys,os,ctypes | |
| from pathlib import Path | |
| if not torch.cuda.is_available(): | |
| print("[INFO] CUDA is not available, skipping nvrtc setup.") | |
| return | |
| if sys.platform == "win32": | |
| torch_lib_dir = Path(torch.__file__).parent / "lib" | |
| if torch_lib_dir.exists(): | |
| os.add_dll_directory(str(torch_lib_dir)) | |
| print(f"[INFO] Added DLL directory: {torch_lib_dir}") | |
| matching_files = sorted(torch_lib_dir.glob("nvrtc*.dll")) | |
| if not matching_files: | |
| print(f"[ERROR] No nvrtc*.dll found in {torch_lib_dir}") | |
| return | |
| for dll_path in matching_files: | |
| dll_name = os.path.basename(dll_path) | |
| try: | |
| ctypes.CDLL(dll_name) | |
| print(f"[INFO] Loaded: {dll_name}") | |
| except OSError as e: | |
| print(f"[WARNING] Failed to load {dll_name}: {e}") | |
| else: | |
| print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}") | |
| elif sys.platform == "linux": | |
| site_packages = Path(torch.__file__).resolve().parents[1] | |
| nvrtc_dir = site_packages / "nvidia" / "cuda_nvrtc" / "lib" | |
| if not nvrtc_dir.exists(): | |
| print(f"[ERROR] nvrtc dir not found: {nvrtc_dir}") | |
| return | |
| matching_files = sorted(nvrtc_dir.glob("libnvrtc*.so*")) | |
| if not matching_files: | |
| print(f"[ERROR] No libnvrtc*.so* found in {nvrtc_dir}") | |
| return | |
| for so_path in matching_files: | |
| try: | |
| ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) # type: ignore | |
| print(f"[INFO] Loaded: {so_path}") | |
| except OSError as e: | |
| print(f"[WARNING] Failed to load {so_path}: {e}") | |
| load_nvrtc() | |
| def clean_text(text, language, version=None): | |
| if version is None: | |
| version = os.environ.get("version", "v2") | |
| if version == "v1": | |
| symbols = symbols_v1.symbols | |
| language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"} | |
| else: | |
| symbols = symbols_v2.symbols | |
| language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"} | |
| if language not in language_module_map: | |
| language = "en" | |
| text = " " | |
| for special_s, special_l, target_symbol in special: | |
| if special_s in text and language == special_l: | |
| return clean_special(text, language, special_s, target_symbol, version) | |
| language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]]) | |
| if hasattr(language_module, "text_normalize"): | |
| norm_text = language_module.text_normalize(text) | |
| else: | |
| norm_text = text | |
| if language == "zh" or language == "yue": ########## | |
| phones, word2ph = language_module.g2p(norm_text) | |
| assert len(phones) == sum(word2ph) | |
| assert len(norm_text) == len(word2ph) | |
| elif language == "en": | |
| phones = language_module.g2p(norm_text) | |
| if len(phones) < 4: | |
| phones = [","] + phones | |
| word2ph = None | |
| else: | |
| phones = language_module.g2p(norm_text) | |
| word2ph = None | |
| phones = ["UNK" if ph not in symbols else ph for ph in phones] | |
| return phones, word2ph, norm_text | |
| def clean_special(text, language, special_s, target_symbol, version=None): | |
| if version is None: | |
| version = os.environ.get("version", "v2") | |
| if version == "v1": | |
| symbols = symbols_v1.symbols | |
| language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"} | |
| else: | |
| symbols = symbols_v2.symbols | |
| language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"} | |
| """ | |
| 特殊静音段sp符号处理 | |
| """ | |
| text = text.replace(special_s, ",") | |
| language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]]) | |
| norm_text = language_module.text_normalize(text) | |
| phones = language_module.g2p(norm_text) | |
| new_ph = [] | |
| for ph in phones[0]: | |
| assert ph in symbols | |
| if ph == ",": | |
| new_ph.append(target_symbol) | |
| else: | |
| new_ph.append(ph) | |
| return new_ph, phones[1], norm_text | |
| def text_to_sequence(text, language, version=None): | |
| version = os.environ.get("version", version) | |
| if version is None: | |
| version = "v2" | |
| phones = clean_text(text) | |
| return cleaned_text_to_sequence(phones, version) | |
| if __name__ == "__main__": | |
| print(clean_text("你好%啊啊啊额、还是到付红四方。", "zh")) | |