Spaces:
Runtime error
Runtime error
| import torch | |
| import tqdm | |
| import numpy as np | |
| import nltk | |
| from utils import DEVICE, FeatureExtractor, HWT, MGT | |
| from roberta_model_loader import roberta_model | |
| from meta_train import net | |
| from data_loader import load_HC3, filter_data | |
| feature_extractor = FeatureExtractor(roberta_model, net) | |
| target = HWT | |
| # load target data | |
| data_o = load_HC3() | |
| data = filter_data(data_o) | |
| data = data[target] | |
| # print(data[:3]) | |
| # split with nltk | |
| nltk.download("punkt", quiet=True) | |
| nltk.download("punkt_tab", quiet=True) | |
| paragraphs = [nltk.sent_tokenize(paragraph)[1:-1] for paragraph in data] | |
| data = [sent for paragraph in paragraphs for sent in paragraph if 5 < len(sent.split())] | |
| # print(data[:3]) | |
| # extract features | |
| feature_ref = [] | |
| for i in tqdm.tqdm(range(2000), desc=f"Generating feature ref for {target}"): | |
| feature_ref.append( | |
| feature_extractor.process(data[i], False).detach() | |
| ) # detach to save memory | |
| torch.save(torch.cat(feature_ref, dim=0), f"feature_ref_{target}.pt") | |