Sanraj/tiny_llama1.1B_finetuned
A safety-enhanced version of TinyLlama-1.1B-Chat-v1.0, fine-tuned using NEMO-RL on a combined synthetic safety dataset (Aegis + WildGuard) to improve responsible AI behavior and reduce harmful outputs.
Model Details
Model Description
This model is a safety-focused fine-tuned version of TinyLlama-1.1B-Chat-v1.0, trained using NVIDIA's NEMO-RL framework on a combined synthetic dataset from Aegis and WildGuard safety datasets. The fine-tuning process focused on teaching the model to recognize and appropriately handle potentially harmful or sensitive content requests using high-quality synthetic safety data.
- Developed by: Sanraj
- Model type: Causal Language Model (Decoder-only Transformer)
- Language(s): English (primarily)
- License: Apache 2.0 (inherited from base model)
- Finetuned from model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
- Model size: 1.1B parameters
- Fine-tuning focus: Content Safety and Responsible AI
- Fine-tuning framework: NVIDIA NEMO-RL
- Model ID: Sanraj/tiny_llama1.1B_finetuned
Model Sources
- Model Repository: Sanraj/tiny_llama1.1B_finetuned
- Base Repository: TinyLlama/TinyLlama-1.1B-Chat-v1.0
- Fine-tuning Framework: NVIDIA NEMO-RL
- Training Configuration: safety-for-agentic-ai training pipeline
Uses
Direct Use
This model is designed for conversational AI applications where content safety is a priority. It can be used for:
- Safe chatbot applications
- Educational tools requiring content moderation
- Research into AI safety and alignment
- Applications requiring responsible AI behavior
Downstream Use
The model can be further fine-tuned for specific safety-critical applications or integrated into larger systems requiring content moderation capabilities.
Out-of-Scope Use
- High-stakes decision making without human oversight
- Applications where safety failures could cause significant harm
- Production systems without additional safety measures
- Use cases requiring capabilities beyond the base model's scope
Bias, Risks, and Limitations
While this model has been specifically fine-tuned for safety, it still inherits limitations from the base TinyLlama model:
- Model size limitations: As a 1.1B parameter model, it may have limited knowledge and reasoning capabilities
- Training data source: Combined synthetic dataset (Aegis + WildGuard)
- Safety coverage: Safety training may not cover all possible harmful scenarios despite comprehensive synthetic data
- Language limitations: Primarily trained and tested on English content
Recommendations
- Always implement additional safety measures in production environments
- Regular evaluation and monitoring for safety performance
- Human oversight recommended for sensitive applications
- Consider ensemble approaches with larger safety models for critical applications
How to Get Started with the Model
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Sanraj/tiny_llama1.1B_finetuned")
model = AutoModelForCausalLM.from_pretrained(
"Sanraj/tiny_llama1.1B_finetuned",
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Example usage
def generate_safe_response(prompt, max_length=512):
inputs = tokenizer.encode(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=max_length,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response[len(prompt):].strip()
# Example
response = generate_safe_response("How can I help someone who is feeling sad?")
print(response)
Training Details
Training Data
The model was fine-tuned using a combined synthetic safety dataset derived from two prominent safety datasets:
- Aegis Dataset: A comprehensive safety dataset focusing on AI safety scenarios and appropriate responses
- WildGuard Dataset: A dataset designed to train models to recognize and handle harmful content requests
The combined dataset contains:
- Synthetic prompt-response pairs focused on safety scenarios
- Examples of appropriate responses to potentially harmful requests
- Diverse safety categories covering various types of harmful content
- Training data filtered for quality and safety relevance
Dataset Characteristics:
- Source: Combined synthetic data from Aegis + WildGuard datasets
- Format: Prompt-response pairs in JSONL format
- Training file:
train_on_policy_data_filtered.jsonl
- Validation file:
val_on_policy_data_filtered.jsonl
- Input key: "input" (prompts/queries)
- Output key: "generated_output" (safe responses)
Training Procedure
Training Configuration
- Training regime: bfloat16 mixed precision
- Optimizer: AdamW with learning rate 2e-6
- Scheduler: Linear warmup (5 steps) + Cosine annealing
- Max epochs: 1
- Training steps: 20 (fast training configuration)
- Batch size: 8 global, 2 micro-batch size
- Sequence length: 2048 tokens
- Gradient clipping: 1.0
Memory Optimizations
- FSDP CPU offloading enabled
- Activation checkpointing enabled
- Gradient checkpointing enabled
- Single GPU training configuration
Training Hyperparameters
- Learning rate: 2.0e-6
- Weight decay: 0.01
- Beta1: 0.9
- Beta2: 0.999
- Epsilon: 1e-8
- Max gradient norm: 1.0
Evaluation
Safety Evaluation
The model should be evaluated on:
- Response appropriateness to harmful requests
- Ability to provide helpful alternatives to unsafe requests
- Consistency in safety behavior across diverse prompts
- Maintenance of general conversational capabilities
Recommended Evaluation Datasets
- Custom safety evaluation benchmarks
- Conversational AI evaluation suites
- Red-teaming evaluations for safety
Environmental Impact
Training was optimized for efficiency:
- Hardware: Single GPU training (reduced from dual GPU)
- Training time: Minimal (20 steps for proof of concept)
- Compute efficiency: Aggressive memory optimizations enabled
Technical Specifications
Model Architecture
- Base Architecture: Llama-based decoder-only transformer
- Parameters: 1.1B
- Context length: 2048 tokens
- Vocabulary size: 32,000 tokens (inherited from TinyLlama)
Fine-tuning Infrastructure
- Framework: NVIDIA NEMO-RL
- Precision: bfloat16
- Memory optimizations: FSDP, activation checkpointing
- Monitoring: Weights & Biases integration
- Training pipeline: safety-for-agentic-ai framework
Usage Notes
This model was trained with a fast configuration (20 steps) primarily for demonstration purposes. For production use, consider:
- Extended training: Increase training steps and epochs
- Larger datasets: Expand safety dataset coverage
- Comprehensive evaluation: Thorough safety and capability testing
- Regular updates: Continuous improvement based on usage patterns
Citation
If you use this model, please cite the original TinyLlama paper and acknowledge the safety datasets used:
@article{zhang2024tinyllama,
title={TinyLlama: An Open-Source Small Language Model},
author={Zhang, Peiyuan and Guangtao, Zeng and Wang, Tianduo and Lu, Wei},
journal={arXiv preprint arXiv:2401.02385},
year={2024}
}
Dataset Acknowledgments:
- Aegis Dataset: Please cite the original Aegis safety dataset paper
- WildGuard Dataset: Please cite the original WildGuard dataset paper
Model Card Contact
For questions about this model, please open an issue on the model repository.
Disclaimer: This model is provided for research and educational purposes. While fine-tuned for safety, it should not be deployed in production without thorough testing and additional safety measures.
- Downloads last month
- 14
Model tree for Sanraj/tiny_llama1.1B_finetuned
Base model
TinyLlama/TinyLlama-1.1B-Chat-v1.0