bart-large-cnn-finetuned-samsum-lora

This model is a further fine-tuned version of facebook/bart-large-cnn on the samsum dataset.

The base model bart-large-cnn is a fine-tuned verstion of BART model on the CNN Daily Mail dataset.

Check out sooolee/flan-t5-base-cnn-samsum-lora the model fine-tuned for the same purpose.

Model description

  • This model further finetuned 'bart-large-cnn' on the more conversational samsum dataset.
  • Huggingface PEFT Library LoRA (r = 8) was used to speed up training and reduce the model size.
  • Less than 1.2M parameters were trained (0.23% of original bart-large 510M parameters).
  • The model checkpoint is less than 5MB.

Intended uses & limitations

Summarizes transcripts such as YouTube transcripts.

Training and evaluation data

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 0.0005
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • num_epochs: 5

Training results

  • train_loss: 1.28
  • rogue1: 43.115465%
  • rouge2: 21.563061%
  • rougeL: 33.409979%
  • rougeLsum: 33.414162%

How to use

Note 'max_new_tokens=60' is used in the example below to control the summary size. BART model has max generation length = 142 (default) and min generation length = 56.

import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load peft config for pre-trained checkpoint etc.
peft_model_id = "sooolee/bart-large-cnn-finetuned-samsum-lora"
config = PeftConfig.from_pretrained(peft_model_id)

# load base LLM model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id, device_map='auto')

# Tokenize the text inputs
texts = "<e.g. Transcript>"
inputs = tokenizer(texts, return_tensors="pt", padding=True, ) # truncation=True

# Make inferences
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():    
    output = self.model.generate(input_ids=inputs["input_ids"].to(device), max_new_tokens=60, do_sample=True, top_p=0.9)
    summary = self.tokenizer.batch_decode(output.detach().cpu().numpy(), skip_special_tokens=True)

summary

Framework versions

  • Transformers 4.27.2
  • Pytorch 1.13.1+cu116
  • Datasets 2.9.0
  • Tokenizers 0.13.3
Downloads last month
29
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The HF Inference API does not support summarization models for adapter-transformers library.

Dataset used to train sooolee/bart-large-cnn-samsum-lora

Evaluation results