Spaces:
Running
Running
Implement evaluation
Browse files- README.md +22 -4
- dataset.py +33 -0
- eval.py +159 -0
- sample_text/en2es.m2m100_1.2B.json +1 -0
- sample_text/en2es.m2m100_418M.json +1 -0
- sample_text/{en2es.translation.txt → en2es.translation.m2m100_1.2B.txt} +1 -1
- sample_text/en2es.translation.m2m100_418M.txt +0 -0
- translate.py +6 -2
README.md
CHANGED
|
@@ -68,7 +68,7 @@ Run `python translate.py -h` for more info.
|
|
| 68 |
```bash
|
| 69 |
accelerate launch translate.py \
|
| 70 |
--sentences_path sample_text/en.txt \
|
| 71 |
-
--output_path sample_text/en2es.translation.txt \
|
| 72 |
--source_lang en \
|
| 73 |
--target_lang es \
|
| 74 |
--model_name facebook/m2m100_1.2B
|
|
@@ -83,7 +83,7 @@ You can use the Accelerate CLI to configure the Accelerate environment (Run
|
|
| 83 |
```bash
|
| 84 |
accelerate launch --multi_gpu --num_processes 2 --num_machines 1 translate.py \
|
| 85 |
--sentences_path sample_text/en.txt \
|
| 86 |
-
--output_path sample_text/en2es.translation.txt \
|
| 87 |
--source_lang en \
|
| 88 |
--target_lang es \
|
| 89 |
--model_name facebook/m2m100_1.2B
|
|
@@ -102,7 +102,7 @@ Use the `--precision` flag to choose the precision of the model. You can choose
|
|
| 102 |
```bash
|
| 103 |
accelerate launch translate.py \
|
| 104 |
--sentences_path sample_text/en.txt \
|
| 105 |
-
--output_path sample_text/en2es.translation.txt \
|
| 106 |
--source_lang en \
|
| 107 |
--target_lang es \
|
| 108 |
--model_name facebook/m2m100_1.2B \
|
|
@@ -111,6 +111,24 @@ accelerate launch translate.py \
|
|
| 111 |
|
| 112 |
## Evaluate translations
|
| 113 |
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
|
|
|
|
| 68 |
```bash
|
| 69 |
accelerate launch translate.py \
|
| 70 |
--sentences_path sample_text/en.txt \
|
| 71 |
+
--output_path sample_text/en2es.translation.m2m100_1.2B.txt \
|
| 72 |
--source_lang en \
|
| 73 |
--target_lang es \
|
| 74 |
--model_name facebook/m2m100_1.2B
|
|
|
|
| 83 |
```bash
|
| 84 |
accelerate launch --multi_gpu --num_processes 2 --num_machines 1 translate.py \
|
| 85 |
--sentences_path sample_text/en.txt \
|
| 86 |
+
--output_path sample_text/en2es.translation.m2m100_1.2B.txt \
|
| 87 |
--source_lang en \
|
| 88 |
--target_lang es \
|
| 89 |
--model_name facebook/m2m100_1.2B
|
|
|
|
| 102 |
```bash
|
| 103 |
accelerate launch translate.py \
|
| 104 |
--sentences_path sample_text/en.txt \
|
| 105 |
+
--output_path sample_text/en2es.translation.m2m100_1.2B.txt \
|
| 106 |
--source_lang en \
|
| 107 |
--target_lang es \
|
| 108 |
--model_name facebook/m2m100_1.2B \
|
|
|
|
| 111 |
|
| 112 |
## Evaluate translations
|
| 113 |
|
| 114 |
+
To run the evaluation script you need to install [bert_score](https://github.com/Tiiiger/bert_score): `pip install bert_score`
|
| 115 |
+
|
| 116 |
+
The evaluation script will calculate the following metrics:
|
| 117 |
+
* [SacreBLEU](https://github.com/huggingface/datasets/tree/master/metrics/sacrebleu)
|
| 118 |
+
* [BLEU](https://github.com/huggingface/datasets/tree/master/metrics/bleu)
|
| 119 |
+
* [ROUGE](https://github.com/huggingface/datasets/tree/master/metrics/rouge)
|
| 120 |
+
* [METEOR](https://github.com/huggingface/datasets/tree/master/metrics/meteor)
|
| 121 |
+
* [TER](https://github.com/huggingface/datasets/tree/master/metrics/ter)
|
| 122 |
+
* [BertScore](https://github.com/huggingface/datasets/tree/master/metrics/bertscore)
|
| 123 |
+
|
| 124 |
+
Run the following command to evaluate the translations:
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
accelerate launch eval.py \
|
| 128 |
+
--pred_path sample_text/es.txt \
|
| 129 |
+
--gold_path sample_text/en2es.translation.m2m100_1.2B.txt
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
If you want to save the results to a file use the `--output_path` flag.
|
| 133 |
|
| 134 |
|
dataset.py
CHANGED
|
@@ -38,3 +38,36 @@ class DatasetReader(IterableDataset):
|
|
| 38 |
file_itr = open(self.filename, "r")
|
| 39 |
mapped_itr = map(self.preprocess, file_itr)
|
| 40 |
return mapped_itr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
file_itr = open(self.filename, "r")
|
| 39 |
mapped_itr = map(self.preprocess, file_itr)
|
| 40 |
return mapped_itr
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ParallelTextReader(IterableDataset):
|
| 44 |
+
def __init__(self, pred_path: str, gold_path: str):
|
| 45 |
+
self.pred_path = pred_path
|
| 46 |
+
self.gold_path = gold_path
|
| 47 |
+
pref_filename_lines = count_lines(pred_path)
|
| 48 |
+
gold_path_lines = count_lines(gold_path)
|
| 49 |
+
assert pref_filename_lines == gold_path_lines, (
|
| 50 |
+
f"Lines in {pred_path} and {gold_path} do not match "
|
| 51 |
+
f"{pref_filename_lines} vs {gold_path_lines}"
|
| 52 |
+
)
|
| 53 |
+
self.num_sentences = gold_path_lines
|
| 54 |
+
self.current_line = 0
|
| 55 |
+
|
| 56 |
+
def preprocess(self, pred: str, gold: str):
|
| 57 |
+
self.current_line += 1
|
| 58 |
+
pred = pred.rstrip().strip()
|
| 59 |
+
gold = gold.rstrip().strip()
|
| 60 |
+
if len(pred) == 0:
|
| 61 |
+
print(f"Warning: Pred empty sentence at line {self.current_line}")
|
| 62 |
+
if len(gold) == 0:
|
| 63 |
+
print(f"Warning: Gold empty sentence at line {self.current_line}")
|
| 64 |
+
return pred, [gold]
|
| 65 |
+
|
| 66 |
+
def __iter__(self):
|
| 67 |
+
pred_itr = open(self.pred_path, "r")
|
| 68 |
+
gold_itr = open(self.gold_path, "r")
|
| 69 |
+
mapped_itr = map(self.preprocess, pred_itr, gold_itr)
|
| 70 |
+
return mapped_itr
|
| 71 |
+
|
| 72 |
+
def __len__(self):
|
| 73 |
+
return self.num_sentences
|
eval.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataset import ParallelTextReader
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
from accelerate.memory_utils import find_executable_batch_size
|
| 4 |
+
from datasets import load_metric
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import torch
|
| 7 |
+
import json
|
| 8 |
+
import argparse
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_dataloader(pred_path: str, gold_path: str, batch_size: int):
|
| 13 |
+
"""
|
| 14 |
+
Returns a dataloader for the given files.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def collate_fn(batch):
|
| 18 |
+
return list(map(list, zip(*batch)))
|
| 19 |
+
|
| 20 |
+
reader = ParallelTextReader(pred_path=pred_path, gold_path=gold_path)
|
| 21 |
+
dataloader = DataLoader(reader, batch_size=batch_size, collate_fn=collate_fn)
|
| 22 |
+
return dataloader
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def eval_files(
|
| 26 |
+
pred_path: str,
|
| 27 |
+
gold_path: str,
|
| 28 |
+
bert_score_model: str,
|
| 29 |
+
starting_batch_size: int = 128,
|
| 30 |
+
output_path: str = None,
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Evaluates the given files.
|
| 34 |
+
"""
|
| 35 |
+
if torch.cuda.is_available():
|
| 36 |
+
device = "cuda:0"
|
| 37 |
+
print("We will use a GPU to calculate BertScore.")
|
| 38 |
+
else:
|
| 39 |
+
device = "cpu"
|
| 40 |
+
print(
|
| 41 |
+
f"We will use the CPU to calculate BertScore, this can be slow for large datasets."
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
dataloader = get_dataloader(pred_path, gold_path, starting_batch_size)
|
| 45 |
+
print("Loading sacrebleu...")
|
| 46 |
+
sacrebleu = load_metric("sacrebleu")
|
| 47 |
+
print("Loading rouge...")
|
| 48 |
+
rouge = load_metric("rouge")
|
| 49 |
+
print("Loading bleu...")
|
| 50 |
+
bleu = load_metric("bleu")
|
| 51 |
+
print("Loading meteor...")
|
| 52 |
+
meteor = load_metric("meteor")
|
| 53 |
+
print("Loading ter...")
|
| 54 |
+
ter = load_metric("ter")
|
| 55 |
+
print("Loading BertScore...")
|
| 56 |
+
bert_score = load_metric("bertscore")
|
| 57 |
+
|
| 58 |
+
with tqdm(total=len(dataloader.dataset), desc="Loading data...") as pbar:
|
| 59 |
+
for predictions, references in dataloader:
|
| 60 |
+
sacrebleu.add_batch(predictions=predictions, references=references)
|
| 61 |
+
rouge.add_batch(predictions=predictions, references=references)
|
| 62 |
+
bleu.add_batch(
|
| 63 |
+
predictions=[p.split() for p in predictions],
|
| 64 |
+
references=[[r[0].split()] for r in references],
|
| 65 |
+
)
|
| 66 |
+
meteor.add_batch(predictions=predictions, references=references)
|
| 67 |
+
ter.add_batch(predictions=predictions, references=references)
|
| 68 |
+
bert_score.add_batch(predictions=predictions, references=references)
|
| 69 |
+
pbar.update(len(predictions))
|
| 70 |
+
|
| 71 |
+
result_dictionary = {}
|
| 72 |
+
print(f"Computing sacrebleu")
|
| 73 |
+
result_dictionary["sacrebleu"] = sacrebleu.compute()
|
| 74 |
+
print(f"Computing rouge score")
|
| 75 |
+
result_dictionary["rouge"] = rouge.compute()
|
| 76 |
+
print(f"Computing bleu score")
|
| 77 |
+
result_dictionary["bleu"] = bleu.compute()
|
| 78 |
+
print(f"Computing meteor score")
|
| 79 |
+
result_dictionary["meteor"] = meteor.compute()
|
| 80 |
+
print(f"Computing ter score")
|
| 81 |
+
result_dictionary["ter"] = ter.compute()
|
| 82 |
+
|
| 83 |
+
@find_executable_batch_size(starting_batch_size=starting_batch_size)
|
| 84 |
+
def inference(batch_size):
|
| 85 |
+
nonlocal bert_score, bert_score_model
|
| 86 |
+
print(f"Computing bert score with batch size {batch_size} on {device}")
|
| 87 |
+
results = bert_score.compute(
|
| 88 |
+
model_type=bert_score_model,
|
| 89 |
+
batch_size=batch_size,
|
| 90 |
+
device=device,
|
| 91 |
+
use_fast_tokenizer=True,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
results["precision"] = np.average(results["precision"])
|
| 95 |
+
results["recall"] = np.average(results["recall"])
|
| 96 |
+
results["f1"] = np.average(results["f1"])
|
| 97 |
+
|
| 98 |
+
return results
|
| 99 |
+
|
| 100 |
+
result_dictionary["bert_score"] = inference()
|
| 101 |
+
|
| 102 |
+
if output_path is not None:
|
| 103 |
+
with open(output_path, "w") as f:
|
| 104 |
+
json.dump(result_dictionary, f, indent=4)
|
| 105 |
+
|
| 106 |
+
print(f"Results: {json.dumps(result_dictionary,indent=4)}")
|
| 107 |
+
|
| 108 |
+
return result_dictionary
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
parser = argparse.ArgumentParser(
|
| 113 |
+
description="Run the translation evaluation experiments"
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--pred_path",
|
| 117 |
+
type=str,
|
| 118 |
+
required=True,
|
| 119 |
+
help="Path to a txt file containing the predicted sentences.",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--gold_path",
|
| 124 |
+
type=str,
|
| 125 |
+
required=True,
|
| 126 |
+
help="Path to a txt file containing the gold sentences.",
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--starting_batch_size",
|
| 131 |
+
type=int,
|
| 132 |
+
default=64,
|
| 133 |
+
help="Starting batch size for BertScore, we will automatically reduce it if we find an OOM error.",
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--output_path",
|
| 138 |
+
type=str,
|
| 139 |
+
default=None,
|
| 140 |
+
help="Path to a json file to save the results. If not given, the results will be printed to the console.",
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"--bert_score_model",
|
| 145 |
+
type=str,
|
| 146 |
+
default="microsoft/deberta-xlarge-mnli",
|
| 147 |
+
help="Model to use for BertScore. See: https://github.com/huggingface/datasets/tree/master/metrics/bertscore"
|
| 148 |
+
"and https://github.com/Tiiiger/bert_score for more details.",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
args = parser.parse_args()
|
| 152 |
+
|
| 153 |
+
eval_files(
|
| 154 |
+
pred_path=args.pred_path,
|
| 155 |
+
gold_path=args.gold_path,
|
| 156 |
+
starting_batch_size=args.starting_batch_size,
|
| 157 |
+
output_path=args.output_path,
|
| 158 |
+
bert_score_model=args.bert_score_model,
|
| 159 |
+
)
|
sample_text/en2es.m2m100_1.2B.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"sacrebleu": {"score": 32.101150640281695, "counts": [19160, 11392, 7558, 5186], "totals": [31477, 30479, 29481, 28485], "precisions": [60.86984147155066, 37.37655434889596, 25.636850853091822, 18.20607337195015], "bp": 1.0, "sys_len": 31477, "ref_len": 30102}, "rouge": {"rouge1": [[0.5852396804366098, 0.6089057437338691, 0.5919486437026797], [0.5964621218261164, 0.6200342221830797, 0.6029705008756368], [0.6068321807422377, 0.6311106822798185, 0.61324805661008]], "rouge2": [[0.3710985389559613, 0.38708055355385995, 0.3761201217327784], [0.3844850790869714, 0.40017782122170353, 0.38920434271970195], [0.3968990790506025, 0.41382310483690327, 0.4022299418726329]], "rougeL": [[0.5351505034410595, 0.5564838960633809, 0.5410602618870524], [0.5457898501195475, 0.5677049056091881, 0.5519189480892548], [0.5575497491149766, 0.5787856637940312, 0.5630101422167583]], "rougeLsum": [[0.5352116089085267, 0.5570236521823667, 0.5415939934790461], [0.5463246235983789, 0.5676427704754348, 0.5522237812823654], [0.5581141358005033, 0.5796683147249665, 0.5630221371759908]]}, "bleu": {"bleu": 0.2842153038526809, "precisions": [0.5535070989616444, 0.33646946844340314, 0.22383069265549602, 0.15653135365661033], "brevity_penalty": 1.0, "length_ratio": 1.0469217970049918, "translation_length": 28314, "reference_length": 27045}, "meteor": {"meteor": 0.4880039569987408}, "ter": {"score": 59.500831946755405, "num_edits": 16092, "ref_length": 27045.0}, "bert_score": {"precision": 0.8192511852383614, "recall": 0.8262866012752056, "f1": 0.8223477345705033, "hashcode": "microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.11(hug_trans=4.18.0)_fast-tokenizer"}}
|
sample_text/en2es.m2m100_418M.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"sacrebleu": {"score": 29.035496917461597, "counts": [18582, 10514, 6681, 4387], "totals": [31477, 30479, 29481, 28485], "precisions": [59.033580074339994, 34.49588241084025, 22.662053525999795, 15.401088292083553], "bp": 1.0, "sys_len": 31477, "ref_len": 30388}, "rouge": {"rouge1": [[0.5661701202298134, 0.5806961045770566, 0.5693885562082325], [0.5768745925790656, 0.5926959547911554, 0.5803693779677083], [0.5871085218904836, 0.6035331460243276, 0.5900979805085623]], "rouge2": [[0.34243414046469267, 0.35226400857606666, 0.34469210847048837], [0.3545484183384055, 0.36470783370743065, 0.3569058648048812], [0.36612813327517263, 0.37717476449671, 0.3689653665404565]], "rougeL": [[0.5129704896656746, 0.526995889564155, 0.5162056185006965], [0.523632841460358, 0.5375452284094455, 0.5267080806612512], [0.5350158816319085, 0.5480980981777757, 0.5372302857012781]], "rougeLsum": [[0.5126805856827783, 0.5265189554049317, 0.5155154093959223], [0.5239559133309495, 0.5380410013947112, 0.5271022617246641], [0.5351934954578494, 0.5491115103854219, 0.5381174565735956]]}, "bleu": {"bleu": 0.2546886610724999, "precisions": [0.5339761248852158, 0.30784155806120955, 0.19560013678331242, 0.1308640025272469], "brevity_penalty": 1.0, "length_ratio": 1.0353982300884956, "translation_length": 28314, "reference_length": 27346}, "meteor": {"meteor": 0.4630996837124251}, "ter": {"score": 61.848167922182405, "num_edits": 16913, "ref_length": 27346.0}, "bert_score": {"precision": 0.8128398380875588, "recall": 0.8185442119538784, "f1": 0.8153291321396827, "hashcode": "microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.11(hug_trans=4.18.0)_fast-tokenizer"}}
|
sample_text/{en2es.translation.txt → en2es.translation.m2m100_1.2B.txt}
RENAMED
|
@@ -997,4 +997,4 @@ Quiero felicitarle, lamentablemente en su ausencia, por la forma exhaustiva y ri
|
|
| 997 |
Él mencionó anteriormente que el informe se llevó a cabo con una mayoría significativa, pero no con mi apoyo.
|
| 998 |
Por lo tanto, aunque no comparto sus conclusiones, creo que él ha ilustrado en su informe muchas de las cuestiones que la Comisión debe abordar.
|
| 999 |
La primera es la posibilidad de renacentización de la política de competencia.
|
| 1000 |
-
Sé que la Comisión se opone a esto, pero el potencial existe.
|
|
|
|
| 997 |
Él mencionó anteriormente que el informe se llevó a cabo con una mayoría significativa, pero no con mi apoyo.
|
| 998 |
Por lo tanto, aunque no comparto sus conclusiones, creo que él ha ilustrado en su informe muchas de las cuestiones que la Comisión debe abordar.
|
| 999 |
La primera es la posibilidad de renacentización de la política de competencia.
|
| 1000 |
+
Sé que la Comisión se opone a esto, pero el potencial existe.
|
sample_text/en2es.translation.m2m100_418M.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
translate.py
CHANGED
|
@@ -122,6 +122,7 @@ def main(
|
|
| 122 |
total=total_lines, desc="Dataset translation", leave=True, ascii=True
|
| 123 |
) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
|
| 124 |
with torch.no_grad():
|
|
|
|
| 125 |
for batch in data_loader:
|
| 126 |
batch["input_ids"] = batch["input_ids"]
|
| 127 |
batch["attention_mask"] = batch["attention_mask"]
|
|
@@ -141,8 +142,11 @@ def main(
|
|
| 141 |
tgt_text = tokenizer.batch_decode(
|
| 142 |
generated_tokens, skip_special_tokens=True
|
| 143 |
)
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
pbar.update(len(tgt_text))
|
| 148 |
|
|
|
|
| 122 |
total=total_lines, desc="Dataset translation", leave=True, ascii=True
|
| 123 |
) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
|
| 124 |
with torch.no_grad():
|
| 125 |
+
first_batch = True
|
| 126 |
for batch in data_loader:
|
| 127 |
batch["input_ids"] = batch["input_ids"]
|
| 128 |
batch["attention_mask"] = batch["attention_mask"]
|
|
|
|
| 142 |
tgt_text = tokenizer.batch_decode(
|
| 143 |
generated_tokens, skip_special_tokens=True
|
| 144 |
)
|
| 145 |
+
if not first_batch:
|
| 146 |
+
print(file=output_file)
|
| 147 |
+
else:
|
| 148 |
+
first_batch = False
|
| 149 |
+
print("\n".join(tgt_text), file=output_file, end="")
|
| 150 |
|
| 151 |
pbar.update(len(tgt_text))
|
| 152 |
|