Speculative Decoding

This model implements Medusa, an efficient speculative decoding approach that can achieve up to 3x faster inference for large language models. The implementation consists of the base Vicuna-7B model augmented with specialized prediction heads that enable parallel token generation.

Model Description

  • Model type: Causal language model with speculative decoding
  • Base model: Vicuna-7B-v1.3
  • Language: English
  • License: MIT

This implementation adds multiple speculative heads on top of the base Vicuna model. Each speculative head attempts to predict future tokens in parallel, enabling faster inference by generating multiple tokens in a single forward pass.

Technical Specifications

  • Base model: lmsys/vicuna-7b-v1.3
  • Parameters: ~7B (base model) + speculative heads
  • Context length: 2048 tokens
  • Speculative heads: 3
  • Layers per head: 1
  • Architecture: Each head uses residual blocks (ResBlock) with SiLU activation

Direct Inference

import torch
from model import MedusaModel
from transformers import AutoTokenizer

# Load model (requires Medusa codebase and compatible weights)
model = MedusaModel.from_pretrained("theharshithh/vicuna-7b-speculative")
tokenizer = model.get_tokenizer()

prompt = "Human: What is machine learning?\nAssistant:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")

# Generate text
generated_text = ""
with torch.no_grad():
    for output in model.medusa_generate(
        input_ids=input_ids,
        temperature=0.7,
        max_steps=512
    ):
        generated_text = output["text"]

print(prompt + generated_text)

ONNX Inference

For accelerated inference, first export the model to ONNX format:

from inference.infer_onnx import export_model, inference

# Export model to ONNX
export_model(model_checkpoint="theharshithh/vicuna-7b-speculative", save_directory="onnx/")

# Run inference with ONNX model
inference(model_checkpoint="theharshithh/vicuna-7b-speculative", save_directory="onnx/")

Training Details

This model was trained on a processed dataset derived from the ShareGPT conversations dataset. Training was conducted using distributed training across multiple GPUs, with these key hyperparameters:

  • Learning rate: 1e-3
  • Weight decay: 0.0
  • Warmup ratio: 0.1
  • Scheduler: Cosine
  • Batch size: 4 (per device)
  • Gradient accumulation steps: 4
  • Precision: BF16

Performance

The model achieves significant inference speedups compared to standard autoregressive generation:

  • Up to 3x faster inference speed by generating multiple tokens in parallel
  • Minimal impact on output quality compared to standard generation

The training loss curves show steady convergence, though longer training would likely yield further improvements.

Limitations

  • Currently optimized for inference with batch size 1
  • Performance varies based on text complexity and token predictability
  • Training was conducted with limited GPU resources (extended training recommended for optimal results)

Citation

If you use this model in your research, please cite:

GitHub Repository:

@misc{medusa-vicuna-7b-speculative-github,
  author = {theharshithh},
  title = {Medusa: Fast LLM Inference with Speculative Decoding},
  year = {2025},
  publisher = {GitHub},
  howpublished = {https://github.com/theharshithh/speculative-decoding}
}

Repository & Source Code

For implementation details, source code, and further documentation, see the GitHub repository:

Downloads last month
5
Safetensors
Model size
6.79B params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for theharshithh/vicuna-7b-speculative

Finetuned
(58)
this model