princepride commited on
Commit
5f97682
·
verified ·
1 Parent(s): 0fc9e07

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -343
model.py CHANGED
@@ -7,22 +7,6 @@ from abc import ABC, abstractmethod
7
  from typing import List
8
  import re
9
 
10
- class FilterPipeline():
11
- def __init__(self, filter_list):
12
- self._filter_list:List[Filter] = filter_list
13
-
14
- def append(self, filter):
15
- self._filter_list.append(filter)
16
-
17
- def batch_encoder(self, inputs):from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
18
- import torch
19
- from modules.file import ExcelFileWriter
20
- import os
21
-
22
- from abc import ABC, abstractmethod
23
- from typing import List
24
- import re
25
-
26
  class FilterPipeline():
27
  def __init__(self, filter_list):
28
  self._filter_list:List[Filter] = filter_list
@@ -385,330 +369,3 @@ class Model():
385
  inputs = filter.decoder(inputs)
386
  return inputs
387
 
388
- class Filter(ABC):
389
- def __init__(self):
390
- self.name = 'filter'
391
- self.code = []
392
- @abstractmethod
393
- def encoder(self, inputs):
394
- pass
395
-
396
- @abstractmethod
397
- def decoder(self, inputs):
398
- pass
399
-
400
- class SpecialTokenFilter(Filter):
401
- def __init__(self):
402
- self.name = 'special token filter'
403
- self.code = []
404
- self.special_tokens = ['!', '!']
405
-
406
- def encoder(self, inputs):
407
- filtered_inputs = []
408
- self.code = []
409
- for i, input_str in enumerate(inputs):
410
- if not all(char in self.special_tokens for char in input_str):
411
- filtered_inputs.append(input_str)
412
- else:
413
- self.code.append([i, input_str])
414
- return filtered_inputs
415
-
416
- def decoder(self, inputs):
417
- original_inputs = inputs.copy()
418
- for removed_indice in self.code:
419
- original_inputs.insert(removed_indice[0], removed_indice[1])
420
- return original_inputs
421
-
422
- class SperSignFilter(Filter):
423
- def __init__(self):
424
- self.name = 's persign filter'
425
- self.code = []
426
-
427
- def encoder(self, inputs):
428
- encoded_inputs = []
429
- self.code = [] # 清空 self.code
430
- for i, input_str in enumerate(inputs):
431
- if 's%' in input_str:
432
- encoded_str = input_str.replace('s%', '*')
433
- self.code.append(i) # 将包含 's%' 的字符串的索引存储到 self.code 中
434
- else:
435
- encoded_str = input_str
436
- encoded_inputs.append(encoded_str)
437
- return encoded_inputs
438
-
439
- def decoder(self, inputs):
440
- decoded_inputs = inputs.copy()
441
- for i in self.code:
442
- decoded_inputs[i] = decoded_inputs[i].replace('*', 's%') # 使用 self.code 中的索引还原原始字符串
443
- return decoded_inputs
444
-
445
- class SimilarFilter(Filter):
446
- def __init__(self):
447
- self.name = 'similar filter'
448
- self.code = []
449
-
450
- def is_similar(self, str1, str2):
451
- # 判断两个字符串是否相似(只有数字上有区别)
452
- pattern = re.compile(r'\d+')
453
- return pattern.sub('', str1) == pattern.sub('', str2)
454
-
455
- def encoder(self, inputs):
456
- encoded_inputs = []
457
- self.code = [] # 清空 self.code
458
- i = 0
459
- while i < len(inputs):
460
- encoded_inputs.append(inputs[i])
461
- similar_strs = [inputs[i]]
462
- j = i + 1
463
- while j < len(inputs) and self.is_similar(inputs[i], inputs[j]):
464
- similar_strs.append(inputs[j])
465
- j += 1
466
- if len(similar_strs) > 1:
467
- self.code.append((i, similar_strs)) # 将相似字符串的起始索引和实际字符串列表存储到 self.code 中
468
- i = j
469
- return encoded_inputs
470
-
471
- def decoder(self, inputs):
472
- decoded_inputs = []
473
- index = 0
474
- for i, similar_strs in self.code:
475
- decoded_inputs.extend(inputs[index:i])
476
- decoded_inputs.extend(similar_strs) # 直接将实际的相似字符串添加到 decoded_inputs 中
477
- index = i + 1
478
- decoded_inputs.extend(inputs[index:])
479
- return decoded_inputs
480
-
481
- script_dir = os.path.dirname(os.path.abspath(__file__))
482
- parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir)))
483
-
484
- class Model():
485
- def __init__(self, modelname, selected_lora_model, selected_gpu):
486
- def get_gpu_index(gpu_info, target_gpu_name):
487
- """
488
- 从 GPU 信息中获取目标 GPU 的索引
489
- Args:
490
- gpu_info (list): 包含 GPU 名称的列表
491
- target_gpu_name (str): 目标 GPU 的名称
492
-
493
- Returns:
494
- int: 目标 GPU 的索引,如果未找到则返回 -1
495
- """
496
- for i, name in enumerate(gpu_info):
497
- if target_gpu_name.lower() in name.lower():
498
- return i
499
- return -1
500
- if selected_gpu != "cpu":
501
- gpu_count = torch.cuda.device_count()
502
- gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
503
- selected_gpu_index = get_gpu_index(gpu_info, selected_gpu)
504
- self.device_name = f"cuda:{selected_gpu_index}"
505
- else:
506
- self.device_name = "cpu"
507
- print("device_name", self.device_name)
508
- self.model = AutoModelForSeq2SeqLM.from_pretrained(modelname).to(self.device_name)
509
- self.tokenizer = AutoTokenizer.from_pretrained(modelname)
510
- # self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
511
-
512
- def generate(self, inputs, original_language, target_languages, max_batch_size):
513
- filter_list = [SpecialTokenFilter(), SperSignFilter(), SimilarFilter()]
514
- filter_pipeline = FilterPipeline(filter_list)
515
- def language_mapping(original_language):
516
- d = {
517
- "Achinese (Arabic script)": "ace_Arab",
518
- "Achinese (Latin script)": "ace_Latn",
519
- "Mesopotamian Arabic": "acm_Arab",
520
- "Ta'izzi-Adeni Arabic": "acq_Arab",
521
- "Tunisian Arabic": "aeb_Arab",
522
- "Afrikaans": "afr_Latn",
523
- "South Levantine Arabic": "ajp_Arab",
524
- "Akan": "aka_Latn",
525
- "Amharic": "amh_Ethi",
526
- "North Levantine Arabic": "apc_Arab",
527
- "Standard Arabic": "arb_Arab",
528
- "Najdi Arabic": "ars_Arab",
529
- "Moroccan Arabic": "ary_Arab",
530
- "Egyptian Arabic": "arz_Arab",
531
- "Assamese": "asm_Beng",
532
- "Asturian": "ast_Latn",
533
- "Awadhi": "awa_Deva",
534
- "Central Aymara": "ayr_Latn",
535
- "South Azerbaijani": "azb_Arab",
536
- "North Azerbaijani": "azj_Latn",
537
- "Bashkir": "bak_Cyrl",
538
- "Bambara": "bam_Latn",
539
- "Balinese": "ban_Latn",
540
- "Belarusian": "bel_Cyrl",
541
- "Bemba": "bem_Latn",
542
- "Bengali": "ben_Beng",
543
- "Bhojpuri": "bho_Deva",
544
- "Banjar (Arabic script)": "bjn_Arab",
545
- "Banjar (Latin script)": "bjn_Latn",
546
- "Tibetan": "bod_Tibt",
547
- "Bosnian": "bos_Latn",
548
- "Buginese": "bug_Latn",
549
- "Bulgarian": "bul_Cyrl",
550
- "Catalan": "cat_Latn",
551
- "Cebuano": "ceb_Latn",
552
- "Czech": "ces_Latn",
553
- "Chokwe": "cjk_Latn",
554
- "Central Kurdish": "ckb_Arab",
555
- "Crimean Tatar": "crh_Latn",
556
- "Welsh": "cym_Latn",
557
- "Danish": "dan_Latn",
558
- "German": "deu_Latn",
559
- "Dinka": "dik_Latn",
560
- "Jula": "dyu_Latn",
561
- "Dzongkha": "dzo_Tibt",
562
- "Greek": "ell_Grek",
563
- "English": "eng_Latn",
564
- "Esperanto": "epo_Latn",
565
- "Estonian": "est_Latn",
566
- "Basque": "eus_Latn",
567
- "Ewe": "ewe_Latn",
568
- "Faroese": "fao_Latn",
569
- "Persian": "pes_Arab",
570
- "Fijian": "fij_Latn",
571
- "Finnish": "fin_Latn",
572
- "Fon": "fon_Latn",
573
- "French": "fra_Latn",
574
- "Friulian": "fur_Latn",
575
- "Nigerian Fulfulde": "fuv_Latn",
576
- "Scottish Gaelic": "gla_Latn",
577
- "Irish": "gle_Latn",
578
- "Galician": "glg_Latn",
579
- "Guarani": "grn_Latn",
580
- "Gujarati": "guj_Gujr",
581
- "Haitian Creole": "hat_Latn",
582
- "Hausa": "hau_Latn",
583
- "Hebrew": "heb_Hebr",
584
- "Hindi": "hin_Deva",
585
- "Chhattisgarhi": "hne_Deva",
586
- "Croatian": "hrv_Latn",
587
- "Hungarian": "hun_Latn",
588
- "Armenian": "hye_Armn",
589
- "Igbo": "ibo_Latn",
590
- "Iloko": "ilo_Latn",
591
- "Indonesian": "ind_Latn",
592
- "Icelandic": "isl_Latn",
593
- "Italian": "ita_Latn",
594
- "Javanese": "jav_Latn",
595
- "Japanese": "jpn_Jpan",
596
- "Kabyle": "kab_Latn",
597
- "Kachin": "kac_Latn",
598
- "Arabic": "ar_AR",
599
- "Chinese": "zho_Hans",
600
- "Spanish": "spa_Latn",
601
- "Dutch": "nld_Latn",
602
- "Kazakh": "kaz_Cyrl",
603
- "Korean": "kor_Hang",
604
- "Lithuanian": "lit_Latn",
605
- "Malayalam": "mal_Mlym",
606
- "Marathi": "mar_Deva",
607
- "Nepali": "ne_NP",
608
- "Polish": "pol_Latn",
609
- "Portuguese": "por_Latn",
610
- "Russian": "rus_Cyrl",
611
- "Sinhala": "sin_Sinh",
612
- "Tamil": "tam_Taml",
613
- "Turkish": "tur_Latn",
614
- "Ukrainian": "ukr_Cyrl",
615
- "Urdu": "urd_Arab",
616
- "Vietnamese": "vie_Latn",
617
- "Thai":"tha_Thai"
618
- }
619
- return d[original_language]
620
- def process_gpu_translate_result(temp_outputs):
621
- outputs = []
622
- for temp_output in temp_outputs:
623
- length = len(temp_output[0]["generated_translation"])
624
- for i in range(length):
625
- temp = []
626
- for trans in temp_output:
627
- temp.append({
628
- "target_language": trans["target_language"],
629
- "generated_translation": trans['generated_translation'][i],
630
- })
631
- outputs.append(temp)
632
- excel_writer = ExcelFileWriter()
633
- excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs))
634
- self.tokenizer.src_lang = language_mapping(original_language)
635
- if self.device_name == "cpu":
636
- # Tokenize input
637
- input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name)
638
- output = []
639
- for target_language in target_languages:
640
- # Get language code for the target language
641
- target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
642
- # Generate translation
643
- generated_tokens = self.model.generate(
644
- **input_ids,
645
- forced_bos_token_id=target_lang_code,
646
- max_length=128
647
- )
648
- generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
649
- # Append result to output
650
- output.append({
651
- "target_language": target_language,
652
- "generated_translation": generated_translation,
653
- })
654
- outputs = []
655
- length = len(output[0]["generated_translation"])
656
- for i in range(length):
657
- temp = []
658
- for trans in output:
659
- temp.append({
660
- "target_language": trans["target_language"],
661
- "generated_translation": trans['generated_translation'][i],
662
- })
663
- outputs.append(temp)
664
- return outputs
665
- else:
666
- # 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数)
667
- # max_batch_size = 10
668
- # Ensure batch size is within model limits:
669
- print("length of inputs: ",len(inputs))
670
- batch_size = min(len(inputs), int(max_batch_size))
671
- batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)]
672
- print("length of batches size: ", len(batches))
673
- temp_outputs = []
674
- processed_num = 0
675
- for index, batch in enumerate(batches):
676
- # Tokenize input
677
- batch = filter_pipeline.batch_encoder(batch)
678
- print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
679
- print(batch)
680
- input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name)
681
- temp = []
682
- for target_language in target_languages:
683
- target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
684
- generated_tokens = self.model.generate(
685
- **input_ids,
686
- forced_bos_token_id=target_lang_code,
687
- )
688
- generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
689
- print(generated_translation)
690
- generated_translation = filter_pipeline.batch_decoder(generated_translation)
691
- # Append result to output
692
- temp.append({
693
- "target_language": target_language,
694
- "generated_translation": generated_translation,
695
- })
696
- input_ids.to('cpu')
697
- del input_ids
698
- temp_outputs.append(temp)
699
- processed_num += len(batch)
700
- if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
701
- print("Already processed number: ", len(temp_outputs))
702
- process_gpu_translate_result(temp_outputs)
703
- outputs = []
704
- for temp_output in temp_outputs:
705
- length = len(temp_output[0]["generated_translation"])
706
- for i in range(length):
707
- temp = []
708
- for trans in temp_output:
709
- temp.append({
710
- "target_language": trans["target_language"],
711
- "generated_translation": trans['generated_translation'][i],
712
- })
713
- outputs.append(temp)
714
- return outputs
 
7
  from typing import List
8
  import re
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class FilterPipeline():
11
  def __init__(self, filter_list):
12
  self._filter_list:List[Filter] = filter_list
 
369
  inputs = filter.decoder(inputs)
370
  return inputs
371