Create utils.py
Browse files- cijiang/utils.py +103 -0
cijiang/utils.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import json
|
| 3 |
+
from colorama import Fore, Style, init
|
| 4 |
+
|
| 5 |
+
init(autoreset=True)
|
| 6 |
+
|
| 7 |
+
with open('rules/ALL_SYLLABLES.txt', 'r', encoding='utf-8') as f:
|
| 8 |
+
ALL_SYLLABLES = f.read().strip().split()
|
| 9 |
+
ALL_SYLLABLES = [syllable for syllable in ALL_SYLLABLES if syllable]
|
| 10 |
+
|
| 11 |
+
YUNMU_LIST = ['a', 'o', 'e', 'i', 'u', 'v',
|
| 12 |
+
'ai', 'ei', 'ao', 'ou', 'ia', 'ie', 'iao', 'iu', 'ua', 'uo', 'uai', 'ui', 've',
|
| 13 |
+
'an', 'en', 'in', 'un', 'vn', 'ian', 'uan', 'vuan',
|
| 14 |
+
'ang', 'eng', 'ing', 'ong',
|
| 15 |
+
'zhi', 'chi', 'shi', 'ri', 'zi', 'ci', 'si',
|
| 16 |
+
'yi', 'wu', 'yu', 'yin', 'yun', 'ye', 'yue', 'yuan','ying']
|
| 17 |
+
|
| 18 |
+
def get_yunmu(syllable):
|
| 19 |
+
syllable = syllable.lower().replace('ü', 'v')
|
| 20 |
+
yunmu_list = sorted(YUNMU_LIST, key=lambda x: -len(x))
|
| 21 |
+
|
| 22 |
+
if syllable in yunmu_list:
|
| 23 |
+
return syllable
|
| 24 |
+
|
| 25 |
+
shengmus = [
|
| 26 |
+
'zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h',
|
| 27 |
+
'j', 'q', 'x', 'z', 'c', 's', 'r', 'y', 'w'
|
| 28 |
+
]
|
| 29 |
+
for shengmu in sorted(shengmus, key=lambda x: -len(x)):
|
| 30 |
+
if syllable.startswith(shengmu):
|
| 31 |
+
possible_yunmu = syllable[len(shengmu):]
|
| 32 |
+
|
| 33 |
+
for yunmu in yunmu_list:
|
| 34 |
+
if possible_yunmu == yunmu:
|
| 35 |
+
return yunmu
|
| 36 |
+
|
| 37 |
+
if shengmu in ['j', 'q', 'x', 'y'] and possible_yunmu.startswith('u'):
|
| 38 |
+
|
| 39 |
+
possible_yunmu_v = 'v' + possible_yunmu[1:]
|
| 40 |
+
for yunmu in yunmu_list:
|
| 41 |
+
if possible_yunmu_v == yunmu:
|
| 42 |
+
return yunmu
|
| 43 |
+
|
| 44 |
+
if shengmu == 'y':
|
| 45 |
+
y_map = {
|
| 46 |
+
'u': 'yu',
|
| 47 |
+
'ue': 'yue',
|
| 48 |
+
'uan': 'yuan',
|
| 49 |
+
'un': 'yun',
|
| 50 |
+
'i': 'yi',
|
| 51 |
+
'in': 'yin',
|
| 52 |
+
'ing': 'ying',
|
| 53 |
+
'e': 'ye'
|
| 54 |
+
}
|
| 55 |
+
if possible_yunmu in y_map:
|
| 56 |
+
return y_map[possible_yunmu]
|
| 57 |
+
|
| 58 |
+
if shengmu == 'w' and possible_yunmu == 'u':
|
| 59 |
+
return 'wu'
|
| 60 |
+
|
| 61 |
+
if shengmu == 'y' and possible_yunmu == 'i':
|
| 62 |
+
return 'yi'
|
| 63 |
+
|
| 64 |
+
if shengmu == 'y' and possible_yunmu == 'v':
|
| 65 |
+
return 'yu'
|
| 66 |
+
|
| 67 |
+
if possible_yunmu.startswith('v'):
|
| 68 |
+
for yunmu in yunmu_list:
|
| 69 |
+
if possible_yunmu == yunmu:
|
| 70 |
+
return yunmu
|
| 71 |
+
for yunmu in yunmu_list:
|
| 72 |
+
if syllable == yunmu:
|
| 73 |
+
return yunmu
|
| 74 |
+
for yunmu in yunmu_list:
|
| 75 |
+
if syllable.endswith(yunmu):
|
| 76 |
+
return yunmu
|
| 77 |
+
return None
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def print_results(rhymer, text, target_rhyme, top_results=8, beam_width=20, num_candidates=5000):
|
| 81 |
+
out = rhymer.get_rhymes(text, target_rhyme, beam_width=beam_width, num_candidates=num_candidates)
|
| 82 |
+
mask_count = text.count("[M]")
|
| 83 |
+
context = text.split('[M]')[0]
|
| 84 |
+
|
| 85 |
+
print(f"======= 韵脚: |{target_rhyme}|")
|
| 86 |
+
for i, (seq, log_prob) in enumerate(out[:top_results]):
|
| 87 |
+
rhymes = seq[-mask_count:].split()
|
| 88 |
+
colored_rhymes = [Fore.RED + part + Style.RESET_ALL if idx < mask_count else part for idx, part in enumerate(rhymes)]
|
| 89 |
+
colored_rhymes = ''.join(colored_rhymes) # Join the parts back together
|
| 90 |
+
|
| 91 |
+
print(f"{i+1}. {context}{colored_rhymes} (score: {log_prob:.3f})")
|
| 92 |
+
print("=" + "=" * 40)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
syllable_to_yunmu = defaultdict(str)
|
| 97 |
+
for syllable in ALL_SYLLABLES:
|
| 98 |
+
yunmu = get_yunmu(syllable)
|
| 99 |
+
if yunmu:
|
| 100 |
+
syllable_to_yunmu[syllable] = yunmu
|
| 101 |
+
|
| 102 |
+
with open('rules/syllable_to_yunmu.json', 'w', encoding='utf-8') as f:
|
| 103 |
+
json.dump(syllable_to_yunmu, f, ensure_ascii=False, indent=4)
|