fix: 修复VLA设备与损失计算逻辑,并优化Transformer默认训练参数

This commit is contained in:
gouhanke
2026-03-03 17:56:12 +08:00
parent cdb887c9bf
commit 8bcad5844e
4 changed files with 49 additions and 51 deletions

View File

@@ -248,8 +248,11 @@ def main(cfg: DictConfig):
# ========================================================================= # =========================================================================
# 4. 设置优化器与学习率调度器 # 4. 设置优化器与学习率调度器
# ========================================================================= # =========================================================================
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5) weight_decay = float(cfg.train.get('weight_decay', 1e-5))
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr})") grad_clip = float(cfg.train.get('grad_clip', 1.0))
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=weight_decay)
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr}, weight_decay={weight_decay})")
# 设置带预热的学習率调度器 # 设置带预热的学習率调度器
warmup_steps = int(cfg.train.get('warmup_steps', 500)) warmup_steps = int(cfg.train.get('warmup_steps', 500))
@@ -353,7 +356,7 @@ def main(cfg: DictConfig):
loss.backward() loss.backward()
# 梯度裁剪以稳定训练 # 梯度裁剪以稳定训练
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0) torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=grad_clip)
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()

View File

@@ -101,6 +101,22 @@ class VLAAgent(nn.Module):
# 初始化队列(用于在线推理) # 初始化队列(用于在线推理)
self.reset() self.reset()
def _get_model_device(self) -> torch.device:
"""获取模型当前所在设备。"""
return next(self.parameters()).device
def _move_to_device(self, data, device: torch.device):
"""递归地将张量数据移动到指定设备。"""
if torch.is_tensor(data):
return data.to(device)
if isinstance(data, dict):
return {k: self._move_to_device(v, device) for k, v in data.items()}
if isinstance(data, list):
return [self._move_to_device(v, device) for v in data]
if isinstance(data, tuple):
return tuple(self._move_to_device(v, device) for v in data)
return data
# ========================== # ==========================
# 训练阶段 (Training) # 训练阶段 (Training)
@@ -170,8 +186,9 @@ class VLAAgent(nn.Module):
# 如果提供了 action_is_pad对padding位置进行mask # 如果提供了 action_is_pad对padding位置进行mask
if action_is_pad is not None: if action_is_pad is not None:
# action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim) # action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim)
mask = ~action_is_pad.unsqueeze(-1) # True表示有效数据 mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype) # 1.0表示有效数据
loss = (loss * mask).sum() / mask.sum() # 只对有效位置计算平均 valid_count = mask.sum() * loss.shape[-1]
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
else: else:
loss = loss.mean() loss = loss.mean()
@@ -262,33 +279,10 @@ class VLAAgent(nn.Module):
Returns: Returns:
action: (action_dim,) 单个动作 action: (action_dim,) 单个动作
""" """
# 检测设备并确保所有组件在同一设备 # 使用模型当前设备作为唯一真值,将输入移动到模型设备
# 尝试从观测中获取设备 # 避免根据CPU观测把模型错误搬回CPU。
device = None device = self._get_model_device()
for v in observation.values(): observation = self._move_to_device(observation, device)
if isinstance(v, torch.Tensor):
device = v.device
break
if device is not None and self.normalization.enabled:
# 确保归一化参数在同一设备上
# 根据归一化类型获取正确的属性
if self.normalization.normalization_type == 'gaussian':
norm_device = self.normalization.qpos_mean.device
else: # min_max
norm_device = self.normalization.qpos_min.device
if device != norm_device:
self.normalization.to(device)
# 同时确保其他模块也在正确设备
self.vision_encoder.to(device)
self.state_encoder.to(device)
self.action_encoder.to(device)
self.noise_pred_net.to(device)
# 将所有 observation 移到正确设备
observation = {k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()}
# 将新观测添加到队列 # 将新观测添加到队列
self._populate_queues(observation) self._populate_queues(observation)
@@ -355,6 +349,16 @@ class VLAAgent(nn.Module):
visual_features = self.vision_encoder(images) visual_features = self.vision_encoder(images)
state_features = self.state_encoder(proprioception) state_features = self.state_encoder(proprioception)
# 拼接条件(只计算一次)
# visual_features: (B, obs_horizon, vision_dim)
# state_features: (B, obs_horizon, obs_dim)
global_cond = torch.cat([visual_features, state_features], dim=-1)
global_cond_flat = global_cond.flatten(start_dim=1)
if self.head_type == 'transformer':
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
else:
cond = None
# 2. 初始化纯高斯噪声动作 # 2. 初始化纯高斯噪声动作
# 形状: (B, pred_horizon, action_dim) # 形状: (B, pred_horizon, action_dim)
device = visual_features.device device = visual_features.device
@@ -368,17 +372,8 @@ class VLAAgent(nn.Module):
for t in self.infer_scheduler.timesteps: for t in self.infer_scheduler.timesteps:
model_input = current_actions model_input = current_actions
# 拼接全局条件并展平
# visual_features: (B, obs_horizon, vision_dim)
# state_features: (B, obs_horizon, obs_dim)
# 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim))
global_cond = torch.cat([visual_features, state_features], dim=-1)
global_cond = global_cond.flatten(start_dim=1)
# 预测噪声根据head类型选择接口 # 预测噪声根据head类型选择接口
if self.head_type == 'transformer': if self.head_type == 'transformer':
# Transformer需要序列格式的条件
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
noise_pred = self.noise_pred_net( noise_pred = self.noise_pred_net(
sample=model_input, sample=model_input,
timestep=t, timestep=t,
@@ -388,7 +383,7 @@ class VLAAgent(nn.Module):
noise_pred = self.noise_pred_net( noise_pred = self.noise_pred_net(
sample=model_input, sample=model_input,
timestep=t, timestep=t,
global_cond=global_cond global_cond=global_cond_flat
) )
# 移除噪声,更新 current_actions # 移除噪声,更新 current_actions

View File

@@ -10,7 +10,7 @@ defaults:
train: train:
# 基础训练参数 # 基础训练参数
batch_size: 8 # 批次大小 batch_size: 8 # 批次大小
lr: 1e-4 # 学习率 lr: 5e-5 # 学习率Transformer建议更小
max_steps: 100000 # 最大训练步数 max_steps: 100000 # 最大训练步数
device: "cuda" # 设备: "cuda" 或 "cpu" device: "cuda" # 设备: "cuda" 或 "cpu"
@@ -24,7 +24,7 @@ train:
save_freq: 2000 # 保存检查点频率(步数) save_freq: 2000 # 保存检查点频率(步数)
# 学习率调度器(带预热) # 学习率调度器(带预热)
warmup_steps: 500 # 预热步数 warmup_steps: 2000 # 预热步数Transformer建议更长
scheduler_type: "cosine" # 预热后的调度器: "constant" 或 "cosine" scheduler_type: "cosine" # 预热后的调度器: "constant" 或 "cosine"
min_lr: 1e-6 # 最小学习率(用于余弦退火) min_lr: 1e-6 # 最小学习率(用于余弦退火)
@@ -41,4 +41,4 @@ train:
experiment: experiment:
name: "vla_diffusion" # 实验名称 name: "vla_diffusion" # 实验名称
notes: "" # 实验备注 notes: "" # 实验备注
tags: [] # 实验标签 tags: [] # 实验标签

View File

@@ -5,18 +5,18 @@ _partial_: true
# ==================== # ====================
# Transformer 架构配置 # Transformer 架构配置
# ==================== # ====================
n_layer: 8 # Transformer层数 n_layer: 4 # Transformer层数(先用小模型提高收敛稳定性)
n_head: 8 # 注意力头数 n_head: 4 # 注意力头数
n_emb: 256 # 嵌入维度 n_emb: 128 # 嵌入维度
p_drop_emb: 0.1 # Embedding dropout p_drop_emb: 0.05 # Embedding dropout
p_drop_attn: 0.1 # Attention dropout p_drop_attn: 0.05 # Attention dropout
# ==================== # ====================
# 条件配置 # 条件配置
# ==================== # ====================
causal_attn: false # 是否使用因果注意力(自回归生成) causal_attn: false # 是否使用因果注意力(自回归生成)
obs_as_cond: true # 观测作为条件由cond_dim > 0决定 obs_as_cond: true # 观测作为条件由cond_dim > 0决定
n_cond_layers: 0 # 条件编码器层数(0表示使用MLP>0使用TransformerEncoder n_cond_layers: 1 # 条件编码器层数(1层先做稳定融合
# ==================== # ====================
# 注意事项 # 注意事项