Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| import os | |
| import json | |
| from build_vocab import Vocabulary, JsonReader | |
| import numpy as np | |
| from torchvision import transforms | |
| import pickle | |
| class ChestXrayDataSet(Dataset): | |
| def __init__(self, | |
| image_dir, | |
| caption_json, | |
| file_list, | |
| vocabulary, | |
| s_max=10, | |
| n_max=50, | |
| transforms=None): | |
| self.image_dir = image_dir | |
| self.caption = JsonReader(caption_json) | |
| self.file_names, self.labels = self.__load_label_list(file_list) | |
| self.vocab = vocabulary | |
| self.transform = transforms | |
| self.s_max = s_max | |
| self.n_max = n_max | |
| def __load_label_list(self, file_list): | |
| labels = [] | |
| filename_list = [] | |
| with open(file_list, 'r') as f: | |
| for line in f: | |
| items = line.split() | |
| image_name = items[0] | |
| label = items[1:] | |
| label = [int(i) for i in label] | |
| image_name = '{}.png'.format(image_name) | |
| filename_list.append(image_name) | |
| labels.append(label) | |
| return filename_list, labels | |
| def __getitem__(self, index): | |
| image_name = self.file_names[index] | |
| image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB') | |
| label = self.labels[index] | |
| if self.transform is not None: | |
| image = self.transform(image) | |
| try: | |
| text = self.caption[image_name] | |
| except Exception as err: | |
| text = 'normal. ' | |
| target = list() | |
| max_word_num = 0 | |
| for i, sentence in enumerate(text.split('. ')): | |
| if i >= self.s_max: | |
| break | |
| sentence = sentence.split() | |
| if len(sentence) == 0 or len(sentence) == 1 or len(sentence) > self.n_max: | |
| continue | |
| tokens = list() | |
| tokens.append(self.vocab('<start>')) | |
| tokens.extend([self.vocab(token) for token in sentence]) | |
| tokens.append(self.vocab('<end>')) | |
| if max_word_num < len(tokens): | |
| max_word_num = len(tokens) | |
| target.append(tokens) | |
| sentence_num = len(target) | |
| return image, image_name, list(label / np.sum(label)), target, sentence_num, max_word_num | |
| def __len__(self): | |
| return len(self.file_names) | |
| def collate_fn(data): | |
| images, image_id, label, captions, sentence_num, max_word_num = zip(*data) | |
| images = torch.stack(images, 0) | |
| max_sentence_num = max(sentence_num) | |
| max_word_num = max(max_word_num) | |
| targets = np.zeros((len(captions), max_sentence_num + 1, max_word_num)) | |
| prob = np.zeros((len(captions), max_sentence_num + 1)) | |
| for i, caption in enumerate(captions): | |
| for j, sentence in enumerate(caption): | |
| targets[i, j, :len(sentence)] = sentence[:] | |
| prob[i][j] = len(sentence) > 0 | |
| return images, image_id, torch.Tensor(label), targets, prob | |
| def get_loader(image_dir, | |
| caption_json, | |
| file_list, | |
| vocabulary, | |
| transform, | |
| batch_size, | |
| s_max=10, | |
| n_max=50, | |
| shuffle=False): | |
| dataset = ChestXrayDataSet(image_dir=image_dir, | |
| caption_json=caption_json, | |
| file_list=file_list, | |
| vocabulary=vocabulary, | |
| s_max=s_max, | |
| n_max=n_max, | |
| transforms=transform) | |
| data_loader = torch.utils.data.DataLoader(dataset=dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| collate_fn=collate_fn) | |
| return data_loader | |
| if __name__ == '__main__': | |
| vocab_path = '../data/vocab.pkl' | |
| image_dir = '../data/images' | |
| caption_json = '../data/debugging_captions.json' | |
| file_list = '../data/debugging.txt' | |
| batch_size = 6 | |
| resize = 256 | |
| crop_size = 224 | |
| transform = transforms.Compose([ | |
| transforms.Resize(resize), | |
| transforms.RandomCrop(crop_size), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), | |
| (0.229, 0.224, 0.225))]) | |
| with open(vocab_path, 'rb') as f: | |
| vocab = pickle.load(f) | |
| data_loader = get_loader(image_dir=image_dir, | |
| caption_json=caption_json, | |
| file_list=file_list, | |
| vocabulary=vocab, | |
| transform=transform, | |
| batch_size=batch_size, | |
| shuffle=False) | |
| for i, (image, image_id, label, target, prob) in enumerate(data_loader): | |
| print(image.shape) | |
| print(image_id) | |
| print(label) | |
| print(target) | |
| print(prob) | |
| break | |