|
|
|
"""
|
|
Created on Sun Sep 15 18:27:17 2024
|
|
|
|
@author: salikha4
|
|
"""
|
|
|
|
import os
|
|
import csv
|
|
import json
|
|
import shutil
|
|
import random
|
|
import argparse
|
|
from datetime import datetime
|
|
import pandas as pd
|
|
import time
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
|
from torch.optim import Adam
|
|
import numpy as np
|
|
import warnings
|
|
warnings.filterwarnings('ignore')
|
|
|
|
def lwm_inference(preprocessed_chs, input_type, lwm_model, device):
|
|
|
|
dataset = prepare_for_lwm(preprocessed_chs, device)
|
|
|
|
lwm_loss, embedding_data = evaluate(lwm_model, dataset)
|
|
|
|
|
|
if input_type == 'cls_emb':
|
|
embedding_data = embedding_data[:, 0]
|
|
elif input_type == 'channel_emb':
|
|
embedding_data = embedding_data[:, 1:]
|
|
|
|
dataset = embedding_data.float()
|
|
return dataset
|
|
|
|
def prepare_for_lwm(data, device, batch_size=64, shuffle=False):
|
|
|
|
input_ids, masked_tokens, masked_pos = zip(*data)
|
|
|
|
input_ids_tensor = torch.tensor(input_ids, device=device).float()
|
|
masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
|
|
masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
|
|
|
|
dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
|
|
|
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
|
|
|
def evaluate(model, dataloader):
|
|
|
|
model.eval()
|
|
running_loss = 0.0
|
|
outputs = []
|
|
criterionMCM = nn.MSELoss()
|
|
|
|
with torch.no_grad():
|
|
for idx, batch in enumerate(dataloader):
|
|
input_ids = batch[0]
|
|
masked_tokens = batch[1]
|
|
masked_pos = batch[2]
|
|
|
|
logits_lm, output = model(input_ids, masked_pos)
|
|
|
|
output_batch_preproc = output
|
|
outputs.append(output_batch_preproc)
|
|
|
|
loss_lm = criterionMCM(logits_lm, masked_tokens)
|
|
loss = loss_lm / torch.var(masked_tokens)
|
|
running_loss += loss.item()
|
|
|
|
average_loss = running_loss / len(dataloader)
|
|
output_total = torch.cat(outputs, dim=0)
|
|
|
|
return average_loss, output_total
|
|
|
|
def create_raw_dataset(data, device):
|
|
"""Create a dataset for raw channel data."""
|
|
input_ids, _, _ = zip(*data)
|
|
input_data = torch.tensor(input_ids, device=device)[:, 1:]
|
|
return input_data.float()
|
|
|