|
|
|
import torch.nn as nn |
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
|
from lerobot.smolvla_base import SmolVLABasePolicy |
|
|
|
class LerobotSmolVLAConfig(PretrainedConfig): |
|
model_type = "lerobot_smolvla" |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.adapt_to_pi_aloha = False |
|
self.add_image_special_tokens = False |
|
self.attention_mode = 'cross_attn' |
|
self.chunk_size = 50 |
|
self.device = 'cuda' |
|
self.empty_cameras = 0 |
|
self.expert_width_multiplier = 0.75 |
|
self.freeze_vision_encoder = True |
|
self.input_features = {'observation.image': {'shape': [3, 256, 256], 'type': 'VISUAL'}, 'observation.image2': {'shape': [3, 256, 256], 'type': 'VISUAL'}, 'observation.image3': {'shape': [3, 256, 256], 'type': 'VISUAL'}, 'observation.state': {'shape': [6], 'type': 'STATE'}} |
|
self.load_vlm_weights = True |
|
self.max_action_dim = 32 |
|
self.max_period = 4 |
|
self.max_state_dim = 32 |
|
self.min_period = 0.004 |
|
self.n_action_steps = 50 |
|
self.n_obs_steps = 1 |
|
self.normalization_mapping = {'ACTION': 'MEAN_STD', 'STATE': 'MEAN_STD', 'VISUAL': 'IDENTITY'} |
|
self.num_expert_layers = 0 |
|
self.num_steps = 10 |
|
self.num_vlm_layers = 16 |
|
self.optimizer_betas = [0.9, 0.95] |
|
self.optimizer_eps = '1e-08' |
|
self.optimizer_grad_clip_norm = 10 |
|
self.optimizer_lr = 0.0001 |
|
self.optimizer_weight_decay = '1e-10' |
|
self.output_features = {'action': {'shape': [6], 'type': 'ACTION'}} |
|
self.pad_language_to = 'max_length' |
|
self.prefix_length = 0 |
|
self.resize_imgs_with_padding = [512, 512] |
|
self.scheduler_decay_lr = 2.5e-06 |
|
self.scheduler_decay_steps = 30000 |
|
self.scheduler_warmup_steps = 1000 |
|
self.self_attn_every_n_layers = 2 |
|
self.tokenizer_max_length = 48 |
|
self.train_expert_only = True |
|
self.train_state_proj = True |
|
self.type = 'smolvla' |
|
self.use_amp = False |
|
self.use_cache = True |
|
self.use_delta_joint_actions_aloha = False |
|
self.vlm_model_name = 'HuggingFaceTB/SmolVLM2-500M-Video-Instruct' |
|
|
|
for k, v in kwargs.items(): |
|
if not hasattr(self, k): |
|
setattr(self, k, v) |
|
|
|
class LerobotSmolVLAWrappedModel(PreTrainedModel): |
|
config_class = LerobotSmolVLAConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
policy_init_kwargs = config.to_dict() |
|
self.smolvla_policy = SmolVLABasePolicy(**policy_init_kwargs) |
|
|
|
def forward(self, observations, actions=None, language_instruction=None, timestep=None): |
|
|
|
return self.smolvla_policy( |
|
observations=observations, |
|
actions=actions, |
|
language_instruction=language_instruction, |
|
timestep=timestep |
|
) |
|
|