Spaces:
Build error
Build error
| import torch, argparse | |
| from commonsense_model import CommonsenseGRUModel | |
| from dataloader import RobertaCometDataset | |
| from torch.utils.data import DataLoader | |
| def load_model(model_path, args): | |
| emo_gru = True | |
| n_classes = 15 | |
| cuda = args.cuda | |
| D_m = 1024 | |
| D_s = 768 | |
| D_g = 150 | |
| D_p = 150 | |
| D_r = 150 | |
| D_i = 150 | |
| D_h = 100 | |
| D_a = 100 | |
| D_e = D_p + D_r + D_i | |
| model = CommonsenseGRUModel( | |
| D_m, | |
| D_s, | |
| D_g, | |
| D_p, | |
| D_r, | |
| D_i, | |
| D_e, | |
| D_h, | |
| D_a, | |
| n_classes=n_classes, | |
| listener_state=args.active_listener, | |
| context_attention=args.attention, | |
| dropout_rec=args.rec_dropout, | |
| dropout=args.dropout, | |
| emo_gru=emo_gru, | |
| mode1=args.mode1, | |
| norm=args.norm, | |
| residual=args.residual, | |
| ) | |
| if cuda: | |
| model.cuda() | |
| model.load_state_dict(torch.load(model_path)) | |
| else: | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) | |
| model.eval() | |
| return model | |
| def get_valid_dataloader( | |
| roberta_features_path: str, | |
| comet_features_path: str, | |
| batch_size=1, | |
| num_workers=0, | |
| pin_memory=False, | |
| ): | |
| valid_set = RobertaCometDataset("valid", roberta_features_path, comet_features_path) | |
| test_loader = DataLoader( | |
| valid_set, | |
| batch_size=batch_size, | |
| collate_fn=valid_set.collate_fn, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| ) | |
| return test_loader, valid_set.keys | |
| def predict(model, data_loader, args): | |
| predictions = [] | |
| for data in data_loader: | |
| r1, r2, r3, r4, x1, x2, x3, x4, x5, x6, o1, o2, o3, qmask, umask, label = ( | |
| [d.cuda() for d in data[:-1]] if args.cuda else data[:-1] | |
| ) | |
| log_prob, _, alpha, alpha_f, alpha_b, _ = model( | |
| r1, r2, r3, r4, x5, x6, x1, o2, o3, qmask, umask | |
| ) | |
| lp_ = log_prob.transpose(0, 1).contiguous().view(-1, log_prob.size()[2]) | |
| preds = torch.argmax(lp_, dim=-1) | |
| predictions.append(preds.data.cpu().numpy()) | |
| return predictions | |
| def parse_cosmic_args(): | |
| parser = argparse.ArgumentParser() | |
| # Parse arguments input into the cosmic model | |
| parser.add_argument( | |
| "--no-cuda", action="store_true", default=True, help="does not use GPU" | |
| ) | |
| parser.add_argument( | |
| "--lr", type=float, default=0.0001, metavar="LR", help="learning rate" | |
| ) | |
| parser.add_argument( | |
| "--l2", | |
| type=float, | |
| default=0.00003, | |
| metavar="L2", | |
| help="L2 regularization weight", | |
| ) | |
| parser.add_argument( | |
| "--rec-dropout", | |
| type=float, | |
| default=0.3, | |
| metavar="rec_dropout", | |
| help="rec_dropout rate", | |
| ) | |
| parser.add_argument( | |
| "--dropout", type=float, default=0.5, metavar="dropout", help="dropout rate" | |
| ) | |
| parser.add_argument( | |
| "--batch-size", type=int, default=1, metavar="BS", help="batch size" | |
| ) | |
| parser.add_argument( | |
| "--epochs", type=int, default=10, metavar="E", help="number of epochs" | |
| ) | |
| parser.add_argument( | |
| "--class-weight", action="store_true", default=True, help="use class weights" | |
| ) | |
| parser.add_argument( | |
| "--active-listener", action="store_true", default=True, help="active listener" | |
| ) | |
| parser.add_argument( | |
| "--attention", default="simple", help="Attention type in context GRU" | |
| ) | |
| parser.add_argument( | |
| "--tensorboard", | |
| action="store_true", | |
| default=False, | |
| help="Enables tensorboard log", | |
| ) | |
| parser.add_argument("--mode1", type=int, default=2, help="Roberta features to use") | |
| parser.add_argument("--seed", type=int, default=500, metavar="seed", help="seed") | |
| parser.add_argument("--norm", type=int, default=0, help="normalization strategy") | |
| parser.add_argument("--mu", type=float, default=0, help="class_weight_mu") | |
| parser.add_argument( | |
| "--residual", action="store_true", default=True, help="use residual connection" | |
| ) | |
| args = parser.parse_args() | |
| args.cuda = torch.cuda.is_available() and not args.no_cuda | |
| if args.cuda: | |
| print("Running on GPU") | |
| else: | |
| print("Running on CPU") | |
| return args | |
| if __name__ == "__main__": | |
| def pred_to_labels(preds): | |
| mapped_predictions = [] | |
| for pred in preds: | |
| # map the prediction for each conversation | |
| mapped_labels = [] | |
| for label in pred: | |
| mapped_labels.append(label_mapping[label]) | |
| mapped_predictions.append(mapped_labels) | |
| # return the mapped labels for each conversation | |
| return mapped_predictions | |
| label_mapping = { | |
| 0: "Curiosity", | |
| 1: "Obscene", | |
| 2: "Informative", | |
| 3: "Openness", | |
| 4: "Acceptance", | |
| 5: "Interest", | |
| 6: "Greeting", | |
| 7: "Disapproval", | |
| 8: "Denial", | |
| 9: "Anxious", | |
| 10: "Uninterested", | |
| 11: "Remorse", | |
| 12: "Confused", | |
| 13: "Accusatory", | |
| 14: "Annoyed", | |
| } | |
| args = parse_cosmic_args() | |
| model = load_model("epik/best_model.pt", args) | |
| test_dataloader, ids = get_valid_dataloader() | |
| predicted_labels = pred_to_labels(predict(model, test_dataloader, args)) | |
| for id, labels in zip(ids, predicted_labels): | |
| print(f"Conversation ID: {id}") | |
| print(f"Predicted Sentiment Labels: {labels}") | |
| print(len(labels)) | |