feat: 冻结resnet
This commit is contained in:
@@ -116,12 +116,19 @@ class VLAAgent(nn.Module):
|
|||||||
action_features, noise, timesteps
|
action_features, noise, timesteps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 拼接全局条件并展平
|
||||||
|
# visual_features: (B, obs_horizon, vision_dim)
|
||||||
|
# state_features: (B, obs_horizon, obs_dim)
|
||||||
|
# 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim))
|
||||||
|
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||||||
|
global_cond = global_cond.flatten(start_dim=1)
|
||||||
|
|
||||||
|
|
||||||
# 5. 网络预测噪声
|
# 5. 网络预测噪声
|
||||||
pred_noise = self.noise_pred_net(
|
pred_noise = self.noise_pred_net(
|
||||||
sample=noisy_actions,
|
sample=noisy_actions,
|
||||||
timestep=timesteps,
|
timestep=timesteps,
|
||||||
visual_features=visual_features,
|
global_cond=global_cond
|
||||||
proprioception=state_features
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. 计算 Loss (MSE)
|
# 6. 计算 Loss (MSE)
|
||||||
@@ -314,12 +321,18 @@ class VLAAgent(nn.Module):
|
|||||||
for t in self.infer_scheduler.timesteps:
|
for t in self.infer_scheduler.timesteps:
|
||||||
model_input = current_actions
|
model_input = current_actions
|
||||||
|
|
||||||
|
# 拼接全局条件并展平
|
||||||
|
# visual_features: (B, obs_horizon, vision_dim)
|
||||||
|
# state_features: (B, obs_horizon, obs_dim)
|
||||||
|
# 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim))
|
||||||
|
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||||||
|
global_cond = global_cond.flatten(start_dim=1)
|
||||||
|
|
||||||
# 预测噪声
|
# 预测噪声
|
||||||
noise_pred = self.noise_pred_net(
|
noise_pred = self.noise_pred_net(
|
||||||
sample=model_input,
|
sample=model_input,
|
||||||
timestep=t,
|
timestep=t,
|
||||||
visual_features=visual_features,
|
global_cond=global_cond
|
||||||
proprioception=state_features
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 移除噪声,更新 current_actions
|
# 移除噪声,更新 current_actions
|
||||||
|
|||||||
@@ -4,7 +4,12 @@ _target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
|
|||||||
# 骨干网络选择
|
# 骨干网络选择
|
||||||
# ====================
|
# ====================
|
||||||
vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50
|
vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50
|
||||||
pretrained_backbone_weights: null # 预训练权重路径或 null(ImageNet 权重)
|
pretrained_backbone_weights: "IMAGENET1K_V1" # 使用ImageNet预训练权重(torchvision>=0.13)
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 冻结设置
|
||||||
|
# ====================
|
||||||
|
freeze_backbone: true # 冻结ResNet参数,只训练后面的pool和out层(推荐:true)
|
||||||
|
|
||||||
# ====================
|
# ====================
|
||||||
# 输入配置
|
# 输入配置
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ class _SingleRgbEncoder(nn.Module):
|
|||||||
crop_is_random: bool,
|
crop_is_random: bool,
|
||||||
use_group_norm: bool,
|
use_group_norm: bool,
|
||||||
spatial_softmax_num_keypoints: int,
|
spatial_softmax_num_keypoints: int,
|
||||||
|
freeze_backbone: bool = True, # 新增:是否冻结backbone
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -133,6 +134,11 @@ class _SingleRgbEncoder(nn.Module):
|
|||||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 冻结backbone参数(可选)
|
||||||
|
if freeze_backbone:
|
||||||
|
for param in self.backbone.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
# 设置池化和最终层
|
# 设置池化和最终层
|
||||||
# 使用试运行来获取特征图形状
|
# 使用试运行来获取特征图形状
|
||||||
dummy_shape = (1, input_shape[0], *crop_shape)
|
dummy_shape = (1, input_shape[0], *crop_shape)
|
||||||
@@ -164,6 +170,7 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
spatial_softmax_num_keypoints: int = 32,
|
spatial_softmax_num_keypoints: int = 32,
|
||||||
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
|
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
|
||||||
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
||||||
|
freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True)
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -181,6 +188,7 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
crop_is_random=crop_is_random,
|
crop_is_random=crop_is_random,
|
||||||
use_group_norm=use_group_norm,
|
use_group_norm=use_group_norm,
|
||||||
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||||
|
freeze_backbone=freeze_backbone,
|
||||||
)
|
)
|
||||||
for _ in range(num_cameras)
|
for _ in range(num_cameras)
|
||||||
]
|
]
|
||||||
@@ -197,6 +205,7 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
crop_is_random=crop_is_random,
|
crop_is_random=crop_is_random,
|
||||||
use_group_norm=use_group_norm,
|
use_group_norm=use_group_norm,
|
||||||
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||||
|
freeze_backbone=freeze_backbone,
|
||||||
)
|
)
|
||||||
self.feature_dim = self.rgb_encoder.feature_dim
|
self.feature_dim = self.rgb_encoder.feature_dim
|
||||||
|
|
||||||
|
|||||||
@@ -225,27 +225,15 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
timestep: Union[torch.Tensor, float, int],
|
timestep: Union[torch.Tensor, float, int],
|
||||||
local_cond=None, global_cond=None,
|
local_cond=None, global_cond=None,
|
||||||
visual_features=None, proprioception=None,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
x: (B,T,input_dim)
|
x: (B,T,input_dim)
|
||||||
timestep: (B,) or int, diffusion step
|
timestep: (B,) or int, diffusion step
|
||||||
local_cond: (B,T,local_cond_dim)
|
local_cond: (B,T,local_cond_dim)
|
||||||
global_cond: (B,global_cond_dim)
|
global_cond: (B,global_cond_dim)
|
||||||
visual_features: (B, T_obs, D_vis)
|
|
||||||
proprioception: (B, T_obs, D_prop)
|
|
||||||
output: (B,T,input_dim)
|
output: (B,T,input_dim)
|
||||||
"""
|
"""
|
||||||
if global_cond is None:
|
|
||||||
conds = []
|
|
||||||
if visual_features is not None:
|
|
||||||
conds.append(visual_features.flatten(start_dim=1))
|
|
||||||
if proprioception is not None:
|
|
||||||
conds.append(proprioception.flatten(start_dim=1))
|
|
||||||
if len(conds) > 0:
|
|
||||||
global_cond = torch.cat(conds, dim=-1)
|
|
||||||
|
|
||||||
sample = einops.rearrange(sample, 'b h t -> b t h')
|
sample = einops.rearrange(sample, 'b h t -> b t h')
|
||||||
|
|
||||||
# 1. time
|
# 1. time
|
||||||
|
|||||||
Reference in New Issue
Block a user