Spaces:
Build error
Build error
| import os | |
| from tqdm import tqdm | |
| from nltk import tokenize | |
| import numpy as np | |
| import pickle, torch | |
| import comet.src.data.data as data | |
| import comet.src.data.config as cfg | |
| import comet.src.models.utils as model_utils | |
| import comet.src.interactive.functions as interactive | |
| class CSKFeatureExtractor: | |
| def __init__(self, dir=".", device=0): | |
| super(CSKFeatureExtractor, self).__init__() | |
| model_file = os.path.join( | |
| dir, "comet/pretrained_models/atomic_pretrained_model.pickle" | |
| ) | |
| sampling_algorithm = "beam-5" | |
| category = "all" | |
| opt, state_dict = interactive.load_model_file(model_file) | |
| data_loader, text_encoder = interactive.load_data("atomic", opt, dir) | |
| self.opt = opt | |
| self.data_loader = data_loader | |
| self.text_encoder = text_encoder | |
| n_ctx = data_loader.max_event + data_loader.max_effect | |
| n_vocab = len(text_encoder.encoder) + n_ctx | |
| self.model = interactive.make_model(opt, n_vocab, n_ctx, state_dict) | |
| self.model.eval() | |
| if device != "cpu": | |
| cfg.device = int(device) | |
| cfg.do_gpu = True | |
| torch.cuda.set_device(cfg.device) | |
| self.model.cuda(cfg.device) | |
| else: | |
| cfg.device = "cpu" | |
| def set_atomic_inputs(self, input_event, category, data_loader, text_encoder): | |
| XMB = torch.zeros(1, data_loader.max_event + 1).long().to(cfg.device) | |
| prefix, suffix = data.atomic_data.do_example( | |
| text_encoder, input_event, None, True, None | |
| ) | |
| if len(prefix) > data_loader.max_event + 1: | |
| prefix = prefix[: data_loader.max_event + 1] | |
| XMB[:, : len(prefix)] = torch.LongTensor(prefix) | |
| XMB[:, -1] = torch.LongTensor([text_encoder.encoder["<{}>".format(category)]]) | |
| batch = {} | |
| batch["sequences"] = XMB | |
| batch["attention_mask"] = data.atomic_data.make_attention_mask(XMB) | |
| return batch | |
| def extract(self, sentence): | |
| atomic_keys = [ | |
| "xIntent", | |
| "xAttr", | |
| "xNeed", | |
| "xWant", | |
| "xEffect", | |
| "xReact", | |
| "oWant", | |
| "oEffect", | |
| "oReact", | |
| ] | |
| map1 = [{}, {}, {}, {}, {}, {}, {}, {}, {}] | |
| all_keys = list(sentence.keys()) | |
| for i in tqdm(range(len(all_keys))): | |
| item = all_keys[i] | |
| list1 = [[], [], [], [], [], [], [], [], []] | |
| for x in sentence[item]: | |
| input_event = x.encode("ascii", errors="ignore").decode("utf-8") | |
| m1 = [] | |
| for sent in tokenize.sent_tokenize(input_event): | |
| seqs = [] | |
| masks = [] | |
| for category in atomic_keys: | |
| batch = self.set_atomic_inputs( | |
| sent, category, self.data_loader, self.text_encoder | |
| ) | |
| seqs.append(batch["sequences"]) | |
| masks.append(batch["attention_mask"]) | |
| XMB = torch.cat(seqs) | |
| MMB = torch.cat(masks) | |
| XMB = model_utils.prepare_position_embeddings( | |
| self.opt, self.data_loader.vocab_encoder, XMB.unsqueeze(-1) | |
| ) | |
| h, _ = self.model(XMB.unsqueeze(1), sequence_mask=MMB) | |
| last_index = MMB[0][:-1].nonzero()[-1].cpu().numpy()[0] + 1 | |
| m1.append(h[:, -1, :].detach().cpu().numpy()) | |
| m1 = np.mean(np.array(m1), axis=0) | |
| for k, l1 in enumerate(list1): | |
| l1.append(m1[k]) | |
| for k, v1 in enumerate(map1): | |
| v1[item] = list1[k] | |
| return map1 | |