SmolVLM DSE (Document Screenshot Embedding)
This model fine-tunes SmolVLM for document screenshot embedding tasks using contrastive learning.
Model Description
Overview
- Model Type: SmolVLM (IDEFICS3) with DSE architecture
- Base Model: HuggingFaceTB/SmolVLM-256M-Base
- Task: Document Screenshot Embedding (Visual-Text Retrieval)
- Training Data: wiki-ss-nq (queries) and wiki-ss-corpus (document screenshots)
Architecture Details
- Vision Transformer for image encoding (768d)
- LLaMA model for text encoding (576d)
- Linear projection layer to align text representations (576d → 768d)
- Last token pooling with normalization
- Temperature scaling (0.02)
Training
- Epochs: 1
- Batch Size: 8 (effective batch size with gradient accumulation: 32)
- Learning Rate: 1e-5
- Optimizer: AdamW with weight decay 0.01
- Hardware: Single GPU
- Training Time: ~8 hours
Current Performance
- Top-1 Accuracy on wiki-ss-nq test: 1.16%
- Note: The low accuracy suggests potential issues that need investigation:
- Text-vision alignment quality
- Embedding space misalignment
- Need for additional training epochs
- Potential hyperparameter tuning required
Usage
from transformers import AutoProcessor, AutoModelForVision2Seq
# Load model and processor
processor = AutoProcessor.from_pretrained("sugiv/smolvlm-dse")
model = AutoModelForVision2Seq.from_pretrained("sugiv/smolvlm-dse")
# Process query
query_inputs = processor(
text=query_text,
return_tensors="pt",
padding=True,
truncation=True
)
# Process document image
image_inputs = processor(
images=document_image,
return_tensors="pt"
)
# Get embeddings
query_embedding = model.encode_query(query_inputs)
doc_embedding = model.encode_passage(image_inputs)
Limitations and Future Work
Current accuracy is significantly lower than expected
- Investigation needed for:
- Embedding space analysis
- Training dynamics
- Hyperparameter optimization
- Additional training epochs
- Text-vision alignment quality
Training Configuration
deepspeed --include localhost:0 --master_port 60000 train.py \
--deepspeed ds_zero2_config.json \
--output_dir retriever-smolvlm \
--model_name_or_path HuggingFaceTB/SmolVLM-256M-Base \
--save_steps 50 \
--dataset_name Tevatron/wiki-ss-nq \
--corpus_name Tevatron/wiki-ss-corpus \
--cache_dir ./cached_datasets \
--query_prefix "Query: " \
--passage_prefix "Passage: " \
--bf16 \
--pooling last \
--normalize \
--temperature 0.02 \
--per_device_train_batch_size 8 \
--gradient_checkpointing \
--train_group_size 16 \
--learning_rate 1e-5 \
--weight_decay 0.01 \
--query_max_len 128 \
--passage_max_len 512 \
--num_train_epochs 1
License
Same as base model HuggingFaceTB/SmolVLM-256M-Base
@article{Gao2022TevatronAE, title={Tevatron: An Efficient and Flexible Toolkit for Dense Retrieval}, author={Luyu Gao and Xueguang Ma and Jimmy J. Lin and Jamie Callan}, journal={ArXiv}, year={2022}, volume={abs/2203.05765} }
- Downloads last month
- 3
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
Model tree for sugiv/smolvlm-dse
Base model
HuggingFaceTB/SmolLM2-1.7B
Quantized
HuggingFaceTB/SmolLM2-1.7B-Instruct
Quantized
HuggingFaceTB/SmolVLM-Instruct