SmolVLM DSE (Document Screenshot Embedding)

This model fine-tunes SmolVLM for document screenshot embedding tasks using contrastive learning.

Model Description

Overview

  • Model Type: SmolVLM (IDEFICS3) with DSE architecture
  • Base Model: HuggingFaceTB/SmolVLM-256M-Base
  • Task: Document Screenshot Embedding (Visual-Text Retrieval)
  • Training Data: wiki-ss-nq (queries) and wiki-ss-corpus (document screenshots)

Architecture Details

  • Vision Transformer for image encoding (768d)
  • LLaMA model for text encoding (576d)
  • Linear projection layer to align text representations (576d → 768d)
  • Last token pooling with normalization
  • Temperature scaling (0.02)

Training

  • Epochs: 1
  • Batch Size: 8 (effective batch size with gradient accumulation: 32)
  • Learning Rate: 1e-5
  • Optimizer: AdamW with weight decay 0.01
  • Hardware: Single GPU
  • Training Time: ~8 hours

Current Performance

  • Top-1 Accuracy on wiki-ss-nq test: 1.16%
  • Note: The low accuracy suggests potential issues that need investigation:
    • Text-vision alignment quality
    • Embedding space misalignment
    • Need for additional training epochs
    • Potential hyperparameter tuning required

Usage

from transformers import AutoProcessor, AutoModelForVision2Seq

# Load model and processor
processor = AutoProcessor.from_pretrained("sugiv/smolvlm-dse")
model = AutoModelForVision2Seq.from_pretrained("sugiv/smolvlm-dse")

# Process query
query_inputs = processor(
    text=query_text,
    return_tensors="pt",
    padding=True,
    truncation=True
)

# Process document image
image_inputs = processor(
    images=document_image,
    return_tensors="pt"
)

# Get embeddings
query_embedding = model.encode_query(query_inputs)
doc_embedding = model.encode_passage(image_inputs)

Limitations and Future Work

Current accuracy is significantly lower than expected

  1. Investigation needed for:
  2. Embedding space analysis
    • Training dynamics
    • Hyperparameter optimization
    • Additional training epochs
    • Text-vision alignment quality

Training Configuration

deepspeed --include localhost:0 --master_port 60000 train.py \
  --deepspeed ds_zero2_config.json \
  --output_dir retriever-smolvlm \
  --model_name_or_path HuggingFaceTB/SmolVLM-256M-Base \
  --save_steps 50 \
  --dataset_name Tevatron/wiki-ss-nq \
  --corpus_name Tevatron/wiki-ss-corpus \
  --cache_dir ./cached_datasets \
  --query_prefix "Query: " \
  --passage_prefix "Passage: " \
  --bf16 \
  --pooling last \
  --normalize \
  --temperature 0.02 \
  --per_device_train_batch_size 8 \
  --gradient_checkpointing \
  --train_group_size 16 \
  --learning_rate 1e-5 \
  --weight_decay 0.01 \
  --query_max_len 128 \
  --passage_max_len 512 \
  --num_train_epochs 1

License

Same as base model HuggingFaceTB/SmolVLM-256M-Base

@article{Gao2022TevatronAE, title={Tevatron: An Efficient and Flexible Toolkit for Dense Retrieval}, author={Luyu Gao and Xueguang Ma and Jimmy J. Lin and Jamie Callan}, journal={ArXiv}, year={2022}, volume={abs/2203.05765} }

Downloads last month
3
Safetensors
Model size
257M params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for sugiv/smolvlm-dse

Dataset used to train sugiv/smolvlm-dse