smolvla / modeling_lerobot_policy.py
gribok201's picture
Upload folder using huggingface_hub
2f14779 verified
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
# This import assumes 'lerobot' is installed in the user's environment
from lerobot.smolvla_base import SmolVLABasePolicy
class LerobotSmolVLAConfig(PretrainedConfig):
model_type = "lerobot_smolvla"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.adapt_to_pi_aloha = False
self.add_image_special_tokens = False
self.attention_mode = 'cross_attn'
self.chunk_size = 50
self.device = 'cuda'
self.empty_cameras = 0
self.expert_width_multiplier = 0.75
self.freeze_vision_encoder = True
self.input_features = {'observation.image': {'shape': [3, 256, 256], 'type': 'VISUAL'}, 'observation.image2': {'shape': [3, 256, 256], 'type': 'VISUAL'}, 'observation.image3': {'shape': [3, 256, 256], 'type': 'VISUAL'}, 'observation.state': {'shape': [6], 'type': 'STATE'}}
self.load_vlm_weights = True
self.max_action_dim = 32
self.max_period = 4
self.max_state_dim = 32
self.min_period = 0.004
self.n_action_steps = 50
self.n_obs_steps = 1
self.normalization_mapping = {'ACTION': 'MEAN_STD', 'STATE': 'MEAN_STD', 'VISUAL': 'IDENTITY'}
self.num_expert_layers = 0
self.num_steps = 10
self.num_vlm_layers = 16
self.optimizer_betas = [0.9, 0.95]
self.optimizer_eps = '1e-08'
self.optimizer_grad_clip_norm = 10
self.optimizer_lr = 0.0001
self.optimizer_weight_decay = '1e-10'
self.output_features = {'action': {'shape': [6], 'type': 'ACTION'}}
self.pad_language_to = 'max_length'
self.prefix_length = 0
self.resize_imgs_with_padding = [512, 512]
self.scheduler_decay_lr = 2.5e-06
self.scheduler_decay_steps = 30000
self.scheduler_warmup_steps = 1000
self.self_attn_every_n_layers = 2
self.tokenizer_max_length = 48
self.train_expert_only = True
self.train_state_proj = True
self.type = 'smolvla'
self.use_amp = False
self.use_cache = True
self.use_delta_joint_actions_aloha = False
self.vlm_model_name = 'HuggingFaceTB/SmolVLM2-500M-Video-Instruct'
for k, v in kwargs.items():
if not hasattr(self, k):
setattr(self, k, v)
class LerobotSmolVLAWrappedModel(PreTrainedModel):
config_class = LerobotSmolVLAConfig
def __init__(self, config):
super().__init__(config)
# to_dict() correctly extracts all config parameters for the policy
policy_init_kwargs = config.to_dict()
self.smolvla_policy = SmolVLABasePolicy(**policy_init_kwargs)
def forward(self, observations, actions=None, language_instruction=None, timestep=None):
# This explicit signature is better for usability and documentation
return self.smolvla_policy(
observations=observations,
actions=actions,
language_instruction=language_instruction,
timestep=timestep
)