{ "cells": [ { "cell_type": "markdown", "id": "a8b6caed", "metadata": {}, "source": [ "# πͺπΊ π·οΈ Eurovoc Model Training Notebook" ] }, { "cell_type": "code", "execution_count": 1, "id": "c4c73793", "metadata": {}, "outputs": [], "source": [ "import pickle \n", "import pandas as pd\n", "from transformers import AutoTokenizer, AutoModel\n", "\n", "from datasets import list_datasets, load_dataset\n", "\n", "from sklearn.preprocessing import MultiLabelBinarizer\n", "import torch\n", "\n", "import pytorch_lightning as pl\n", "from pytorch_lightning.callbacks import ModelCheckpoint" ] }, { "cell_type": "markdown", "id": "dc770f0b", "metadata": { "tags": [] }, "source": [ "---\n", "\n", "## 1. Data loading\n", "### Choose our dataset, extracted from ep registry or eurlex57k" ] }, { "cell_type": "code", "execution_count": 2, "id": "9fdc5328", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset json (/home/scampion/.cache/huggingface/datasets/EuropeanParliament___json/EuropeanParliament--cellar_eurovoc-3a27a019ebbf0296/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d5bf91bf9dc2416faefe96d680217da6", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#dataset = load_dataset('json', data_files='ep_registry.jsonl')\n", "\n", "#dataset = load_dataset('eurlex')\n", "dataset = load_dataset('EuropeanParliament/cellar_eurovoc')\n" ] }, { "cell_type": "markdown", "id": "94967fc2", "metadata": {}, "source": [ "### Merge train, test and validation" ] }, { "cell_type": "code", "execution_count": 3, "id": "ce5f764f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | title | \n", "date | \n", "eurovoc_concepts | \n", "url | \n", "lang | \n", "formats | \n", "text | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "Corrigendum to Commission Implementing Regulat... | \n", "2023-07-20 | \n", "[China, Malaysia, anti-dumping duty, business ... | \n", "http://publications.europa.eu/resource/cellar/... | \n", "eng | \n", "[fmx4, pdfa2a, xhtml] | \n", "L_2023183EN. 01005801. xml 20. 7. 2023Β Β Β EN O... | \n", "
1 | \n", "Council Decision (CFSP) 2023/1501 of 20Β July 2... | \n", "2023-07-20 | \n", "[EU restrictive measure, Russia, Ukraine, econ... | \n", "http://publications.europa.eu/resource/cellar/... | \n", "eng | \n", "[fmx4, pdfa2a, xhtml] | \n", "LI2023183EN. 01004801. xml 20. 7. 2023Β Β Β EN O... | \n", "
2 | \n", "Council Decision (CFSP) 2023/1502 of 20Β July 2... | \n", "2023-07-20 | \n", "[Burma/Myanmar, EU restrictive measure, econom... | \n", "http://publications.europa.eu/resource/cellar/... | \n", "eng | \n", "[fmx4, pdfa2a, xhtml] | \n", "LI2023183EN. 01005201. xml 20. 7. 2023Β Β Β EN O... | \n", "
3 | \n", "The Committee of the Regions welcomes Croatian... | \n", "2023-07-20 | \n", "[Croatia, EU regional policy, European Committ... | \n", "http://publications.europa.eu/resource/cellar/... | \n", "eng | \n", "[pdf] | \n", "EUROPEAN UNION Committee of the Regions The Co... | \n", "
4 | \n", "Corrigendum to Commission Implementing Regulat... | \n", "2023-07-20 | \n", "[India, TΓΌrkiye, anti-dumping duty, building m... | \n", "http://publications.europa.eu/resource/cellar/... | \n", "eng | \n", "[fmx4, pdfa2a, xhtml] | \n", "L_2023183EN. 01005901. xml 20. 7. 2023Β Β Β EN O... | \n", "
βββββββββββββββββββββββββββββ³ββββββββββββββββββββββββββββ\n", "β Test metric β DataLoader 0 β\n", "β‘ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ©\n", "β test_loss β 0.0031269278842955828 β\n", "βββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββ\n", "\n" ], "text/plain": [ "βββββββββββββββββββββββββββββ³ββββββββββββββββββββββββββββ\n", "β\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0mβ\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0mβ\n", "β‘ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ©\n", "β\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0mβ\u001b[35m \u001b[0m\u001b[35m 0.0031269278842955828 \u001b[0m\u001b[35m \u001b[0mβ\n", "βββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββ\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[{'test_loss': 0.0031269278842955828}]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.test(dataloaders=dataloader)" ] }, { "cell_type": "markdown", "id": "66b871ec", "metadata": {}, "source": [ "# Evaluation" ] }, { "cell_type": "code", "execution_count": 16, "id": "ba317c3e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/home/scampion/training/lightning_logs/version_9/checkpoints/EurovocTagger-epoch=06-val_loss=0.00.ckpt'" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_model_path = trainer.checkpoint_callback.best_model_path\n", "best_model_path" ] }, { "cell_type": "code", "execution_count": 17, "id": "fe9751a1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at nlpaueb/legal-bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "100%|ββββββββββ| 23243/23243 [16:20<00:00, 23.72it/s] \n" ] } ], "source": [ "from tqdm import tqdm\n", "from transformers import AutoTokenizer\n", "\n", "trained_model = EurovocTagger.load_from_checkpoint(best_model_path,\n", " bert_model_name=BERT_MODEL_NAME,\n", " n_classes=len(mlb.classes_))\n", "trained_model.eval()\n", "trained_model.freeze()\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "trained_model = trained_model.to(device)\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)\n", "\n", "val_dataset = EurovocDataset(x_test, y_test, tokenizer, max_token_len=MAX_LEN)\n", "predictions = []\n", "labels = []\n", "\n", "for item in tqdm(val_dataset):\n", " _, prediction = trained_model(\n", " item[\"input_ids\"].unsqueeze(dim=0).to(device), \n", " item[\"attention_mask\"].unsqueeze(dim=0).to(device)\n", " )\n", " predictions.append(prediction.flatten())\n", " labels.append(item[\"labels\"].int())\n", "\n", "predictions = torch.stack(predictions).detach().cpu()\n", "labels = torch.stack(labels).detach().cpu()" ] }, { "cell_type": "markdown", "id": "67477f7f", "metadata": {}, "source": [ "### F1 Score" ] }, { "cell_type": "code", "execution_count": 18, "id": "f0265f6e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.01 tensor(0.2188)\n", "0.06 tensor(0.3929)\n", "0.11 tensor(0.4353)\n", "0.16 tensor(0.4462)\n", "0.21 tensor(0.4437)\n", "0.26 tensor(0.4364)\n", "0.31 tensor(0.4249)\n", "0.36 tensor(0.4106)\n", "0.41 tensor(0.3947)\n", "0.46 tensor(0.3780)\n", "0.51 tensor(0.3597)\n", "0.56 tensor(0.3404)\n", "0.61 tensor(0.3209)\n", "0.66 tensor(0.3007)\n" ] } ], "source": [ "from torchmetrics import F1Score\n", "for i in range(1, 70, 5):\n", " f1 = F1Score(task=\"multilabel\", num_labels=len(mlb.classes_), average='weighted', threshold= i / 100.0)\n", " print(i / 100.0, f1(predictions, labels))" ] }, { "cell_type": "markdown", "id": "0945ad49", "metadata": {}, "source": [ "### NDCG Score" ] }, { "cell_type": "code", "execution_count": null, "id": "e4e3291f", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import ndcg_score\n", "def calculate_average_ndcg(predictions, labels, top_k=5):\n", " # Initialize a list to store NDCG scores for each sample\n", " ndcg_scores = []\n", "\n", " # Calculate NDCG for each sample\n", " for i in range(len(predictions)):\n", " # Convert tensors to numpy arrays\n", " y_true = labels[i].cpu().numpy().reshape(1, -1)\n", " y_score = predictions[i].cpu().numpy().reshape(1, -1)\n", " \n", " # Calculate NDCG for the sample\n", " ndcg = ndcg_score(y_true, y_score, k=top_k)\n", " ndcg_scores.append(ndcg)\n", "\n", " # Calculate the average NDCG score\n", " average_ndcg = np.mean(ndcg_scores)\n", " \n", " return average_ndcg\n", "\n", "for k in [3, 5, 10]:\n", " average = calculate_average_ndcg(predictions, labels, top_k=k)\n", " print(\"NDCG@\"+str(k)+\": \"+ str(round(average, 4)))" ] } ], "metadata": { "kernelspec": { "display_name": "eurovoc-env", "language": "python", "name": "eurovoc-env" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }