🧠 Fine-Tuned GPT-2 Medium for Conversational AI

This project fine-tunes the gpt2-medium language model to support natural, casual conversational dialogue using PEFT + LoRA.


🚀 Model Summary

  • Base model: gpt2-medium
  • Objective: Enable natural question-answering and dialogue
  • Training method: Supervised Fine-Tuning (SFT) using PEFT with LoRA adapters
  • Tokenizer: gpt2 (same as base model)

📈 Training Metrics

Metric Value
Global Steps 2611
Final Training Loss 2.185
Training Runtime 430.61 seconds
Samples/sec 138.41
Steps/sec 17.32
Total FLOPs 1.12 × 10¹⁵
Epochs 7.0

These metrics reflect final performance after complete training.


💬 Inference Script

Chat with the model using the talk() function below:

def talk(model=peft_model, tokenizer=tokenizer, device=device):
    print("Start chatting with the bot! Type 'exit' to stop.\n")
    while True:
        question = input("You: ")
        if question.lower() == "exit":
            print("Goodbye!")
            break

        prompt = f"User: {question}\nBot:"
        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=20,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id
            )

        response = tokenizer.decode(
            outputs[0][inputs["input_ids"].shape[-1]:],
            skip_special_tokens=True
        )

        # Clean response
        response = response.split(".")
        response = ".".join(response[:-1]) + "."
        print("Bot:", response.strip())
  • 🤖 Stateless: No memory across turns (yet).
  • 🌱 Future idea: Add memory/context for multi-turn dialogue.

⚙️ Quick Setup

To use this model locally:

pip install transformers peft accelerate

Downloads last month
6
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for kunjcr2/gpt2_conv

Adapter
(163)
this model

Dataset used to train kunjcr2/gpt2_conv