In [89]:
from src import BertForSemanticEmbedding, getLabelModel
from src import DataTrainingArguments, ModelArguments, CustomTrainingArguments, read_yaml_config
from src import dataset_classification_type
from src import SemSupDataset
from transformers import AutoConfig, HfArgumentParser, AutoTokenizer
import torch

import json
from tqdm import tqdm

In [2]:
ARGS_FILE = 'configs/ablation_amzn_eda.yml'

In [3]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments))
model_args, data_args, training_args = parser.parse_dict(read_yaml_config(ARGS_FILE, output_dir = 'demo_tmp',  extra_args = {}))

Yaml Config is:
--------------------------------------------------------------------------------
{'task_name': 'amazon13k', 'dataset_name': 'amazon13k', 'dataset_config_name': None, 'max_seq_length': 160, 'overwrite_output_dir': False, 'overwrite_cache': False, 'pad_to_max_length': True, 'load_from_local': True, 'max_train_samples': None, 'max_eval_samples': 15000, 'max_predict_samples': None, 'train_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/train_split6500_2.jsonl', 'validation_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/test_unseen_split6500_2.jsonl', 'test_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/test_unseen_split6500_2.jsonl', 'label_max_seq_length': 160, 'descriptions_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/heir_withdescriptions_v3_v3_unseen_edaaug.json', 'test_descriptions_file': '/n/fs/nlp-pranjal/SemSup-LMLC/training/datasets/Amzn13K/heir_withdescriptions_v3_v3.json', 'all_labels': '/n/fs/n

In [4]:
config = AutoConfig.from_pretrained(
    model_args.config_name if model_args.config_name else model_args.model_name_or_path,
    finetuning_task=data_args.task_name,
    cache_dir=model_args.cache_dir,
    revision=model_args.model_revision,
    use_auth_token=True if model_args.use_auth_token else None,
)

config.model_name_or_path = model_args.model_name_or_path
config.problem_type = dataset_classification_type[data_args.task_name]
config.negative_sampling = model_args.negative_sampling
config.semsup = model_args.semsup
config.encoder_model_type = model_args.encoder_model_type
config.arch_type = model_args.arch_type
config.coil = model_args.coil
config.token_dim = model_args.token_dim
config.colbert = model_args.colbert

In [7]:
label_model, label_tokenizer = getLabelModel(data_args, model_args)
config.label_hidden_size = label_model.config.hidden_size
model = BertForSemanticEmbedding(config)
model.label_model = label_model
model.label_tokenizer = label_tokenizer
model.config.label2id = {l: i for i, l in enumerate(label_list)}
model.config.id2label = {id: label for label, id in config.label2id.items()}

Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.t

Config is BertConfig {
  "_name_or_path": "bert-base-uncased",
  "arch_type": 2,
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "coil": true,
  "colbert": false,
  "encoder_model_type": "bert",
  "finetuning_task": "amazon13k",
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label_hidden_size": 512,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_name_or_path": "bert-base-uncased",
  "model_type": "bert",
  "negative_sampling": "none",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "problem_type": "multi_label_classification",
  "semsup": true,
  "token_dim": 16,
  "transformers_version": "4.20.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}



In [15]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [None]:
coil_cluster_map = json.load(open(''))

In [11]:
max_descs_per_label = 20

In [33]:
def tokenize_class_descs(label_descs, return_lengths = False):
    if return_lengths == 1:
        return {
            label_key: min(descs[0],max_descs_per_label)  for label_key, descs in label_descs.items() 
        } # descs 0 is the length
    else:
        return {
            label_key: descs[1][:max_descs_per_label] for label_key, descs in label_descs.items() 
        }
        
def save_tokenized_descs(add_label_name = False):
    class_descs_tokenized = dict()
    for label_key in tqdm(list(class_descs.keys())):
        descs_len = class_descs_len[label_key]
        descs = class_descs[label_key]
        class_descs_tokenized[label_key] = tokenizer(
            [label_key + ". " + x for x in descs] if add_label_name else
            descs,
            max_length = data_args.label_max_seq_length, padding = 'max_length', truncation= True)

In [23]:
js_file = json.load(open(data_args.test_descriptions_file, encoding = 'utf-8'))
class_descs_len = tokenize_class_descs(js_file, return_lengths = True)
class_descs = tokenize_class_descs(js_file)

In [34]:
tokenized_descs = save_tokenized_descs(model_args.add_label_name)

100%|██████████| 13330/13330 [00:23<00:00, 555.67it/s]


In [61]:
class DummyDset:
    def __init__(self, inp_text = None):
        self.data = [] 
        if inp_text is not None:
            self.set_text(inp_text)

    def set_text(self, tex):
        result = tokenizer(tex, padding='max_length', max_length=data_args.max_seq_length, truncation=True)
        result['label'] = [0 for x in range(len(label_list))]
        self.data = [result]

    def __len__(self,):
        return 1

    def __getitem__(self, idx):
        return self.data[0]

In [65]:
input_dset = DummyDset()

In [67]:
label_list = [x.strip() for x in open(data_args.all_labels).readlines()]

model.config.label2id = {l: i for i, l in enumerate(label_list)}
model.config.id2label = {id: label for label, id in config.label2id.items()}

id2label = model.config.id2label
label_to_id = model.config.label2id

test_labels = [x.strip() for x in open(data_args.test_labels).readlines()]
class_descs_tokenized = None

In [98]:
dset.semsup

True

In [69]:
dset = SemSupDataset(input_dset, data_args, data_args.test_descriptions_file, label_to_id, id2label, label_tokenizer, return_desc_embeddings=True, seen_labels = None if training_args.scenario == 'seen' else test_labels, add_label_name = model_args.add_label_name, max_descs_per_label = data_args.max_descs_per_label, use_precomputed_embeddings = model_args.use_precomputed_embeddings, class_descs_tokenized = class_descs_tokenized, isTrain = False)

100%|██████████| 13330/13330 [00:50<00:00, 263.17it/s]


In [71]:
len(dset)

1

In [77]:
for item in dset:
    break

In [83]:
dset.input_dataset.set_text("Hello this is a text dataset")

In [90]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [96]:
model.to(device)

BertForSemanticEmbedding(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [102]:
item['input_ids']

tensor([[  101,  7592,  2023,  2003,  1037,  3793,  2951, 13462,   102,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [113]:
dset.input_dataset.set_text('Hello how is it going?')

In [None]:
dset.input_dataset.__getitem__(0)

In [None]:
model.eval()
with torch.no_grad():
    for item in dset:
        print('hah')
        item = {k:torch.tensor(v, device = device).unsqueeze(0) for k,v in item.items()}
        del item['label']
        logits = model(**item)
        print('Lets go')

In [105]:
logits

In [87]:
item

{'input_ids': [101, 7592, 2023, 2003, 1037, 3793, 2951, 13462, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

{'input_ids': [101, 7592, 2023, 2003, 1037, 3793, 2951, 13462, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 