Spaces:
Runtime error
Runtime error
| import time | |
| import pickle | |
| import torch | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import DataLoader | |
| from torch.autograd import Variable | |
| from PIL import Image | |
| import cv2 | |
| from models import * | |
| from dataset import * | |
| from loss import * | |
| from build_tag import * | |
| from build_vocab import * | |
| class CaptionSampler(object): | |
| def __init__(self): | |
| # Default configuration values | |
| self.args = { | |
| "model_dir": "", | |
| "image_dir": "", | |
| "caption_json": "", | |
| "vocab_path": "vocab.pkl", | |
| "file_lists": "", | |
| "load_model_path": "train_best_loss.pth.tar", | |
| "resize": 224, | |
| "cam_size": 224, | |
| "generate_dir": "cam", | |
| "result_path": "results", | |
| "result_name": "debug", | |
| "momentum": 0.1, | |
| "visual_model_name": "densenet201", | |
| "pretrained": False, | |
| "classes": 210, | |
| "sementic_features_dim": 512, | |
| "k": 10, | |
| "attention_version": "v4", | |
| "embed_size": 512, | |
| "hidden_size": 512, | |
| "sent_version": "v1", | |
| "sentence_num_layers": 2, | |
| "dropout": 0.1, | |
| "word_num_layers": 1, | |
| "s_max": 10, | |
| "n_max": 30, | |
| "batch_size": 8, | |
| "lambda_tag": 10000, | |
| "lambda_stop": 10, | |
| "lambda_word": 1, | |
| "cuda": False # Keep CUDA disabled by default | |
| } | |
| self.vocab = self.__init_vocab() | |
| self.tagger = self.__init_tagger() | |
| self.transform = self.__init_transform() | |
| self.model_state_dict = self.__load_mode_state_dict() | |
| self.extractor = self.__init_visual_extractor() | |
| self.mlc = self.__init_mlc() | |
| self.co_attention = self.__init_co_attention() | |
| self.sentence_model = self.__init_sentence_model() | |
| self.word_model = self.__init_word_word() | |
| self.ce_criterion = self._init_ce_criterion() | |
| self.mse_criterion = self._init_mse_criterion() | |
| def _init_ce_criterion(): | |
| return nn.CrossEntropyLoss(size_average=False, reduce=False) | |
| def _init_mse_criterion(): | |
| return nn.MSELoss() | |
| def sample(self, image_file): | |
| self.extractor.eval() | |
| self.mlc.eval() | |
| self.co_attention.eval() | |
| self.sentence_model.eval() | |
| self.word_model.eval() | |
| imageData = self.transform(image_file) | |
| imageData = imageData.unsqueeze_(0) | |
| image = self.__to_var(imageData, requires_grad=False) | |
| visual_features, avg_features = self.extractor.forward(image) | |
| tags, semantic_features = self.mlc(avg_features) | |
| sentence_states = None | |
| prev_hidden_states = self.__to_var(torch.zeros(image.shape[0], 1, self.args["hidden_size"])) | |
| pred_sentences = [] | |
| for i in range(self.args["s_max"]): | |
| ctx, alpha_v, alpha_a = self.co_attention.forward(avg_features, semantic_features, prev_hidden_states) | |
| topic, p_stop, hidden_state, sentence_states = self.sentence_model.forward(ctx, | |
| prev_hidden_states, | |
| sentence_states) | |
| p_stop = p_stop.squeeze(1) | |
| p_stop = torch.max(p_stop, 1)[1].unsqueeze(1) | |
| start_tokens = np.zeros((topic.shape[0], 1)) | |
| start_tokens[:, 0] = self.vocab('<start>') | |
| start_tokens = self.__to_var(torch.Tensor(start_tokens).long(), requires_grad=False) | |
| sampled_ids = self.word_model.sample(topic, start_tokens) | |
| prev_hidden_states = hidden_state | |
| sampled_ids = sampled_ids * p_stop.numpy() | |
| pred_sentences.append(self.__vec2sent(sampled_ids[0])) | |
| return pred_sentences | |
| def __init_cam_path(self, image_file): | |
| generate_dir = os.path.join(self.args["model_dir"], self.args["generate_dir"]) | |
| if not os.path.exists(generate_dir): | |
| os.makedirs(generate_dir) | |
| image_dir = os.path.join(generate_dir, image_file) | |
| if not os.path.exists(image_dir): | |
| os.makedirs(image_dir) | |
| return image_dir | |
| def __save_json(self, result): | |
| result_path = os.path.join(self.args["model_dir"], self.args["result_path"]) | |
| if not os.path.exists(result_path): | |
| os.makedirs(result_path) | |
| with open(os.path.join(result_path, '{}.json'.format(self.args["result_name"])), 'w') as f: | |
| json.dump(result, f) | |
| def __load_mode_state_dict(self): | |
| try: | |
| model_state_dict = torch.load(os.path.join(self.args["model_dir"], self.args["load_model_path"]), map_location=torch.device('cpu')) | |
| print("[Load Model-{} Succeed!]".format(self.args["load_model_path"])) | |
| print("Load From Epoch {}".format(model_state_dict['epoch'])) | |
| return model_state_dict | |
| except Exception as err: | |
| print("[Load Model Failed] {}".format(err)) | |
| raise err | |
| def __init_tagger(self): | |
| return Tag() | |
| def __vec2sent(self, array): | |
| sampled_caption = [] | |
| for word_id in array: | |
| word = self.vocab.get_word_by_id(word_id) | |
| if word == '<start>': | |
| continue | |
| if word == '<end>' or word == '<pad>': | |
| break | |
| sampled_caption.append(word) | |
| return ' '.join(sampled_caption) | |
| def __init_vocab(self): | |
| with open('vocab.pkl', 'rb') as f: | |
| vocab = pickle.load(f) | |
| print(vocab) | |
| return vocab | |
| def __init_data_loader(self, file_list): | |
| data_loader = get_loader(image_dir=self.args.image_dir, | |
| caption_json=self.args.caption_json, | |
| file_list=file_list, | |
| vocabulary=self.vocab, | |
| transform=self.transform, | |
| batch_size=self.args.batch_size, | |
| s_max=self.args.s_max, | |
| n_max=self.args.n_max, | |
| shuffle=False) | |
| return data_loader | |
| def __init_transform(self): | |
| transform = transforms.Compose([ | |
| transforms.Resize((self.args["resize"], self.args["resize"])), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), | |
| (0.229, 0.224, 0.225))]) | |
| return transform | |
| def __to_var(self, x, requires_grad=True): | |
| if self.args["cuda"]: | |
| x = x.cuda() | |
| return Variable(x, requires_grad=requires_grad) | |
| def __init_visual_extractor(self): | |
| model = VisualFeatureExtractor(model_name=self.args["visual_model_name"], | |
| pretrained=self.args["pretrained"]) | |
| if self.model_state_dict is not None: | |
| print("Visual Extractor Loaded!") | |
| model.load_state_dict(self.model_state_dict['extractor']) | |
| if self.args["cuda"]: | |
| model = model.cuda() | |
| return model | |
| def __init_mlc(self): | |
| model = MLC(classes=self.args["classes"], | |
| sementic_features_dim=self.args["sementic_features_dim"], | |
| fc_in_features=self.extractor.out_features, | |
| k=self.args["k"]) | |
| if self.model_state_dict is not None: | |
| print("MLC Loaded!") | |
| model.load_state_dict(self.model_state_dict['mlc']) | |
| if self.args["cuda"]: | |
| model = model.cuda() | |
| return model | |
| def __init_co_attention(self): | |
| model = CoAttention(version=self.args["attention_version"], | |
| embed_size=self.args["embed_size"], | |
| hidden_size=self.args["hidden_size"], | |
| visual_size=self.extractor.out_features, | |
| k=self.args["k"], | |
| momentum=self.args["momentum"]) | |
| if self.model_state_dict is not None: | |
| print("Co-Attention Loaded!") | |
| model.load_state_dict(self.model_state_dict['co_attention']) | |
| if self.args["cuda"]: | |
| model = model.cuda() | |
| return model | |
| def __init_sentence_model(self): | |
| model = SentenceLSTM(version=self.args["sent_version"], | |
| embed_size=self.args["embed_size"], | |
| hidden_size=self.args["hidden_size"], | |
| num_layers=self.args["sentence_num_layers"], | |
| dropout=self.args["dropout"], | |
| momentum=self.args["momentum"]) | |
| if self.model_state_dict is not None: | |
| print("Sentence Model Loaded!") | |
| model.load_state_dict(self.model_state_dict['sentence_model']) | |
| if self.args["cuda"]: | |
| model = model.cuda() | |
| return model | |
| def __init_word_word(self): | |
| model = WordLSTM(vocab_size=len(self.vocab), | |
| embed_size=self.args["embed_size"], | |
| hidden_size=self.args["hidden_size"], | |
| num_layers=self.args["word_num_layers"], | |
| n_max=self.args["n_max"]) | |
| if self.model_state_dict is not None: | |
| print("Word Model Loaded!") | |
| model.load_state_dict(self.model_state_dict['word_model']) | |
| if self.args["cuda"]: | |
| model = model.cuda() | |
| return model | |
| def main(image): | |
| sampler = CaptionSampler() | |
| # image = 'sample_images/CXR195_IM-0618-1001.png' | |
| caption = sampler.sample(image) | |
| print(caption[0]) | |
| return caption[0] | |