Spaces:
Runtime error
Runtime error
| from collections import Counter | |
| from typing import Union | |
| from dataclasses import make_dataclass, field | |
| from transformers import T5Config | |
| import ctypes | |
| import os | |
| import platform | |
| import re | |
| import torch | |
| from datasketch import MinHash, MinHashLSH | |
| from collections import defaultdict | |
| from transformers.trainer_callback import TrainerControl, TrainerState | |
| from transformers import TrainingArguments, TrainerCallback | |
| # from nltk import ngrams | |
| from nltk.translate.bleu_score import sentence_bleu | |
| import numpy as np | |
| import ujson | |
| from config import T5ModelConfig | |
| # 结束标点符号 | |
| END_PUN = set(".。!!))》}】??\"”") | |
| class MyTrainerCallback(TrainerCallback): | |
| log_cnt = 0 | |
| def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | |
| ''' | |
| 在打印 n 次日志后清除cuda缓存,适合低显存设备,能防止OOM | |
| ''' | |
| self.log_cnt += 1 | |
| if self.log_cnt % 2 == 0: | |
| torch.cuda.empty_cache() | |
| def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | |
| ''' | |
| 在 on_epoch_end 时保存一次模型。 | |
| TrainingArguments的 save_strategy 中 epoch 和 steps 不兼容。要实现每隔 save_steps 步保存一次检查点,考虑到磁盘空间大小,最多只保存最近N个检查点。 | |
| ''' | |
| # 设置should_save=True并返回即可 | |
| control.should_save = True | |
| return control | |
| # 保留中文和英文、下划线,不要标点符号 | |
| NON_CHAR = re.compile("[^[\u4E00-\u9FA5|A-Za-z_0-9]") | |
| def _get_doc_mini_hash(doc: list[str] | str, num_perm: int) -> MinHash: | |
| ''' | |
| 获取一段文本的mini hash | |
| ''' | |
| mini_hash = MinHash(num_perm=num_perm) | |
| for s in doc: | |
| mini_hash.update(s.encode('utf-8')) | |
| return mini_hash | |
| class DropDatasetDuplicate: | |
| def __init__(self, threshold: float=0.85, num_perm: int=256) -> None: | |
| ''' | |
| 获取一个数据集中所有重复(相似的超过threshold)的index,输入为:list[str],一个str元素为一段文本(doc) | |
| 如输入: [a, b, c, d, c, d, e] 返回:{4, 5} (后面两个 c, d 的index) | |
| ''' | |
| self.similar_index_cluster = defaultdict(set) | |
| self.data_lsh = MinHashLSH(threshold=threshold, num_perm=num_perm) | |
| self.num_perm = num_perm | |
| def add_doc(self, index: object, doc: str,) -> set[int]: | |
| ''' | |
| 添加文档, | |
| index: 文档的索引 | |
| doc: 文档本身 | |
| ''' | |
| # 只保留中文和英文、下划线,不要标点符号 | |
| doc = ''.join(NON_CHAR.split(doc)) | |
| # doc = [''.join(t) for t in list(ngrams(doc, 3))] | |
| doc_hash = _get_doc_mini_hash(doc, self.num_perm) | |
| close_duplicates = self.data_lsh.query(doc_hash) | |
| self.data_lsh.insert(index, doc_hash) | |
| # 所有相似的doc在similar_index_cluster中的key都是最早出现的idx | |
| # 如:data中索引inndex 2, 7, 8, 9, 10, 12 是相似的,则在similar_index_cluster中表现为 {2: {8, 9, 10, 12}} | |
| if len(close_duplicates) > 0: | |
| min_idx= min(close_duplicates) | |
| self.similar_index_cluster[min_idx].add(index) | |
| def get_duplicate_indexs(self): | |
| ''' | |
| 返回所有的重复文档索引 | |
| ''' | |
| similar_index_cluster = self.similar_index_cluster | |
| need_to_remove_idx = set() | |
| for key_idx in similar_index_cluster.keys(): | |
| need_to_remove_idx |= similar_index_cluster[key_idx] | |
| return need_to_remove_idx | |
| def get_T5_config(config: T5ModelConfig, vocab_size: int, decoder_start_token_id: int=0, eos_token_id: int=1) -> T5Config: | |
| ''' | |
| 用户配置转换为T5Config | |
| ''' | |
| t5_config = T5Config() | |
| # t5_config.model_type = 'TextToTextModel' | |
| # 初始化 | |
| t5_config.d_ff = config.d_ff | |
| t5_config.d_kv = config.d_kv | |
| t5_config.d_model = config.d_model | |
| t5_config.num_decoder_layers = config.num_decoder_layers | |
| t5_config.num_heads = config.num_heads | |
| t5_config.num_layers = config.num_layers | |
| t5_config.vocab_size = vocab_size | |
| t5_config.decoder_start_token_id = decoder_start_token_id | |
| t5_config.eos_token_id = eos_token_id | |
| return t5_config | |
| def f1_p_r_compute(spo_list_pred: list, spo_list_true: list, repair: bool=False): | |
| ''' | |
| spo_list: [ [(s,p,o)...], [(s,p,o)]], 每一行[(s,p,o)...]为一个句子中的spo | |
| 计算spo的f1分数,精确率,召回率, | |
| ''' | |
| assert len(spo_list_pred) == len(spo_list_true) | |
| def repair_song_album(spo_list: list, song: list, album: list): | |
| ''' | |
| 修复一条文本的'歌曲'和'专辑'的spo。对于歌曲x(subject)的关系歌手、作词、作曲,x必须同时存在于song和album中 | |
| ''' | |
| if len(song) == 0 and len(album) == 0: | |
| return spo_list | |
| ps = ['歌手', '作词', '作曲'] | |
| new_spo_list = [] | |
| for spo in spo_list: | |
| s, p = spo[0], spo[1] | |
| if p in ps and s in album and s not in song: | |
| continue | |
| new_spo_list.append(spo) | |
| return new_spo_list | |
| def repair_song_album_list(spo_list: list): | |
| ''' | |
| ''' | |
| new_spo_list = [] | |
| for spos in spo_list: | |
| song, album = [], [] | |
| for spo in spos: | |
| s, p, o = spo | |
| if p == '所属专辑': | |
| song.append(s) | |
| album.append(o) | |
| new_spo_list.append(repair_song_album(spos, song, album)) | |
| return new_spo_list | |
| if repair: | |
| spo_list_pred = repair_song_album_list(spo_list_pred) | |
| spo_list_true = repair_song_album_list(spo_list_true) | |
| TP = 1e-10 # 正类判定为正类, A | |
| # TN = 1e-10 # 负类判定为负类 | |
| TP_FP = 1e-10 # 检索到的, A + B | |
| TP_FN = 1e-10 # 真正想要的,A + C | |
| # FP = 1e-10 # 负类判定为正类 | |
| # FN = 1e-10 # 正类判定为负类 | |
| # p = a / (a + b) | |
| # r = a / (a + c) | |
| # f1 = 2pr / (p + r) | |
| for i in range(len(spo_list_true)): | |
| pred_set = set(spo_list_pred[i]) | |
| true_set = set(spo_list_true[i]) | |
| pred_true_set = pred_set & true_set # 预测和真实取交集 | |
| TP += len(pred_true_set) # 检索到且是想要的, A | |
| TP_FP += len(pred_set) # 检索到的,包括想要的和不想要的,A + B | |
| TP_FN += len(true_set) # 真正想要的, 包括检索到和没检索到的,A + C | |
| p = TP / TP_FP | |
| r = TP / TP_FN | |
| f1 = (2 * p * r) / (p + r) | |
| return f1, p, r | |
| def fixed_response(item: str) -> str: | |
| ''' | |
| 修复被截断的回答,从末尾往回找第一个结束标点 | |
| ''' | |
| if len(item) <= 1: return item | |
| if item[-1] in END_PUN: return item | |
| n = len(item) | |
| i = n - 1 | |
| while i > 0 and item[i] not in END_PUN: | |
| i -= 1 | |
| return ''.join(item[0: i + 1]) | |
| def fixed_space(sentence: str)->str: | |
| '''单个空格删除,连续两个空格保留一个 | |
| ''' | |
| n = len(sentence) | |
| new_sentence = [] | |
| i = 0 | |
| while i < n: | |
| word = sentence[i] | |
| if word != ' ': | |
| new_sentence.append(word) | |
| elif i + 1 < n and sentence[i + 1] == ' ': | |
| new_sentence.append(word) | |
| i += 1 # 两个空格保留一个,指针往下走一步 | |
| i += 1 | |
| return ''.join(new_sentence) | |
| def get_free_space_of_disk(folder: str='./') -> float: | |
| ''' | |
| 获取指定目录所在磁盘大小,返回单位: GB | |
| ''' | |
| res_val = 0.0 | |
| if platform.system() == 'Windows': | |
| free_bytes = ctypes.c_ulonglong(0) | |
| ctypes.windll.kernel32.GetDiskFreeSpaceExW(ctypes.c_wchar_p(folder), None, None, ctypes.pointer(free_bytes)) | |
| res_val = free_bytes.value | |
| else: | |
| st = os.statvfs(folder) | |
| res_val = st.f_bavail * st.f_frsize | |
| return res_val / (1024 ** 3) | |
| def my_average(arry_list: list[float]) -> float: | |
| ''' | |
| 自定义均值计算,空数组返回0.0 | |
| ''' | |
| if len(arry_list) == 0: return 0.0 | |
| return np.average(arry_list) | |
| def json_to_dataclass(json_file: str, class_name: str='Config') -> type: | |
| ''' | |
| 将json配置文件转换为dataclass | |
| >>> example: | |
| >>> data_class = json_to_dataclass('my_config.json', 'Config') | |
| >>> my_config = data_class() | |
| >>> assert my_config.name == 'Alice' | |
| >>> my_config.name = 'Bob' | |
| ''' | |
| json_dict = {} | |
| with open(json_file, 'r', encoding='utf-8') as f: | |
| json_dict = ujson.load(f) | |
| # 将dict转换为可迭代的属性名称、属性类型,默认值 | |
| fields_list = [] | |
| for k, v in json_dict.items(): | |
| fields_list.append( (k, type(v), field(default=v)) ) | |
| data_class = make_dataclass(cls_name=class_name, fields=fields_list) | |
| return data_class | |
| def get_path_of_suffix_files(root: str, suffix: str, with_create_time: bool=False) -> list: | |
| ''' | |
| 获取指定目录下下指定后缀的所有文件的绝对路径 | |
| ''' | |
| suffix_files = [] | |
| for root, _, files in os.walk(root): | |
| for file in files: | |
| if file.endswith(suffix): | |
| full_path = '{}/{}'.format(root, file) | |
| if with_create_time: | |
| suffix_files.append( (full_path, os.path.getctime(full_path)) ) | |
| else: | |
| suffix_files.append(full_path) | |
| return suffix_files | |
| def get_bleu4_score(reference: Union[str, list[str]], outputs: Union[str, list[str]], n_gram: int=4) -> float: | |
| ''' | |
| 获取bleu4分数 | |
| ''' | |
| weights = np.ones(n_gram) * (1.0 / n_gram) | |
| outputs_len, reference_len = len(outputs), len(reference) | |
| if not type(reference) is list: | |
| reference = list(reference) | |
| if not type(outputs) is list: | |
| outputs = list(outputs) | |
| outputs_counter = extract_Ngram(outputs, n_gram=n_gram) | |
| reference_counter = extract_Ngram(reference, n_gram=n_gram) | |
| ngram_counter_clip = outputs_counter & reference_counter | |
| clip_counter = np.zeros(n_gram) | |
| output_ngram_counter = np.zeros(n_gram) | |
| for (key, ngram), cnt in ngram_counter_clip.items(): | |
| clip_counter[ngram - 1] += cnt | |
| for (key, ngram), cnt in outputs_counter.items(): | |
| output_ngram_counter[ngram - 1] += cnt | |
| # print(clip_counter, output_ngram_counter) | |
| if np.min(clip_counter) == 0.0: | |
| return np.array(0.0) | |
| precision_scores = clip_counter / output_ngram_counter | |
| # bleu | |
| log_precision_scores = weights * np.log(precision_scores) | |
| # 几何平均形式求平均值然后加权 | |
| geometric_mean = np.exp(np.sum(log_precision_scores)) | |
| brevity_penalty = np.exp(1 - (reference_len / outputs_len)) | |
| # brevity_penalty = 1.0, bleu = sentence_bleu([reference], outputs) | |
| # brevity_penalty = 1.0 | |
| bleu = brevity_penalty * geometric_mean | |
| return bleu | |
| def extract_Ngram(words_list: list[str], n_gram: int) -> tuple: | |
| ''' | |
| 获取一个句子的n_grama | |
| return: | |
| ngram_counter: key = ('w1 w2 ... wn', n_gram), value: count of key | |
| ''' | |
| n = len(words_list) | |
| ngram_counter = Counter() | |
| for i in range(1, n_gram + 1): | |
| for j in range(n - i + 1): | |
| key = ' '.join(words_list[j: j + i]) | |
| ngram_counter[(key, i)] += 1 | |
| return ngram_counter | |
| def save_model_config(config_dict: dict, file: str) -> None: | |
| ''' | |
| 将模型配置写入到json文件, 输入模型保存的目录及文件名 | |
| ''' | |
| # file = file.replace('\\', '/') | |
| # file = '{}/model_config.json'.format('/'.join(file.split('/')[0: -1])) | |
| with open(file, 'w', encoding='utf-8') as f: | |
| ujson.dump(config_dict, f, indent=4, ensure_ascii=False) | |
| if __name__ == '__main__': | |
| ref = '抱歉,我不知道ABB代表什么意思' | |
| out = '我不明白ABB是什么意思' | |
| b1 = sentence_bleu([list(out)], list(ref), weights=(0.25, 0.25, 0.25, 0.25)) | |
| print(b1) | |
| b2 = get_bleu4_score(out, ref) | |
| print(b2) | |
| candidate_corpus = ['i', 'have', 'a', 'pen', 'on', 'my', 'desk', 'a', 'b', 'c', 'd','f','f'] | |
| reference_corpus = ['there', 'is', 'a', 'pen', 'on', 'my', 'desk', 'a', 'b', 'd', 'd', 'fd'] | |
| print('----') | |
| print(sentence_bleu([reference_corpus], candidate_corpus, weights=(0.25, 0.25, 0.25, 0.25))) | |
| print(get_bleu4_score(reference_corpus, candidate_corpus)) |