import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig # This import assumes 'lerobot' is installed in the user's environment from lerobot.smolvla_base import SmolVLABasePolicy class LerobotSmolVLAConfig(PretrainedConfig): model_type = "lerobot_smolvla" def __init__(self, **kwargs): super().__init__(**kwargs) self.adapt_to_pi_aloha = False self.add_image_special_tokens = False self.attention_mode = 'cross_attn' self.chunk_size = 50 self.device = 'cuda' self.empty_cameras = 0 self.expert_width_multiplier = 0.75 self.freeze_vision_encoder = True self.input_features = {'observation.image': {'shape': [3, 256, 256], 'type': 'VISUAL'}, 'observation.image2': {'shape': [3, 256, 256], 'type': 'VISUAL'}, 'observation.image3': {'shape': [3, 256, 256], 'type': 'VISUAL'}, 'observation.state': {'shape': [6], 'type': 'STATE'}} self.load_vlm_weights = True self.max_action_dim = 32 self.max_period = 4 self.max_state_dim = 32 self.min_period = 0.004 self.n_action_steps = 50 self.n_obs_steps = 1 self.normalization_mapping = {'ACTION': 'MEAN_STD', 'STATE': 'MEAN_STD', 'VISUAL': 'IDENTITY'} self.num_expert_layers = 0 self.num_steps = 10 self.num_vlm_layers = 16 self.optimizer_betas = [0.9, 0.95] self.optimizer_eps = '1e-08' self.optimizer_grad_clip_norm = 10 self.optimizer_lr = 0.0001 self.optimizer_weight_decay = '1e-10' self.output_features = {'action': {'shape': [6], 'type': 'ACTION'}} self.pad_language_to = 'max_length' self.prefix_length = 0 self.resize_imgs_with_padding = [512, 512] self.scheduler_decay_lr = 2.5e-06 self.scheduler_decay_steps = 30000 self.scheduler_warmup_steps = 1000 self.self_attn_every_n_layers = 2 self.tokenizer_max_length = 48 self.train_expert_only = True self.train_state_proj = True self.type = 'smolvla' self.use_amp = False self.use_cache = True self.use_delta_joint_actions_aloha = False self.vlm_model_name = 'HuggingFaceTB/SmolVLM2-500M-Video-Instruct' for k, v in kwargs.items(): if not hasattr(self, k): setattr(self, k, v) class LerobotSmolVLAWrappedModel(PreTrainedModel): config_class = LerobotSmolVLAConfig def __init__(self, config): super().__init__(config) # to_dict() correctly extracts all config parameters for the policy policy_init_kwargs = config.to_dict() self.smolvla_policy = SmolVLABasePolicy(**policy_init_kwargs) def forward(self, observations, actions=None, language_instruction=None, timestep=None): # This explicit signature is better for usability and documentation return self.smolvla_policy( observations=observations, actions=actions, language_instruction=language_instruction, timestep=timestep )