fix: 修复VLA设备与损失计算逻辑,并优化Transformer默认训练参数
This commit is contained in:
@@ -248,8 +248,11 @@ def main(cfg: DictConfig):
|
||||
# =========================================================================
|
||||
# 4. 设置优化器与学习率调度器
|
||||
# =========================================================================
|
||||
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
|
||||
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr})")
|
||||
weight_decay = float(cfg.train.get('weight_decay', 1e-5))
|
||||
grad_clip = float(cfg.train.get('grad_clip', 1.0))
|
||||
|
||||
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=weight_decay)
|
||||
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr}, weight_decay={weight_decay})")
|
||||
|
||||
# 设置带预热的学習率调度器
|
||||
warmup_steps = int(cfg.train.get('warmup_steps', 500))
|
||||
@@ -353,7 +356,7 @@ def main(cfg: DictConfig):
|
||||
loss.backward()
|
||||
|
||||
# 梯度裁剪以稳定训练
|
||||
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
|
||||
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=grad_clip)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
@@ -101,6 +101,22 @@ class VLAAgent(nn.Module):
|
||||
# 初始化队列(用于在线推理)
|
||||
self.reset()
|
||||
|
||||
def _get_model_device(self) -> torch.device:
|
||||
"""获取模型当前所在设备。"""
|
||||
return next(self.parameters()).device
|
||||
|
||||
def _move_to_device(self, data, device: torch.device):
|
||||
"""递归地将张量数据移动到指定设备。"""
|
||||
if torch.is_tensor(data):
|
||||
return data.to(device)
|
||||
if isinstance(data, dict):
|
||||
return {k: self._move_to_device(v, device) for k, v in data.items()}
|
||||
if isinstance(data, list):
|
||||
return [self._move_to_device(v, device) for v in data]
|
||||
if isinstance(data, tuple):
|
||||
return tuple(self._move_to_device(v, device) for v in data)
|
||||
return data
|
||||
|
||||
|
||||
# ==========================
|
||||
# 训练阶段 (Training)
|
||||
@@ -170,8 +186,9 @@ class VLAAgent(nn.Module):
|
||||
# 如果提供了 action_is_pad,对padding位置进行mask
|
||||
if action_is_pad is not None:
|
||||
# action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim)
|
||||
mask = ~action_is_pad.unsqueeze(-1) # True表示有效数据
|
||||
loss = (loss * mask).sum() / mask.sum() # 只对有效位置计算平均
|
||||
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype) # 1.0表示有效数据
|
||||
valid_count = mask.sum() * loss.shape[-1]
|
||||
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||||
else:
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -262,33 +279,10 @@ class VLAAgent(nn.Module):
|
||||
Returns:
|
||||
action: (action_dim,) 单个动作
|
||||
"""
|
||||
# 检测设备并确保所有组件在同一设备上
|
||||
# 尝试从观测中获取设备
|
||||
device = None
|
||||
for v in observation.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
device = v.device
|
||||
break
|
||||
|
||||
if device is not None and self.normalization.enabled:
|
||||
# 确保归一化参数在同一设备上
|
||||
# 根据归一化类型获取正确的属性
|
||||
if self.normalization.normalization_type == 'gaussian':
|
||||
norm_device = self.normalization.qpos_mean.device
|
||||
else: # min_max
|
||||
norm_device = self.normalization.qpos_min.device
|
||||
|
||||
if device != norm_device:
|
||||
self.normalization.to(device)
|
||||
# 同时确保其他模块也在正确设备
|
||||
self.vision_encoder.to(device)
|
||||
self.state_encoder.to(device)
|
||||
self.action_encoder.to(device)
|
||||
self.noise_pred_net.to(device)
|
||||
|
||||
# 将所有 observation 移到正确设备
|
||||
observation = {k: v.to(device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in observation.items()}
|
||||
# 使用模型当前设备作为唯一真值,将输入移动到模型设备
|
||||
# 避免根据CPU观测把模型错误搬回CPU。
|
||||
device = self._get_model_device()
|
||||
observation = self._move_to_device(observation, device)
|
||||
|
||||
# 将新观测添加到队列
|
||||
self._populate_queues(observation)
|
||||
@@ -355,6 +349,16 @@ class VLAAgent(nn.Module):
|
||||
visual_features = self.vision_encoder(images)
|
||||
state_features = self.state_encoder(proprioception)
|
||||
|
||||
# 拼接条件(只计算一次)
|
||||
# visual_features: (B, obs_horizon, vision_dim)
|
||||
# state_features: (B, obs_horizon, obs_dim)
|
||||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||||
global_cond_flat = global_cond.flatten(start_dim=1)
|
||||
if self.head_type == 'transformer':
|
||||
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
||||
else:
|
||||
cond = None
|
||||
|
||||
# 2. 初始化纯高斯噪声动作
|
||||
# 形状: (B, pred_horizon, action_dim)
|
||||
device = visual_features.device
|
||||
@@ -368,17 +372,8 @@ class VLAAgent(nn.Module):
|
||||
for t in self.infer_scheduler.timesteps:
|
||||
model_input = current_actions
|
||||
|
||||
# 拼接全局条件并展平
|
||||
# visual_features: (B, obs_horizon, vision_dim)
|
||||
# state_features: (B, obs_horizon, obs_dim)
|
||||
# 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim))
|
||||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||||
global_cond = global_cond.flatten(start_dim=1)
|
||||
|
||||
# 预测噪声(根据head类型选择接口)
|
||||
if self.head_type == 'transformer':
|
||||
# Transformer需要序列格式的条件
|
||||
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
||||
noise_pred = self.noise_pred_net(
|
||||
sample=model_input,
|
||||
timestep=t,
|
||||
@@ -388,7 +383,7 @@ class VLAAgent(nn.Module):
|
||||
noise_pred = self.noise_pred_net(
|
||||
sample=model_input,
|
||||
timestep=t,
|
||||
global_cond=global_cond
|
||||
global_cond=global_cond_flat
|
||||
)
|
||||
|
||||
# 移除噪声,更新 current_actions
|
||||
|
||||
@@ -10,7 +10,7 @@ defaults:
|
||||
train:
|
||||
# 基础训练参数
|
||||
batch_size: 8 # 批次大小
|
||||
lr: 1e-4 # 学习率
|
||||
lr: 5e-5 # 学习率(Transformer建议更小)
|
||||
max_steps: 100000 # 最大训练步数
|
||||
device: "cuda" # 设备: "cuda" 或 "cpu"
|
||||
|
||||
@@ -24,7 +24,7 @@ train:
|
||||
save_freq: 2000 # 保存检查点频率(步数)
|
||||
|
||||
# 学习率调度器(带预热)
|
||||
warmup_steps: 500 # 预热步数
|
||||
warmup_steps: 2000 # 预热步数(Transformer建议更长)
|
||||
scheduler_type: "cosine" # 预热后的调度器: "constant" 或 "cosine"
|
||||
min_lr: 1e-6 # 最小学习率(用于余弦退火)
|
||||
|
||||
@@ -41,4 +41,4 @@ train:
|
||||
experiment:
|
||||
name: "vla_diffusion" # 实验名称
|
||||
notes: "" # 实验备注
|
||||
tags: [] # 实验标签
|
||||
tags: [] # 实验标签
|
||||
|
||||
@@ -5,18 +5,18 @@ _partial_: true
|
||||
# ====================
|
||||
# Transformer 架构配置
|
||||
# ====================
|
||||
n_layer: 8 # Transformer层数
|
||||
n_head: 8 # 注意力头数
|
||||
n_emb: 256 # 嵌入维度
|
||||
p_drop_emb: 0.1 # Embedding dropout
|
||||
p_drop_attn: 0.1 # Attention dropout
|
||||
n_layer: 4 # Transformer层数(先用小模型提高收敛稳定性)
|
||||
n_head: 4 # 注意力头数
|
||||
n_emb: 128 # 嵌入维度
|
||||
p_drop_emb: 0.05 # Embedding dropout
|
||||
p_drop_attn: 0.05 # Attention dropout
|
||||
|
||||
# ====================
|
||||
# 条件配置
|
||||
# ====================
|
||||
causal_attn: false # 是否使用因果注意力(自回归生成)
|
||||
obs_as_cond: true # 观测作为条件(由cond_dim > 0决定)
|
||||
n_cond_layers: 0 # 条件编码器层数(0表示使用MLP,>0使用TransformerEncoder)
|
||||
n_cond_layers: 1 # 条件编码器层数(1层先做稳定融合)
|
||||
|
||||
# ====================
|
||||
# 注意事项
|
||||
|
||||
Reference in New Issue
Block a user