Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.data import Dataset | |
from model import * | |
class DataCollector(Dataset): | |
def __init__(self, dataset, english_tokenizer, french_tokenizer, max_length=512): | |
self.dataset = dataset | |
self.english_tokenizer = english_tokenizer | |
self.french_tokenizer = french_tokenizer | |
self.max_length = max_length | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, index): | |
english_input_ids = torch.tensor(self.dataset[index]['src_ids']) | |
french_input_ids = torch.tensor(self.dataset[index]['tgt_ids']) | |
# Padder manuellement avec torch.nn.functional.pad ou en utilisant torch.cat | |
src_pad_token = self.english_tokenizer.pad_token_id | |
tgt_pad_token = self.french_tokenizer.pad_token_id | |
# Pour l'anglais | |
if len(english_input_ids) < self.max_length: | |
pad_length = self.max_length - len(english_input_ids) | |
english_input_ids = torch.cat([english_input_ids, torch.full((pad_length,), src_pad_token, dtype=english_input_ids.dtype)]) | |
else: | |
english_input_ids = english_input_ids[:self.max_length] | |
# Pour le français | |
if len(french_input_ids) < self.max_length: | |
pad_length = self.max_length - len(french_input_ids) | |
french_input_ids = torch.cat([french_input_ids, torch.full((pad_length,), tgt_pad_token, dtype=french_input_ids.dtype)]) | |
else: | |
french_input_ids = french_input_ids[:self.max_length] | |
# Créer les masques de padding | |
src_pad_mask = (english_input_ids != src_pad_token) | |
tgt_pad_mask = (french_input_ids != tgt_pad_token) | |
# Pour les tâches de traduction ou LM, on décale la cible | |
input_tgt = french_input_ids[:-1].clone() | |
label_tgt = french_input_ids[1:].clone() | |
input_tgt_mask = (input_tgt != tgt_pad_token) | |
label_tgt[label_tgt == tgt_pad_token] = -100 | |
return { | |
"src_input_ids": english_input_ids, # Taille fixe: (self.max_length,) | |
"src_pad_mask": src_pad_mask, | |
"tgt_input_ids": french_input_ids, # Taille fixe: (self.max_length,) | |
"tgt_pad_mask": torch.cat([input_tgt_mask, torch.full((1,), 0, dtype=french_input_ids.dtype)]), | |
"tgt_labels": torch.cat([label_tgt, torch.full((1,), -100, dtype=french_input_ids.dtype)]) | |
} | |