lingyu98 commited on
Commit
f4134b8
·
verified ·
1 Parent(s): 99a698e

Create cijiang/rhyme.py

Browse files
Files changed (1) hide show
  1. cijiang/rhyme.py +240 -0
cijiang/rhyme.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from collections import namedtuple
6
+ from typing import List, Tuple, Dict
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from pypinyin import pinyin, Style
9
+
10
+ BeamEntry = namedtuple('BeamEntry', ['sequence', 'log_prob', 'position'])
11
+
12
+ def is_pinyin(syllable):
13
+ """Check if a syllable is a valid pinyin syllable"""
14
+ try:
15
+ syllable.encode('ascii')
16
+ except UnicodeEncodeError:
17
+ return False
18
+ return True
19
+
20
+ class CiJiangRhymer:
21
+ def __init__(self, strict=True, tone=True, heteronym=False):
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ self._load_model()
24
+ self._load_rules()
25
+ self.tone = tone
26
+ self.heteronym = heteronym
27
+ if strict:
28
+ self.mode = 'strict'
29
+ else:
30
+ self.mode = 'blurry'
31
+
32
+ # Pre-compute character mappings for efficiency
33
+ self._build_character_cache()
34
+
35
+ def _load_model(self):
36
+ model_name = "Qwen/Qwen3-0.6B-Base" # Changed to base model
37
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
38
+
39
+ # Add padding token if it doesn't exist
40
+ if self.tokenizer.pad_token is None:
41
+ self.tokenizer.pad_token = self.tokenizer.eos_token
42
+
43
+ self.model = AutoModelForCausalLM.from_pretrained(
44
+ model_name,
45
+ torch_dtype="auto",
46
+ device_map="auto"
47
+ )
48
+ self.model.eval()
49
+ # Note: torch.compile may not work with all versions, comment out if issues
50
+ self.vocab = self.tokenizer.get_vocab()
51
+
52
+ def _load_rules(self):
53
+ with open('rules/syllable_to_yunmu.json', 'r', encoding='utf-8') as f:
54
+ self.syllable_to_yunmu = json.load(f)
55
+
56
+ with open('rules/ALL_SYLLABLES.txt', 'r', encoding='utf-8') as f:
57
+ all_syllables = f.read().strip().split()
58
+ self.all_syllables = [syllable for syllable in all_syllables if syllable]
59
+
60
+ with open('rules/rhymes.json', 'r', encoding='utf-8') as f:
61
+ self.rhymes = json.load(f)
62
+
63
+ def _build_character_cache(self):
64
+ """Pre-compute character to pinyin mappings for all vocabulary tokens"""
65
+ print("Building character cache for faster lookup...")
66
+ self.char_to_pinyins = {}
67
+ self.token_to_char: Dict[int, str] = {}
68
+
69
+ for token_id in tqdm(range(len(self.vocab)), desc="Caching characters"):
70
+ char = self.tokenizer.decode(token_id).strip()
71
+
72
+ if len(char) == 1 and '\u4e00' <= char <= '\u9fff':
73
+ self.token_to_char[token_id] = char
74
+
75
+ # Cache pinyin for this character if not already done
76
+ if char not in self.char_to_pinyins:
77
+ hetero_pinyins = pinyin(char, style=Style.TONE3,
78
+ heteronym=True, neutral_tone_with_five=True)[0]
79
+ pinyins = pinyin(char, style=Style.TONE3,
80
+ heteronym=False, neutral_tone_with_five=True)[0]
81
+ self.char_to_pinyins[char] = {
82
+ "hetero": hetero_pinyins,
83
+ "single": pinyins
84
+ }
85
+
86
+ def _prefilter_tokens_by_rhyme(self, top_tokens: torch.Tensor, top_log_probs: torch.Tensor,
87
+ allowed_rhymes: set, target_tone: str) -> List[Tuple[str, float, int]]:
88
+ """Pre-filter tokens that match rhyming requirements using cached data"""
89
+ matching_candidates = []
90
+
91
+ token_ids = top_tokens.to(torch.float32).cpu().numpy()
92
+ log_probs = top_log_probs.to(torch.float32).cpu().numpy()
93
+
94
+ for i, token_id in enumerate(token_ids):
95
+ char = self.token_to_char.get(int(token_id))
96
+ if char is None:
97
+ continue
98
+
99
+ candidate_pinyins = self.char_to_pinyins[char]["hetero" if self.heteronym else "single"]
100
+
101
+ for candidate_pinyin in candidate_pinyins:
102
+ if len(candidate_pinyin) < 2:
103
+ continue
104
+
105
+ candidate_syllable, candidate_tone = candidate_pinyin[:-1], candidate_pinyin[-1]
106
+ yunmu = self.syllable_to_yunmu.get(candidate_syllable)
107
+
108
+ if self.tone==False: candidate_tone = target_tone # Ignore tone if not required
109
+
110
+ if (yunmu in allowed_rhymes and
111
+ (candidate_tone == target_tone or target_tone == '5' or candidate_tone == '5')):
112
+ matching_candidates.append((char, float(log_probs[i]), int(token_id)))
113
+ break
114
+
115
+ return matching_candidates
116
+
117
+ def _get_next_token_probabilities(self, prompt: str, num_candidates: int = 200) -> Tuple[torch.Tensor, torch.Tensor]:
118
+ """Get probabilities for next token using base model"""
119
+ # Simplified approach for base model - no chat formatting needed
120
+ model_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
121
+
122
+ with torch.inference_mode():
123
+ outputs = self.model(**model_inputs)
124
+
125
+ # Get logits for the next token (last position)
126
+ next_token_logits = outputs.logits[0, -1, :]
127
+
128
+ # Get top candidates
129
+ top_k_result = next_token_logits.topk(min(num_candidates, next_token_logits.size(0)))
130
+ top_tokens = top_k_result.indices
131
+ top_log_probs = torch.log_softmax(next_token_logits, dim=-1)[top_tokens]
132
+
133
+ return top_tokens, top_log_probs
134
+
135
+ def get_rhymes(self, text_with_placeholder: str, target_rhyme: str,
136
+ beam_width: int = 5, num_candidates: int = 200) -> List[Tuple[str, float]]:
137
+ """
138
+ Generate rhyming text using Qwen3 base language model
139
+
140
+ Args:
141
+ text_with_placeholder: Text with placeholders (e.g., "恰似一江春水[M][M][M]")
142
+ target_rhyme: Target rhyme pattern
143
+ beam_width: Number of beams to maintain during search
144
+ num_candidates: Number of top candidates to consider at each step
145
+
146
+ Returns:
147
+ List of (sequence, log_probability) tuples ranked by likelihood
148
+ """
149
+
150
+ if is_pinyin(target_rhyme):
151
+ target_rhyme_pinyin = target_rhyme.split(' ')
152
+ else:
153
+ target_rhyme_pinyin = [pinyin(rhyme, style=Style.TONE3, heteronym=False,
154
+ neutral_tone_with_five=True)[0][0] for rhyme in target_rhyme]
155
+
156
+
157
+ # print(f"Target rhyme pinyin: {target_rhyme_pinyin}")
158
+ # Count placeholders to replace
159
+ placeholder_count = text_with_placeholder.count('[M]')
160
+ if placeholder_count != len(target_rhyme_pinyin):
161
+ print(f"Warning: Number of placeholders ({placeholder_count}) doesn't match target rhyme length ({len(target_rhyme_pinyin)})")
162
+
163
+ # Initialize beam with the original sequence (remove placeholders for now)
164
+ base_text = text_with_placeholder.replace('[M]', '')
165
+ if len(base_text) == 0:
166
+ # add some base text if empty
167
+ base_text = "一个常见词汇是:"
168
+ beam = [BeamEntry(sequence=base_text, log_prob=0.0, position=0)]
169
+
170
+ # Process each character in the target rhyme
171
+ # for i in range(len(target_rhyme_pinyin)):
172
+ for i in tqdm(range(len(target_rhyme_pinyin)), desc="Generating rhymes"):
173
+ new_beam = []
174
+ syl = target_rhyme_pinyin[i]
175
+ syllable, tone = syl[:-1], syl[-1]
176
+ allowed_rhymes = set(self.rhymes.get(self.syllable_to_yunmu.get(syllable, None), {}).get(self.mode, []))
177
+
178
+
179
+ # Process each sequence in current beam
180
+ for beam_entry in beam:
181
+ current_sequence = beam_entry.sequence
182
+ current_log_prob = beam_entry.log_prob
183
+
184
+ # Create prompt for next character (simplified for base model)
185
+ prompt = current_sequence
186
+
187
+ # Get next token probabilities
188
+ try:
189
+ top_tokens, top_log_probs = self._get_next_token_probabilities(prompt, num_candidates)
190
+ except Exception as e:
191
+ print(f"Error getting probabilities: {e}")
192
+ continue
193
+ # print(current_sequence)
194
+ # Use optimized filtering
195
+ matching_candidates = self._prefilter_tokens_by_rhyme(
196
+ top_tokens, top_log_probs, allowed_rhymes, tone
197
+ )
198
+ # print(matching_candidates)
199
+ # Add matching candidates to new beam
200
+ for char, log_prob_value, token_id in matching_candidates:
201
+ new_sequence = current_sequence + char
202
+ new_beam.append(BeamEntry(
203
+ sequence=new_sequence,
204
+ log_prob=current_log_prob + log_prob_value,
205
+ position=i + 1
206
+ ))
207
+
208
+ # Keep only top beam_width candidates
209
+ if new_beam:
210
+ new_beam.sort(key=lambda x: x.log_prob, reverse=True)
211
+ beam = new_beam[:beam_width]
212
+ else:
213
+ print(f"Warning: No valid candidates found for position {i} (syllable: {syl})")
214
+ break
215
+
216
+ # Return final results sorted by probability
217
+ if not beam:
218
+ return []
219
+
220
+ final_results = [(entry.sequence, np.exp(entry.log_prob/10)) for entry in beam]
221
+ final_results.sort(key=lambda x: x[1], reverse=True)
222
+
223
+ return final_results
224
+
225
+ # Example usage:
226
+ if __name__ == "__main__":
227
+ # Initialize the rhymer
228
+ rhymer = CiJiangRhymer(strict=False, tone=True)
229
+
230
+ # Example: Generate rhyming text
231
+ base_text = "��人给你[M][M][M][M]"
232
+ # target_rhyme = "摆摊算命" # Target rhyme pattern
233
+ target_rhyme = "bai3 tan1 suan4 ming4" # Pinyin representation for testing
234
+
235
+
236
+ results = rhymer.get_rhymes(base_text, target_rhyme, beam_width=10, num_candidates=5000)
237
+
238
+ print("Generated rhyming completions:")
239
+ for i, (sequence, prob) in enumerate(results):
240
+ print(f"{i+1}. {sequence} (probability: {prob:.4f})")