PPO-QLoRA Trained Model (spark-model-QLoRA)

This repository contains an agent (actor and critic models) trained using Proximal Policy Optimization (PPO) with QLoRA. The training was performed using the scripts and models available in the spark_rl directory of the explore-rl project.

Base Model: meta-llama/Llama-3-8B-Instruct (or specify if different, based on your train.py arguments)

Model Components

The model_final directory (uploaded here as the root of these components) contains:

  • actor/: LoRA adapters for the actor (policy) model.
  • critic/: LoRA adapters for the critic (value) model's base LLM, and a value_head.pt file for its custom value prediction head.
  • tokenizer/: The Hugging Face tokenizer used during training.
  • hyperparams.txt: Key hyperparameters used for the PPO training.
  • models.py: Contains the LLMActorLora and LLMCriticLora class definitions required to load and use these models.

How to Use

To use these models, you will need the LLMActorLora and LLMCriticLora classes from the included models.py file.

import torch
from transformers import AutoTokenizer
from models import LLMActorLora, LLMCriticLora # models.py is in this repository

# --- Configuration ---
BASE_MODEL_ID = "meta-llama/Llama-3-8B-Instruct" # IMPORTANT: Ensure this matches the model used for training!
MODEL_REPO_PATH = "gabrielbo/spark-model-QLoRA" # Or local path if downloaded
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load Tokenizer ---
try:
    tokenizer = AutoTokenizer.from_pretrained(f"{MODEL_REPO_PATH}/tokenizer")
except Exception: # Fallback if tokenizer is in the root
    tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_PATH)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # Ensure consistency if PPO agent used left padding

# --- Load Actor ---
actor = LLMActorLora(
    device=DEVICE,
    model_id=BASE_MODEL_ID,
    # lora_r and disable_quantization can be defaults or from hyperparams.txt
)
# Path to actor adapters within the model repo
actor_adapters_path = f"{MODEL_REPO_PATH}/actor" 
actor.load_pretrained(actor_adapters_path)
actor.model.eval()
print("Actor loaded successfully.")

# --- Load Critic ---
critic = LLMCriticLora(
    device=DEVICE,
    model_id=BASE_MODEL_ID,
    # lora_r and disable_quantization can be defaults or from hyperparams.txt
)
# Path to critic components within the model repo
critic_components_path = f"{MODEL_REPO_PATH}/critic"
critic.load_pretrained(critic_components_path)
critic.model.eval()
critic.value_head.eval()
print("Critic loaded successfully.")

# --- Example: Generating an action (conceptual) ---
# This part is highly dependent on how your PPOAgent prepares inputs.
# The following is a generic example. You'll need to adapt it.

# Example input construction (refer to PPOAgent.prepare_batch)
question = "What is the capital of France?"
state_text = "The current context is a geography quiz."
input_text = f"Question: {question}\n\nState: {state_text}\n\nAction:"

inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)

print(f"\nGenerating action for: {input_text}")
with torch.no_grad():
    # Actor generates token IDs
    # Note: Generation kwargs might be needed (e.g., temperature, top_p from hyperparams.txt or evaluate.py)
    generated_ids = actor.generate(
        inputs.input_ids,
        attention_mask=inputs.attention_mask,
        max_new_tokens=50, # Adjust as needed
        # temperature=0.7, # Example
        # top_p=0.9,       # Example
        do_sample=True   # Example, if sampling was used
    )
    
    # Decode the generated action
    # The generated output includes the input_text, so we need to slice it off.
    # This depends on tokenizer.padding_side; if "left", then slicing logic changes.
    # Assuming tokenizer.padding_side = "right" (default for many models) or handled by generate
    
    # If tokenizer.padding_side was "left" for generation, the input is at the end.
    # For simplicity, let's assume the output only contains new tokens after input.
    # This might need adjustment based on specific generation config.
    
    # A common way to get only the generated part:
    response_ids = generated_ids[0][inputs.input_ids.shape[-1]:]
    action_text = tokenizer.decode(response_ids, skip_special_tokens=True)
    
    print(f"Generated Action: {action_text.strip()}")

    # --- Example: Getting a value estimate (conceptual) ---
    value_prediction = critic.forward(inputs.input_ids, attention_mask=inputs.attention_mask)
    print(f"Value prediction for the state: {value_prediction.item()}")

Training Details

The model was trained using the PPO algorithm with the following key settings (see hyperparams.txt for more details):

  • Learning Rate (Actor): (Refer to lr in hyperparams.txt)
  • Learning Rate (Critic): (Refer to critic_lr in hyperparams.txt)
  • PPO Clip Ratio: (Refer to clip_ratio in hyperparams.txt)
  • KL Coefficient: (Refer to kl_coef in hyperparams.txt)
  • Target KL: (Refer to target_kl in hyperparams.txt)
  • Batch Size: (As per your training script, e.g., args.batch)
  • PPO Epochs: (As per your training script, e.g., args.ppo_epochs)
  • Total PPO Iterations: (As per your training script, e.g., args.steps)

The specific dataset used for training was MMLU trajectories.

Intended Use

This model is intended for tasks requiring sequential decision-making and reasoning, similar to the MMLU benchmark. It can be used as a starting point for further fine-tuning or for direct application in relevant domains.

Limitations

  • The model's performance is tied to the quality and characteristics of the offline trajectory data it was trained on.
  • As a LoRA-adapted model, it relies on the capabilities of the base meta-llama/Llama-3-8B-Instruct model.
  • The generation behavior may require careful prompt engineering.

Citation

If you use this model or the spark_rl codebase, please consider citing the original explore-rl repository: [Link to your explore-rl GitHub repository, if public]

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support