{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The primary codes below are based on [akpe12/JP-KR-ocr-translator-for-travel](https://github.com/akpe12/JP-KR-ocr-translator-for-travel)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TrHlPFqwFAgj"
      },
      "source": [
        "## Import"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "t-jXeSJKE1WM"
      },
      "outputs": [],
      "source": [
        "\n",
        "from typing import Dict, List\n",
        "import csv\n",
        "import torch\n",
        "from transformers import (\n",
        "    EncoderDecoderModel,\n",
        "    GPT2Tokenizer as BaseGPT2Tokenizer,\n",
        "    PreTrainedTokenizer, BertTokenizerFast,\n",
        "    PreTrainedTokenizerFast,\n",
        "    DataCollatorForSeq2Seq,\n",
        "    Seq2SeqTrainingArguments,\n",
        "    AutoTokenizer,\n",
        "    XLMRobertaTokenizerFast,\n",
        "    BertJapaneseTokenizer,\n",
        "    Trainer\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers.models.encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel\n",
        "\n",
        "# encoder_model_name = \"xlm-roberta-base\"\n",
        "encoder_model_name = \"cl-tohoku/bert-base-japanese-v2\"\n",
        "decoder_model_name = \"skt/kogpt2-base-v2\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nEW5trBtbykK"
      },
      "outputs": [],
      "source": [
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "# device = torch.device(\"cpu\")\n",
        "device, torch.cuda.device_count()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5ic7pUUBFU_v"
      },
      "outputs": [],
      "source": [
        "class GPT2Tokenizer(PreTrainedTokenizerFast):\n",
        "    def build_inputs_with_special_tokens(self, token_ids: List[int]) -> List[int]:\n",
        "        return token_ids + [self.eos_token_id]        \n",
        "\n",
        "src_tokenizer = BertJapaneseTokenizer.from_pretrained(encoder_model_name)\n",
        "trg_tokenizer = GPT2Tokenizer.from_pretrained(decoder_model_name, bos_token='</s>', eos_token='</s>', unk_token='<unk>',\n",
        "  pad_token='<pad>', mask_token='<mask>')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DTf4U1fmFQFh"
      },
      "source": [
        "## Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "65L4O1c5FLKt"
      },
      "outputs": [],
      "source": [
        "class PairedDataset:\n",
        "    def __init__(self, \n",
        "        src_tokenizer: PreTrainedTokenizerFast, tgt_tokenizer: PreTrainedTokenizerFast,\n",
        "        file_path: str\n",
        "    ):\n",
        "        self.src_tokenizer = src_tokenizer\n",
        "        self.trg_tokenizer = tgt_tokenizer\n",
        "        with open(file_path, 'r') as fd:\n",
        "            reader = csv.reader(fd)\n",
        "            next(reader)\n",
        "            self.data = [row for row in reader]\n",
        "\n",
        "    def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:\n",
        "        src, trg = self.data[index]\n",
        "        embeddings = self.src_tokenizer(src, return_attention_mask=False, return_token_type_ids=False)\n",
        "        embeddings['labels'] = self.trg_tokenizer.build_inputs_with_special_tokens(self.trg_tokenizer(trg, return_attention_mask=False)['input_ids'])\n",
        "\n",
        "        return embeddings\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.data)\n",
        "    \n",
        "DATA_ROOT = './output'\n",
        "FILE_FFAC_FULL = 'ffac_full.csv'\n",
        "FILE_FFAC_TEST = 'ffac_test.csv'\n",
        "# FILE_JA_KO_TRAIN = 'ja_ko_train.csv'\n",
        "# FILE_JA_KO_TEST = 'ja_ko_test.csv'\n",
        "\n",
        "train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_FFAC_FULL}')\n",
        "eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_FFAC_TEST}') \n",
        "# train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_JA_KO_TRAIN}')\n",
        "# eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_JA_KO_TEST}')        "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uCBiLouSFiZY"
      },
      "source": [
        "## Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I7uFbFYJFje8"
      },
      "outputs": [],
      "source": [
        "model = EncoderDecoderModel.from_encoder_decoder_pretrained(\n",
        "    encoder_model_name,\n",
        "    decoder_model_name,\n",
        "    pad_token_id=trg_tokenizer.bos_token_id,\n",
        ")\n",
        "model.config.decoder_start_token_id = trg_tokenizer.bos_token_id"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YFq2GyOAUV0W"
      },
      "outputs": [],
      "source": [
        "# for Trainer\n",
        "import wandb\n",
        "\n",
        "collate_fn = DataCollatorForSeq2Seq(src_tokenizer, model)\n",
        "wandb.init(project=\"fftr-poc1\", name='jbert+kogpt2')\n",
        "\n",
        "arguments = Seq2SeqTrainingArguments(\n",
        "    output_dir='dump',\n",
        "    do_train=True,\n",
        "    do_eval=True,\n",
        "    evaluation_strategy=\"epoch\",\n",
        "    save_strategy=\"epoch\",\n",
        "#     num_train_epochs=5,\n",
        "    num_train_epochs=25,\n",
        "#     per_device_train_batch_size=32,\n",
        "    per_device_train_batch_size=64,\n",
        "#     per_device_eval_batch_size=32,\n",
        "    per_device_eval_batch_size=64,\n",
        "    warmup_ratio=0.1,\n",
        "    gradient_accumulation_steps=4,\n",
        "    save_total_limit=5,\n",
        "    dataloader_num_workers=1,\n",
        "    fp16=True,\n",
        "    load_best_model_at_end=True,\n",
        "    report_to='wandb'\n",
        ")\n",
        "\n",
        "trainer = Trainer(\n",
        "    model,\n",
        "    arguments,\n",
        "    data_collator=collate_fn,\n",
        "    train_dataset=train_dataset,\n",
        "    eval_dataset=eval_dataset\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pPsjDHO5Vc3y"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_T4P4XunmK-C"
      },
      "outputs": [],
      "source": [
        "# model = EncoderDecoderModel.from_encoder_decoder_pretrained(\"xlm-roberta-base\",  \"skt/kogpt2-base-v2\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7vTqAgW6Ve3J"
      },
      "outputs": [],
      "source": [
        "trainer.train()\n",
        "\n",
        "model.save_pretrained(\"dump/best_model\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "machine_shape": "hm",
      "provenance": []
    },
    "gpuClass": "premium",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}