| import pandas as pd | |
| import pytorch_lightning as pl | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import BertTokenizer | |
| from src.utils import get_target_columns | |
| def collate_fn(data): | |
| input_ids = [] | |
| token_type_ids = [] | |
| attention_mask = [] | |
| labels = [] | |
| for item in data: | |
| input_ids.append(item['input_ids'].squeeze()) | |
| token_type_ids.append(item['token_type_ids'].squeeze()) | |
| attention_mask.append(item['attention_mask'].squeeze()) | |
| labels.append(item['labels'].squeeze()) | |
| return { | |
| "input_ids": torch.stack(input_ids), | |
| 'token_type_ids': torch.stack(token_type_ids), | |
| 'attention_mask': torch.stack(attention_mask), | |
| 'labels': torch.stack(labels) | |
| } | |
| class ClassificationDataset(Dataset): | |
| def __init__(self, tokenizer: BertTokenizer, df: pd.DataFrame, config: dict): | |
| self.config = config | |
| self.tokenizer = tokenizer | |
| self.df = df | |
| self.features = self.tokenizer( | |
| text=df.full_text.tolist(), | |
| max_length=self.config['max_length'], | |
| padding=True, | |
| truncation=True, | |
| return_attention_mask=True, | |
| add_special_tokens=True, | |
| return_tensors='pt', | |
| ) | |
| if 'cohesion' in self.df.columns: | |
| self.features['labels'] = torch.as_tensor(df[get_target_columns()].values, dtype=torch.float32) | |
| else: | |
| data = torch.ones(size=(len(df), 6), dtype=torch.float32) * -1. | |
| self.features['labels'] = data | |
| def __getitem__(self, item): | |
| """Returns dict with input_ids, token_type_ids, attention_mask, labels | |
| """ | |
| return { | |
| 'input_ids': self.features['input_ids'][item], | |
| 'token_type_ids': self.features['token_type_ids'][item], | |
| 'attention_mask': self.features['attention_mask'][item], | |
| 'labels': self.features['labels'][item] | |
| } | |
| def __len__(self): | |
| return len(self.df) | |
| class ClassificationDataloader(pl.LightningDataModule): | |
| def __init__( | |
| self, | |
| tokenizer: BertTokenizer, | |
| train_df: pd.DataFrame, | |
| val_df: pd.DataFrame, | |
| config: dict | |
| ): | |
| super().__init__() | |
| self.config = config | |
| self.train_data = ClassificationDataset(tokenizer, train_df, config) | |
| self.val_data = ClassificationDataset(tokenizer, val_df, config) | |
| def train_dataloader(self): | |
| return DataLoader( | |
| dataset=self.train_data, | |
| shuffle=True, | |
| batch_size=self.config['batch_size'], | |
| num_workers=self.config['num_workers'], | |
| collate_fn=collate_fn | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| dataset=self.val_data, | |
| shuffle=False, | |
| batch_size=self.config['batch_size'], | |
| num_workers=self.config['num_workers'], | |
| collate_fn=collate_fn | |
| ) | |