Spaces:
Build error
Build error
| import time | |
| import torch | |
| import comet.src.evaluate.generate as base_generate | |
| import comet.src.evaluate.sampler as sampling | |
| import comet.utils.utils as utils | |
| import comet.src.data.config as cfg | |
| def make_generator(opt, *args): | |
| return ConceptNetGenerator(opt, *args) | |
| class ConceptNetGenerator(base_generate.Generator): | |
| def __init__(self, opt, model, data_loader): | |
| self.opt = opt | |
| self.model = model | |
| self.data_loader = data_loader | |
| self.sampler = sampling.make_sampler( | |
| opt.eval.sample, opt, data_loader) | |
| def reset_sequences(self): | |
| return [] | |
| def generate(self, split="dev"): | |
| print("Generating Sequences") | |
| # Set evaluation mode | |
| self.model.eval() | |
| # Reset evaluation set for dataset split | |
| self.data_loader.reset_offsets(splits=split, shuffle=False) | |
| start = time.time() | |
| count = 0 | |
| sequences = None | |
| # Reset generated sequence buffer | |
| sequences = self.reset_sequences() | |
| # Initialize progress bar | |
| bar = utils.set_progress_bar( | |
| self.data_loader.total_size[split] / 2) | |
| reset = False | |
| with torch.no_grad(): | |
| # Cycle through development set | |
| while not reset: | |
| start = len(sequences) | |
| # Generate a single batch | |
| reset = self.generate_batch(sequences, split, bs=1) | |
| end = len(sequences) | |
| if not reset: | |
| bar.update(end - start) | |
| else: | |
| print(end) | |
| count += 1 | |
| if cfg.toy and count > 10: | |
| break | |
| if (self.opt.eval.gs != "full" and (count > opt.eval.gs)): | |
| break | |
| torch.cuda.synchronize() | |
| print("{} generations completed in: {} s".format( | |
| split, time.time() - start)) | |
| # Compute scores for sequences (e.g., BLEU, ROUGE) | |
| # Computes scores that the generator is initialized with | |
| # Change define_scorers to add more scorers as possibilities | |
| # avg_scores, indiv_scores = self.compute_sequence_scores( | |
| # sequences, split) | |
| avg_scores, indiv_scores = None, None | |
| return sequences, avg_scores, indiv_scores | |
| def generate_batch(self, sequences, split, verbose=False, bs=1): | |
| # Sample batch from data loader | |
| batch, reset = self.data_loader.sample_batch( | |
| split, bs=bs, cat="positive") | |
| start_idx = self.data_loader.max_e1 + self.data_loader.max_r | |
| max_end_len = self.data_loader.max_e2 | |
| context = batch["sequences"][:, :start_idx] | |
| reference = batch["sequences"][:, start_idx:] | |
| init = "".join([self.data_loader.vocab_decoder[i].replace( | |
| '</w>', ' ') for i in context[:, :self.data_loader.max_e1].squeeze().tolist() if i]).strip() | |
| start = self.data_loader.max_e1 | |
| end = self.data_loader.max_e1 + self.data_loader.max_r | |
| attr = "".join([self.data_loader.vocab_decoder[i].replace( | |
| '</w>', ' ') for i in context[:, start:end].squeeze(0).tolist() if i]).strip() | |
| # Decode sequence | |
| sampling_result = self.sampler.generate_sequence( | |
| batch, self.model, self.data_loader, start_idx, max_end_len) | |
| sampling_result["key"] = batch["key"] | |
| sampling_result["e1"] = init | |
| sampling_result["r"] = attr | |
| sequences.append(sampling_result) | |
| return reset | |