lingyu98 commited on
Commit
77ad230
·
verified ·
1 Parent(s): f4134b8

Create utils.py

Browse files
Files changed (1) hide show
  1. cijiang/utils.py +103 -0
cijiang/utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import json
3
+ from colorama import Fore, Style, init
4
+
5
+ init(autoreset=True)
6
+
7
+ with open('rules/ALL_SYLLABLES.txt', 'r', encoding='utf-8') as f:
8
+ ALL_SYLLABLES = f.read().strip().split()
9
+ ALL_SYLLABLES = [syllable for syllable in ALL_SYLLABLES if syllable]
10
+
11
+ YUNMU_LIST = ['a', 'o', 'e', 'i', 'u', 'v',
12
+ 'ai', 'ei', 'ao', 'ou', 'ia', 'ie', 'iao', 'iu', 'ua', 'uo', 'uai', 'ui', 've',
13
+ 'an', 'en', 'in', 'un', 'vn', 'ian', 'uan', 'vuan',
14
+ 'ang', 'eng', 'ing', 'ong',
15
+ 'zhi', 'chi', 'shi', 'ri', 'zi', 'ci', 'si',
16
+ 'yi', 'wu', 'yu', 'yin', 'yun', 'ye', 'yue', 'yuan','ying']
17
+
18
+ def get_yunmu(syllable):
19
+ syllable = syllable.lower().replace('ü', 'v')
20
+ yunmu_list = sorted(YUNMU_LIST, key=lambda x: -len(x))
21
+
22
+ if syllable in yunmu_list:
23
+ return syllable
24
+
25
+ shengmus = [
26
+ 'zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h',
27
+ 'j', 'q', 'x', 'z', 'c', 's', 'r', 'y', 'w'
28
+ ]
29
+ for shengmu in sorted(shengmus, key=lambda x: -len(x)):
30
+ if syllable.startswith(shengmu):
31
+ possible_yunmu = syllable[len(shengmu):]
32
+
33
+ for yunmu in yunmu_list:
34
+ if possible_yunmu == yunmu:
35
+ return yunmu
36
+
37
+ if shengmu in ['j', 'q', 'x', 'y'] and possible_yunmu.startswith('u'):
38
+
39
+ possible_yunmu_v = 'v' + possible_yunmu[1:]
40
+ for yunmu in yunmu_list:
41
+ if possible_yunmu_v == yunmu:
42
+ return yunmu
43
+
44
+ if shengmu == 'y':
45
+ y_map = {
46
+ 'u': 'yu',
47
+ 'ue': 'yue',
48
+ 'uan': 'yuan',
49
+ 'un': 'yun',
50
+ 'i': 'yi',
51
+ 'in': 'yin',
52
+ 'ing': 'ying',
53
+ 'e': 'ye'
54
+ }
55
+ if possible_yunmu in y_map:
56
+ return y_map[possible_yunmu]
57
+
58
+ if shengmu == 'w' and possible_yunmu == 'u':
59
+ return 'wu'
60
+
61
+ if shengmu == 'y' and possible_yunmu == 'i':
62
+ return 'yi'
63
+
64
+ if shengmu == 'y' and possible_yunmu == 'v':
65
+ return 'yu'
66
+
67
+ if possible_yunmu.startswith('v'):
68
+ for yunmu in yunmu_list:
69
+ if possible_yunmu == yunmu:
70
+ return yunmu
71
+ for yunmu in yunmu_list:
72
+ if syllable == yunmu:
73
+ return yunmu
74
+ for yunmu in yunmu_list:
75
+ if syllable.endswith(yunmu):
76
+ return yunmu
77
+ return None
78
+
79
+
80
+ def print_results(rhymer, text, target_rhyme, top_results=8, beam_width=20, num_candidates=5000):
81
+ out = rhymer.get_rhymes(text, target_rhyme, beam_width=beam_width, num_candidates=num_candidates)
82
+ mask_count = text.count("[M]")
83
+ context = text.split('[M]')[0]
84
+
85
+ print(f"======= 韵脚: |{target_rhyme}|")
86
+ for i, (seq, log_prob) in enumerate(out[:top_results]):
87
+ rhymes = seq[-mask_count:].split()
88
+ colored_rhymes = [Fore.RED + part + Style.RESET_ALL if idx < mask_count else part for idx, part in enumerate(rhymes)]
89
+ colored_rhymes = ''.join(colored_rhymes) # Join the parts back together
90
+
91
+ print(f"{i+1}. {context}{colored_rhymes} (score: {log_prob:.3f})")
92
+ print("=" + "=" * 40)
93
+
94
+
95
+ if __name__ == "__main__":
96
+ syllable_to_yunmu = defaultdict(str)
97
+ for syllable in ALL_SYLLABLES:
98
+ yunmu = get_yunmu(syllable)
99
+ if yunmu:
100
+ syllable_to_yunmu[syllable] = yunmu
101
+
102
+ with open('rules/syllable_to_yunmu.json', 'w', encoding='utf-8') as f:
103
+ json.dump(syllable_to_yunmu, f, ensure_ascii=False, indent=4)