feat: add held-out validation and dual-decoder lewm imf

This commit is contained in:
Logic
2026-04-17 19:26:56 +08:00
parent 74f4963613
commit 395f5a1645
8 changed files with 693 additions and 86 deletions

View File

@@ -118,6 +118,127 @@ def recursive_to_device(data, device):
return data return data
def build_agent_input(batch_data):
agent_input = {
'images': {
cam_name.replace('observation.', ''): value
for cam_name, value in batch_data.items()
if cam_name.startswith('observation.') and cam_name != 'observation.state'
},
'qpos': batch_data['observation.state'],
'action': batch_data['action'],
}
if 'action_is_pad' in batch_data:
agent_input['action_is_pad'] = batch_data['action_is_pad']
lewm_images = {
cam_name.replace('lewm.observation.', ''): value
for cam_name, value in batch_data.items()
if cam_name.startswith('lewm.observation.') and cam_name != 'lewm.observation.state'
}
if lewm_images:
agent_input['lewm_images'] = lewm_images
if 'lewm.observation.state' in batch_data:
agent_input['lewm_qpos'] = batch_data['lewm.observation.state']
lewm_future_images = {
cam_name.replace('lewm.future.', ''): value
for cam_name, value in batch_data.items()
if cam_name.startswith('lewm.future.') and cam_name != 'lewm.future.state'
}
if lewm_future_images:
agent_input['lewm_future_images'] = lewm_future_images
if 'lewm.future.state' in batch_data:
agent_input['lewm_future_qpos'] = batch_data['lewm.future.state']
return agent_input
def _instantiate_dataset(cfg, dataset_image_resize_shape, episode_indices=None):
kwargs = {'image_resize_shape': dataset_image_resize_shape}
if episode_indices is not None:
kwargs['episode_indices'] = episode_indices
return instantiate(cfg.data, **kwargs)
def build_train_val_datasets(cfg, dataset_image_resize_shape):
val_episode_indices = cfg.train.get('val_episode_indices', None)
if val_episode_indices:
dataset = _instantiate_dataset(cfg, dataset_image_resize_shape)
available_episode_indices = list(getattr(dataset, 'available_episode_indices', []))
if not available_episode_indices:
raise ValueError('显式 val_episode_indices 需要数据集暴露 available_episode_indices')
requested_val_episode_indices = sorted(int(idx) for idx in val_episode_indices)
available_set = set(available_episode_indices)
missing = sorted(set(requested_val_episode_indices) - available_set)
if missing:
raise ValueError(
f'val_episode_indices {missing} 不存在于数据集可用 episodes {available_episode_indices}'
)
train_episode_indices = [
idx for idx in available_episode_indices
if idx not in set(requested_val_episode_indices)
]
if not train_episode_indices:
raise ValueError('显式 val_episode_indices 不能覆盖全部 episodes训练集将为空')
train_dataset = _instantiate_dataset(
cfg,
dataset_image_resize_shape,
episode_indices=train_episode_indices,
)
val_dataset = _instantiate_dataset(
cfg,
dataset_image_resize_shape,
episode_indices=requested_val_episode_indices,
)
return dataset, train_dataset, val_dataset, requested_val_episode_indices
dataset = _instantiate_dataset(cfg, dataset_image_resize_shape)
val_split = float(cfg.train.get('val_split', 0.1))
seed = int(cfg.train.get('seed', 42))
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
if val_size > 0:
train_dataset, val_dataset = random_split(
dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(seed)
)
else:
train_dataset, val_dataset = dataset, None
return dataset, train_dataset, val_dataset, None
def compute_action_mse_validation(agent, val_loader, device):
if val_loader is None:
return None
was_training = agent.training
agent.eval()
total_squared_error = 0.0
total_count = 0.0
with torch.no_grad():
for val_batch in val_loader:
val_batch = recursive_to_device(val_batch, device)
val_input = build_agent_input(val_batch)
pred_actions = agent.predict_action_chunk(val_input)
target_actions = val_input['action']
squared_error = (pred_actions - target_actions).pow(2)
action_is_pad = val_input.get('action_is_pad', None)
if action_is_pad is not None:
mask = (~action_is_pad).unsqueeze(-1).to(squared_error.dtype)
total_squared_error += (squared_error * mask).sum().item()
total_count += mask.sum().item() * squared_error.shape[-1]
else:
total_squared_error += squared_error.sum().item()
total_count += target_actions.numel()
if was_training:
agent.train()
return total_squared_error / max(total_count, 1.0)
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir): def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
""" """
解析恢复训练用的 checkpoint 路径。 解析恢复训练用的 checkpoint 路径。
@@ -410,29 +531,29 @@ def _run_training(cfg: DictConfig):
vision_backbone_cfg = cfg.agent.get('vision_backbone', None) vision_backbone_cfg = cfg.agent.get('vision_backbone', None)
if vision_backbone_cfg is not None and 'dataset_image_resize_shape' in vision_backbone_cfg: if vision_backbone_cfg is not None and 'dataset_image_resize_shape' in vision_backbone_cfg:
dataset_image_resize_shape = vision_backbone_cfg.get('dataset_image_resize_shape') dataset_image_resize_shape = vision_backbone_cfg.get('dataset_image_resize_shape')
dataset = instantiate( dataset, train_dataset, val_dataset, explicit_val_episode_indices = (
cfg.data, build_train_val_datasets(cfg, dataset_image_resize_shape)
image_resize_shape=dataset_image_resize_shape,
) )
log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}") log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}")
except Exception as e: except Exception as e:
log.error(f"❌ 数据集加载失败: {e}") log.error(f"❌ 数据集加载失败: {e}")
raise raise
# 训练/验证集划分 if explicit_val_episode_indices is not None:
val_split = float(cfg.train.get('val_split', 0.1)) log.info(
seed = int(cfg.train.get('seed', 42)) "✅ 数据集划分: 训练集=%s, 验证集=%s (显式 held-out episodes=%s)",
val_size = int(len(dataset) * val_split) len(train_dataset),
train_size = len(dataset) - val_size len(val_dataset),
if val_size > 0: explicit_val_episode_indices,
train_dataset, val_dataset = random_split( )
dataset, else:
[train_size, val_size], val_split = float(cfg.train.get('val_split', 0.1))
generator=torch.Generator().manual_seed(seed) val_size = len(val_dataset) if val_dataset is not None else 0
if val_size > 0:
log.info(
f"✅ 数据集划分: 训练集={len(train_dataset)}, 验证集={val_size} (验证比例={val_split})"
) )
log.info(f"✅ 数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})")
else: else:
train_dataset, val_dataset = dataset, None
log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)") log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)")
train_batch_size = int(cfg.train.batch_size) train_batch_size = int(cfg.train.batch_size)
@@ -674,44 +795,6 @@ def _run_training(cfg: DictConfig):
# ========================================================================= # =========================================================================
log.info("🏋️ 开始训练循环...") log.info("🏋️ 开始训练循环...")
def build_agent_input(batch_data):
"""构建 agent 输入格式"""
images = {}
# SimpleRobotDataset 返回 observation.{cam_name} 格式
for cam_name in cfg.data.camera_names:
key = f"observation.{cam_name}"
if key in batch_data:
images[cam_name] = batch_data[key]
agent_input = {
'images': images,
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
'action': batch_data['action'],
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
}
lewm_images = {}
lewm_future_images = {}
for cam_name in cfg.data.camera_names:
lewm_obs_key = f"lewm.observation.{cam_name}"
if lewm_obs_key in batch_data:
lewm_images[cam_name] = batch_data[lewm_obs_key]
lewm_future_key = f"lewm.future.{cam_name}"
if lewm_future_key in batch_data:
lewm_future_images[cam_name] = batch_data[lewm_future_key]
if 'lewm.observation.state' in batch_data:
agent_input['lewm_qpos'] = batch_data['lewm.observation.state']
if lewm_images:
agent_input['lewm_images'] = lewm_images
if 'lewm.future.state' in batch_data:
agent_input['lewm_future_qpos'] = batch_data['lewm.future.state']
if lewm_future_images:
agent_input['lewm_future_images'] = lewm_future_images
return agent_input
def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None): def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None):
agent_stats = agent.get_normalization_stats() agent_stats = agent.get_normalization_stats()
torch.save({ torch.save({
@@ -784,6 +867,7 @@ def _run_training(cfg: DictConfig):
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100) pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
steps_per_epoch = len(train_loader) steps_per_epoch = len(train_loader)
action_mse_val_freq_epochs = int(cfg.train.get('action_mse_val_freq_epochs', 0) or 0)
rollout_val_freq_epochs = int(cfg.train.get('rollout_val_freq_epochs', 0) or 0) rollout_val_freq_epochs = int(cfg.train.get('rollout_val_freq_epochs', 0) or 0)
rollout_validation_enabled = rollout_val_freq_epochs > 0 rollout_validation_enabled = rollout_val_freq_epochs > 0
best_loss = resume_best_loss best_loss = resume_best_loss
@@ -953,6 +1037,33 @@ def _run_training(cfg: DictConfig):
and completed_epoch > 0 and completed_epoch > 0
and completed_epoch % rollout_val_freq_epochs == 0 and completed_epoch % rollout_val_freq_epochs == 0
) )
should_run_action_mse_validation = (
action_mse_val_freq_epochs > 0
and val_loader is not None
and steps_per_epoch > 0
and completed_steps % steps_per_epoch == 0
and completed_epoch > 0
and completed_epoch % action_mse_val_freq_epochs == 0
)
if should_run_action_mse_validation:
action_mse = compute_action_mse_validation(
agent,
val_loader,
cfg.train.device,
)
if action_mse is not None:
log.info(
f"步骤 {step}/{cfg.train.max_steps} | Epoch {completed_epoch} "
f"held-out action MSE: {action_mse:.6f}"
)
_log_to_swanlab(
swanlab_module,
{
'val/action_mse': action_mse,
'val/action_mse_epoch': completed_epoch,
},
step=step,
)
if should_run_epoch_rollout: if should_run_epoch_rollout:
if checkpoint_path is None: if checkpoint_path is None:
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt" checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"

View File

@@ -26,6 +26,8 @@ class IMFVLAAgent(VLAAgent):
lewm_query_offsets: Optional[Sequence[int]] = None, lewm_query_offsets: Optional[Sequence[int]] = None,
lewm_predictor: Optional[nn.Module] = None, lewm_predictor: Optional[nn.Module] = None,
lewm_pred_projector: Optional[nn.Module] = None, lewm_pred_projector: Optional[nn.Module] = None,
future_decoder: Optional[nn.Module] = None,
future_query_init_std: float = 0.02,
lewm_sigreg: Optional[nn.Module] = None, lewm_sigreg: Optional[nn.Module] = None,
lewm_sigreg_weight: float = 0.09, lewm_sigreg_weight: float = 0.09,
lewm_loss_weight: float = 0.0, lewm_loss_weight: float = 0.0,
@@ -39,11 +41,17 @@ class IMFVLAAgent(VLAAgent):
) )
lewm_query_offsets = tuple(int(offset) for offset in (lewm_query_offsets or ())) lewm_query_offsets = tuple(int(offset) for offset in (lewm_query_offsets or ()))
inferred_extra_condition_tokens = len(lewm_query_offsets) if lewm_query_offsets else 0 inferred_extra_condition_tokens = len(lewm_query_offsets) if lewm_query_offsets else 0
kwargs.setdefault('extra_condition_tokens', inferred_extra_condition_tokens) default_extra_condition_tokens = (
0 if future_decoder is not None else inferred_extra_condition_tokens
)
kwargs.setdefault('extra_condition_tokens', default_extra_condition_tokens)
self.__dict__['lewm_history_horizon'] = int(lewm_history_horizon or kwargs.get('obs_horizon', 1)) self.__dict__['lewm_history_horizon'] = int(lewm_history_horizon or kwargs.get('obs_horizon', 1))
self.__dict__['lewm_query_offsets'] = lewm_query_offsets self.__dict__['lewm_query_offsets'] = lewm_query_offsets
self.__dict__['lewm_predictor'] = lewm_predictor self.__dict__['lewm_predictor'] = lewm_predictor
self.__dict__['lewm_pred_projector'] = lewm_pred_projector or nn.Identity() self.__dict__['lewm_pred_projector'] = lewm_pred_projector or nn.Identity()
self.__dict__['future_decoder'] = future_decoder
self.__dict__['future_query_tokens'] = None
self.__dict__['future_query_init_std'] = float(future_query_init_std)
self.__dict__['lewm_sigreg'] = lewm_sigreg self.__dict__['lewm_sigreg'] = lewm_sigreg
self.__dict__['lewm_sigreg_weight'] = float(lewm_sigreg_weight) self.__dict__['lewm_sigreg_weight'] = float(lewm_sigreg_weight)
self.__dict__['lewm_loss_weight'] = float(lewm_loss_weight) self.__dict__['lewm_loss_weight'] = float(lewm_loss_weight)
@@ -59,8 +67,16 @@ class IMFVLAAgent(VLAAgent):
self.lewm_history_horizon = int(lewm_history_horizon or self.obs_horizon) self.lewm_history_horizon = int(lewm_history_horizon or self.obs_horizon)
self.lewm_predictor = lewm_predictor self.lewm_predictor = lewm_predictor
self.lewm_pred_projector = lewm_pred_projector or nn.Identity() self.lewm_pred_projector = lewm_pred_projector or nn.Identity()
if future_decoder is not None and not isinstance(future_decoder, nn.Module):
self.future_decoder = future_decoder()
else:
self.future_decoder = future_decoder
self.future_query_tokens = None
self.future_query_init_std = float(future_query_init_std)
self.lewm_sigreg = lewm_sigreg self.lewm_sigreg = lewm_sigreg
self.lewm_sigreg_weight = float(lewm_sigreg_weight) self.lewm_sigreg_weight = float(lewm_sigreg_weight)
if self.lewm_predictor is not None and self.future_decoder is not None:
raise ValueError('lewm_predictor and future_decoder are mutually exclusive')
if self.lewm_predictor is None and self.extra_condition_tokens > 0: if self.lewm_predictor is None and self.extra_condition_tokens > 0:
raise ValueError( raise ValueError(
'extra_condition_tokens > 0 requires lewm_predictor to be provided' 'extra_condition_tokens > 0 requires lewm_predictor to be provided'
@@ -69,6 +85,18 @@ class IMFVLAAgent(VLAAgent):
raise ValueError( raise ValueError(
'extra_condition_tokens must equal len(lewm_query_offsets) when lewm_predictor is enabled' 'extra_condition_tokens must equal len(lewm_query_offsets) when lewm_predictor is enabled'
) )
if self.future_decoder is not None:
if inferred_extra_condition_tokens <= 0:
raise ValueError('future_decoder requires non-empty lewm_query_offsets')
if self.extra_condition_tokens != 0:
raise ValueError('future_decoder requires extra_condition_tokens=0')
self.future_query_tokens = nn.Parameter(
torch.randn(
1,
inferred_extra_condition_tokens,
self.per_step_cond_dim,
) * self.future_query_init_std
)
if lewm_pretrained_ckpt is not None: if lewm_pretrained_ckpt is not None:
self.load_lewm_pretrained_components(lewm_pretrained_ckpt) self.load_lewm_pretrained_components(lewm_pretrained_ckpt)
@@ -232,15 +260,19 @@ class IMFVLAAgent(VLAAgent):
*, *,
query_key: str = 'query_tokens', query_key: str = 'query_tokens',
pos_key: str = 'pos_embedding', pos_key: str = 'pos_embedding',
) -> None: ) -> Dict[str, Sequence[str]]:
current_state_dict = module.state_dict() current_state_dict = module.state_dict()
adapted_state_dict = dict(current_state_dict) adapted_state_dict = dict(current_state_dict)
loaded_keys = []
mismatched_keys = []
missing_keys = []
for key, current_tensor in current_state_dict.items(): for key, current_tensor in current_state_dict.items():
if key not in incoming_state_dict: if key not in incoming_state_dict:
continue continue
source_tensor = incoming_state_dict[key] source_tensor = incoming_state_dict[key]
if source_tensor.shape == current_tensor.shape: if source_tensor.shape == current_tensor.shape:
adapted_state_dict[key] = source_tensor adapted_state_dict[key] = source_tensor
loaded_keys.append(key)
continue continue
if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim: if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim:
@@ -256,8 +288,47 @@ class IMFVLAAgent(VLAAgent):
if copy_count < current_tensor.shape[1] and copy_count > 0: if copy_count < current_tensor.shape[1] and copy_count > 0:
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...] patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
adapted_state_dict[key] = patched adapted_state_dict[key] = patched
loaded_keys.append(key)
continue
mismatched_keys.append(key)
for key in incoming_state_dict.keys():
if key not in current_state_dict:
missing_keys.append(key)
module.load_state_dict(adapted_state_dict, strict=True) module.load_state_dict(adapted_state_dict, strict=True)
return {
'loaded_keys': tuple(sorted(loaded_keys)),
'mismatched_keys': tuple(sorted(set(mismatched_keys))),
'missing_keys': tuple(sorted(set(missing_keys))),
}
@staticmethod
def _load_state_dict_ignoring_shape_mismatches(
module: nn.Module,
incoming_state_dict: Mapping[str, torch.Tensor],
) -> Dict[str, Sequence[str]]:
current_state_dict = module.state_dict()
merged_state_dict = dict(current_state_dict)
loaded_keys = []
mismatched_keys = []
missing_keys = []
for key, value in incoming_state_dict.items():
if key not in current_state_dict:
missing_keys.append(key)
continue
if current_state_dict[key].shape != value.shape:
mismatched_keys.append(key)
continue
merged_state_dict[key] = value
loaded_keys.append(key)
module.load_state_dict(merged_state_dict, strict=True)
return {
'loaded_keys': tuple(sorted(loaded_keys)),
'mismatched_keys': tuple(sorted(mismatched_keys)),
'missing_keys': tuple(sorted(missing_keys)),
}
def load_lewm_pretrained_components( def load_lewm_pretrained_components(
self, self,
@@ -266,20 +337,24 @@ class IMFVLAAgent(VLAAgent):
state_dict = self._load_checkpoint_payload(checkpoint_or_path) state_dict = self._load_checkpoint_payload(checkpoint_or_path)
if hasattr(self.vision_encoder, 'load_lewm_checkpoint'): if hasattr(self.vision_encoder, 'load_lewm_checkpoint'):
try:
self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict}) self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict})
except RuntimeError:
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict)
else: else:
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.') vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
self.vision_encoder.load_state_dict(vision_state_dict, strict=True) self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict)
state_encoder_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.state_encoder.') state_encoder_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.state_encoder.')
self.state_encoder.load_state_dict(state_encoder_state_dict, strict=True) self._load_state_dict_ignoring_shape_mismatches(self.state_encoder, state_encoder_state_dict)
projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.projector.proj.') projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.projector.proj.')
mapped_projector_state_dict = { mapped_projector_state_dict = {
f'linear.{key}': value f'linear.{key}': value
for key, value in projector_state_dict.items() for key, value in projector_state_dict.items()
} }
self.cond_projector.load_state_dict(mapped_projector_state_dict, strict=True) self._load_state_dict_ignoring_shape_mismatches(self.cond_projector, mapped_projector_state_dict)
if self.lewm_predictor is not None: if self.lewm_predictor is not None:
predictor_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.predictor.') predictor_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.predictor.')
@@ -287,7 +362,19 @@ class IMFVLAAgent(VLAAgent):
if self.lewm_pred_projector is not None: if self.lewm_pred_projector is not None:
pred_projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.pred_proj.') pred_projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.pred_proj.')
self.lewm_pred_projector.load_state_dict(pred_projector_state_dict, strict=True) self._load_state_dict_ignoring_shape_mismatches(
self.lewm_pred_projector,
pred_projector_state_dict,
)
def _predict_future_tokens_with_decoder(self, history_cond: torch.Tensor) -> torch.Tensor:
if self.future_decoder is None or self.future_query_tokens is None:
raise RuntimeError('future_decoder path requested but not initialized')
batch_size = history_cond.shape[0]
query_tokens = self.future_query_tokens.expand(batch_size, -1, -1)
r = torch.zeros(batch_size, device=history_cond.device, dtype=history_cond.dtype)
t = torch.ones(batch_size, device=history_cond.device, dtype=history_cond.dtype)
return self.future_decoder(query_tokens, r, t, cond=history_cond)
def _build_full_condition( def _build_full_condition(
self, self,
@@ -302,7 +389,7 @@ class IMFVLAAgent(VLAAgent):
predicted_future_tokens = None predicted_future_tokens = None
lewm_history_cond = None lewm_history_cond = None
if self.lewm_predictor is None: if self.lewm_predictor is None and self.future_decoder is None:
return history_cond, predicted_future_tokens, lewm_history_cond return history_cond, predicted_future_tokens, lewm_history_cond
lewm_images = lewm_images if lewm_images is not None else images lewm_images = lewm_images if lewm_images is not None else images
@@ -313,6 +400,8 @@ class IMFVLAAgent(VLAAgent):
lewm_images, lewm_images,
self._normalize_qpos_for_lewm(lewm_proprioception), self._normalize_qpos_for_lewm(lewm_proprioception),
) )
cond = history_cond
if self.lewm_predictor is not None:
predicted_future_tokens = self.lewm_predictor(lewm_history_cond) predicted_future_tokens = self.lewm_predictor(lewm_history_cond)
predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens) predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens)
cond = torch.cat([history_cond, predicted_future_tokens], dim=1) cond = torch.cat([history_cond, predicted_future_tokens], dim=1)
@@ -324,6 +413,8 @@ class IMFVLAAgent(VLAAgent):
raise RuntimeError( raise RuntimeError(
f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}" f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
) )
elif self.future_decoder is not None:
predicted_future_tokens = self._predict_future_tokens_with_decoder(lewm_history_cond)
return cond, predicted_future_tokens, lewm_history_cond return cond, predicted_future_tokens, lewm_history_cond
@staticmethod @staticmethod
@@ -452,12 +543,18 @@ class IMFVLAAgent(VLAAgent):
lewm_proprioception=None, lewm_proprioception=None,
): ):
batch_size = proprioception.shape[0] batch_size = proprioception.shape[0]
if self.lewm_predictor is not None:
cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition( cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition(
images, images,
proprioception, proprioception,
lewm_images=lewm_images, lewm_images=lewm_images,
lewm_proprioception=lewm_proprioception, lewm_proprioception=lewm_proprioception,
) )
else:
cond = self._build_cond(
images,
self.normalization.normalize_qpos(proprioception),
)
z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype) z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype)
action = self._sample_one_step(z_t, cond=cond) action = self._sample_one_step(z_t, cond=cond)
return self.normalization.denormalize_action(action) return self.normalization.denormalize_action(action)

View File

@@ -0,0 +1,74 @@
# @package agent
defaults:
- /backbone@vision_backbone: lewm_resnet_query_fusion
- /modules@state_encoder: lewm_state_encoder
- /modules@action_encoder: identity_action_encoder
- /modules@cond_projector: linear_condition_projector
- /head: imf_transformer1d
- /head@future_decoder: imf_transformer1d
- _self_
_target_: roboimi.vla.agent_imf.IMFVLAAgent
action_dim: 16
obs_dim: 16
normalization_type: "min_max"
pred_horizon: 8
obs_horizon: 2
num_action_steps: 8
camera_names: ${data.camera_names}
num_cams: 3
vision_backbone:
camera_names: ${agent.camera_names}
num_views: ${agent.num_cams}
cond_projector:
output_dim: 288
lewm_history_horizon: 3
lewm_query_offsets: [8]
extra_condition_tokens: 0
lewm_loss_weight: 1.0
lewm_sigreg_weight: 0.09
lewm_pretrained_ckpt: null
future_query_init_std: 0.02
lewm_sigreg:
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg
knots: 17
num_proj: 1024
diffusion_steps: 100
inference_steps: 1
head_type: "transformer"
head:
input_dim: ${agent.action_dim}
output_dim: ${agent.action_dim}
horizon: ${agent.pred_horizon}
n_obs_steps: ${agent.obs_horizon}
cond_dim: ${agent.cond_projector.output_dim}
n_emb: 384
causal_attn: false
time_as_cond: true
obs_as_cond: true
n_cond_layers: 0
backbone_type: attnres_full
n_head: 1
n_kv_head: 1
future_decoder:
input_dim: ${agent.cond_projector.output_dim}
output_dim: ${agent.cond_projector.output_dim}
horizon: ${len:${agent.lewm_query_offsets}}
n_obs_steps: ${agent.obs_horizon}
cond_dim: ${agent.cond_projector.output_dim}
n_emb: 384
causal_attn: false
time_as_cond: true
obs_as_cond: true
n_cond_layers: 0
backbone_type: attnres_full
n_head: 1
n_kv_head: 1

View File

@@ -18,6 +18,8 @@ train:
# 数据加载 # 数据加载
num_workers: 12 # DataLoader 工作进程数(调试时设为 0 num_workers: 12 # DataLoader 工作进程数(调试时设为 0
val_split: 0.0 # 验证集比例;默认使用全量数据训练 val_split: 0.0 # 验证集比例;默认使用全量数据训练
val_episode_indices: null # 显式按 episode 划出的验证集,例如 [100]
action_mse_val_freq_epochs: 0 # >0 时每隔多少个 epoch 在 held-out episode 上计算 action MSE
seed: 42 # 随机种子(用于数据划分) seed: 42 # 随机种子(用于数据划分)
# 日志和检查点 # 日志和检查点

View File

@@ -26,6 +26,7 @@ class SimpleRobotDataset(Dataset):
max_open_files: int = 64, max_open_files: int = 64,
lewm_history_horizon: Optional[int] = None, lewm_history_horizon: Optional[int] = None,
lewm_query_offsets: Optional[Sequence[int]] = None, lewm_query_offsets: Optional[Sequence[int]] = None,
episode_indices: Optional[Sequence[int]] = None,
): ):
""" """
Args: Args:
@@ -57,6 +58,9 @@ class SimpleRobotDataset(Dataset):
) )
self.max_open_files = max(1, int(max_open_files)) self.max_open_files = max(1, int(max_open_files))
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict() self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
self.requested_episode_indices = (
None if episode_indices is None else tuple(sorted(int(idx) for idx in episode_indices))
)
self.dataset_dir = Path(dataset_dir) self.dataset_dir = Path(dataset_dir)
if not self.dataset_dir.exists(): if not self.dataset_dir.exists():
@@ -69,20 +73,45 @@ class SimpleRobotDataset(Dataset):
if not self.hdf5_files: if not self.hdf5_files:
raise FileNotFoundError(f"{dataset_dir} 中未找到 HDF5 文件") raise FileNotFoundError(f"{dataset_dir} 中未找到 HDF5 文件")
if self.requested_episode_indices is not None:
requested = set(self.requested_episode_indices)
filtered = []
for hdf5_path in self.hdf5_files:
stem = hdf5_path.stem
if stem.startswith("episode_"):
try:
idx = int(stem.split("_")[-1])
except ValueError:
continue
if idx in requested:
filtered.append(hdf5_path)
self.hdf5_files = filtered
if not self.hdf5_files:
raise FileNotFoundError(
f"{dataset_dir} 中未找到 episode_indices={sorted(requested)} 对应的 HDF5 文件"
)
# 构建 episode 索引(只存储元数据,不加载数据) # 构建 episode 索引(只存储元数据,不加载数据)
self.episodes = {} self.episodes = {}
self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path) self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
for ep_idx, hdf5_path in enumerate(self.hdf5_files): for ep_idx, hdf5_path in enumerate(self.hdf5_files):
with h5py.File(hdf5_path, 'r') as f: with h5py.File(hdf5_path, 'r') as f:
T = f['action'].shape[0] T = f['action'].shape[0]
dataset_episode_idx = ep_idx
stem = hdf5_path.stem
if stem.startswith("episode_"):
try:
dataset_episode_idx = int(stem.split("_")[-1])
except ValueError:
pass
start_idx = len(self.frame_meta) start_idx = len(self.frame_meta)
for t in range(T): for t in range(T):
self.frame_meta.append({ self.frame_meta.append({
"ep_idx": ep_idx, "ep_idx": dataset_episode_idx,
"frame_idx": t, "frame_idx": t,
"hdf5_path": hdf5_path, "hdf5_path": hdf5_path,
}) })
self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta))) self.episodes[dataset_episode_idx] = list(range(start_idx, len(self.frame_meta)))
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)}") print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)}")
@@ -290,6 +319,10 @@ class SimpleRobotDataset(Dataset):
"""获取所有相机键名 (LeRobotDataset 格式)""" """获取所有相机键名 (LeRobotDataset 格式)"""
return [f"observation.{cam_name}" for cam_name in self.camera_names] return [f"observation.{cam_name}" for cam_name in self.camera_names]
@property
def available_episode_indices(self) -> List[int]:
return sorted(self.episodes.keys())
@property @property
def camera_info(self) -> dict: def camera_info(self) -> dict:
"""获取相机信息""" """获取相机信息"""

View File

@@ -388,6 +388,26 @@ class _StubFutureTokenPredictor(nn.Module):
return summary.repeat(1, self.num_future_tokens, 1) return summary.repeat(1, self.num_future_tokens, 1)
class _RecordingDirectFutureDecoder(nn.Module):
def __init__(self):
super().__init__()
self.scale = nn.Parameter(torch.tensor(0.5))
self.calls = []
def forward(self, sample, r, t, cond=None):
record = {
'sample': sample.detach().clone(),
'r': r.detach().clone(),
't': t.detach().clone(),
'cond': None if cond is None else cond.detach().clone(),
}
self.calls.append(record)
cond_term = 0.0
if cond is not None:
cond_term = cond.mean(dim=1, keepdim=True)
return self.scale * sample + cond_term
class _RecordingSigReg(nn.Module): class _RecordingSigReg(nn.Module):
def __init__(self, value=0.5): def __init__(self, value=0.5):
super().__init__() super().__init__()
@@ -687,6 +707,148 @@ class IMFVLAAgentTest(unittest.TestCase):
) )
torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1)) torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1))
def test_predict_action_with_dual_decoder_keeps_action_condition_history_only(self):
agent_cls, agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
future_decoder = _RecordingDirectFutureDecoder()
agent = agent_cls(
vision_backbone=_StubVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
future_decoder=future_decoder,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
lewm_history_horizon=3,
lewm_query_offsets=[8],
lewm_loss_weight=1.0,
)
agent.infer_scheduler = _ForbiddenScheduler()
with torch.no_grad():
agent.future_query_tokens.copy_(torch.tensor([[[0.1, 0.2, 0.3, 0.4]]], dtype=torch.float32))
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
)
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
initial_noise = torch.tensor(
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
dtype=torch.float32,
)
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
_ = agent.predict_action(images, qpos)
expected_history = torch.tensor(
[[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]],
dtype=torch.float32,
)
self.assertEqual(len(head.calls), 1)
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_history))
self.assertEqual(len(future_decoder.calls), 0)
def test_compute_loss_with_dual_decoder_tracks_lewm_loss_breakdown(self):
agent_cls, agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
future_decoder = _RecordingDirectFutureDecoder()
sigreg = _RecordingSigReg(value=0.75)
agent = agent_cls(
vision_backbone=_StubVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
future_decoder=future_decoder,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
lewm_history_horizon=3,
lewm_query_offsets=[8],
lewm_sigreg=sigreg,
lewm_sigreg_weight=0.09,
lewm_loss_weight=1.0,
)
with torch.no_grad():
agent.future_query_tokens.copy_(torch.tensor([[[0.2, 0.4, 0.6, 0.8]]], dtype=torch.float32))
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
actions = torch.tensor(
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
dtype=torch.float32,
)
lewm_images = _make_images(
batch_size=1,
obs_horizon=3,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32)
lewm_future_images = _make_images(
batch_size=1,
obs_horizon=1,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32)
noise = torch.tensor(
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
dtype=torch.float32,
)
t_sample = torch.tensor([0.8], dtype=torch.float32)
r_sample = torch.tensor([0.25], dtype=torch.float32)
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
loss = agent.compute_loss(
{
'images': images,
'qpos': qpos,
'action': actions,
'lewm_images': lewm_images,
'lewm_qpos': lewm_qpos,
'lewm_future_images': lewm_future_images,
'lewm_future_qpos': lewm_future_qpos,
}
)
metrics = agent.get_last_loss_breakdown()
self.assertAlmostEqual(loss.item(), metrics['loss'], places=6)
self.assertEqual(len(head.calls), 2)
self.assertEqual(head.calls[0]['cond'].shape, (1, 2, 4))
self.assertEqual(len(future_decoder.calls), 1)
self.assertEqual(future_decoder.calls[0]['cond'].shape, (1, 3, 4))
self.assertAlmostEqual(
metrics['loss'],
metrics['action_loss'] + metrics['lewm_loss'],
places=5,
)
self.assertAlmostEqual(
metrics['lewm_loss'],
metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'],
places=5,
)
self.assertGreater(metrics['lewm_pred_loss'], 0.0)
self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6)
def test_select_action_only_regenerates_when_action_queue_is_empty(self): def test_select_action_only_regenerates_when_action_queue_is_empty(self):
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2) agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
observation = { observation = {
@@ -1077,6 +1239,36 @@ class IMFVLAAgentTest(unittest.TestCase):
) )
self.assertIsNotNone(agent.lewm_sigreg) self.assertIsNotNone(agent.lewm_sigreg)
def test_hydra_config_instantiates_lewm_resnet_dual_decoder_imf_attnres(self):
cfg = _compose_cfg(
overrides=[
'agent=lewm_resnet_dual_decoder_imf_attnres',
'agent.head.n_layer=1',
'agent.head.n_emb=16',
'agent.future_decoder.n_layer=1',
'agent.future_decoder.n_emb=16',
'agent.lewm_query_offsets=[8]',
]
)
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
self.assertEqual(cfg.agent.extra_condition_tokens, 0)
self.assertEqual(
cfg.agent.future_decoder._target_,
'roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D',
)
self.assertEqual(cfg.agent.head.cond_dim, 288)
self.assertEqual(cfg.agent.future_decoder.cond_dim, 288)
with _stub_optional_modules(include_imf_head=True):
agent = instantiate(cfg.agent)
self.assertEqual(agent.per_step_cond_dim, 288)
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.obs_horizon)
self.assertEqual(agent.future_decoder.constructor_kwargs['cond_dim'], 288)
self.assertEqual(agent.future_query_tokens.shape, (1, 1, 288))
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self): def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
cfg = _compose_cfg( cfg = _compose_cfg(

View File

@@ -12,18 +12,21 @@ from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset
class SimpleRobotDatasetImageLoadingTest(unittest.TestCase): class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
def _write_episode(self, dataset_dir: Path) -> None: def _write_episode(self, dataset_dir: Path, episode_idx: int = 0, *, base_value: float = 0.0) -> None:
episode_path = dataset_dir / "episode_0.hdf5" episode_path = dataset_dir / f"episode_{episode_idx}.hdf5"
with h5py.File(episode_path, "w") as root: with h5py.File(episode_path, "w") as root:
root.create_dataset("action", data=np.arange(8, dtype=np.float32).reshape(4, 2)) root.create_dataset(
"action",
data=(np.arange(8, dtype=np.float32).reshape(4, 2) + base_value),
)
root.create_dataset( root.create_dataset(
"observations/qpos", "observations/qpos",
data=np.arange(16, dtype=np.float32).reshape(4, 4), data=(np.arange(16, dtype=np.float32).reshape(4, 4) + base_value),
) )
root.create_dataset("task", data=np.array([b"sim_transfer"])) root.create_dataset("task", data=np.array([b"sim_transfer"]))
root.create_dataset( root.create_dataset(
"observations/images/front", "observations/images/front",
data=np.arange(4 * 8 * 8 * 3, dtype=np.uint8).reshape(4, 8, 8, 3), data=((np.arange(4 * 8 * 8 * 3, dtype=np.uint8) + int(base_value)) % 255).reshape(4, 8, 8, 3),
) )
def test_getitem_only_resizes_observation_horizon_images(self): def test_getitem_only_resizes_observation_horizon_images(self):
@@ -100,3 +103,25 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8)) self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4)) self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8)) self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))
def test_dataset_can_limit_loading_to_specific_episode_indices(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
self._write_episode(dataset_dir, episode_idx=0, base_value=0.0)
self._write_episode(dataset_dir, episode_idx=1, base_value=100.0)
dataset = SimpleRobotDataset(
dataset_dir,
obs_horizon=2,
pred_horizon=3,
camera_names=["front"],
image_resize_shape=None,
episode_indices=[1],
)
sample = dataset[0]
self.assertEqual(len(dataset.hdf5_files), 1)
self.assertEqual(dataset.available_episode_indices, [1])
self.assertEqual(len(dataset), 4)
self.assertTrue(np.allclose(sample["observation.state"][0].numpy(), np.array([100.0, 101.0, 102.0, 103.0])))

View File

@@ -41,6 +41,19 @@ class FakeDataset:
return 4 return 4
class SplitAwareFakeDataset(FakeDataset):
def __init__(self, episode_indices=None):
self.episode_indices = None if episode_indices is None else list(episode_indices)
if self.episode_indices is None:
self.episodes = {0: [0], 1: [1], 2: [2]}
else:
self.episodes = {idx: [idx] for idx in self.episode_indices}
@property
def available_episode_indices(self):
return sorted(self.episodes.keys())
class FakeLoader: class FakeLoader:
def __init__(self, batch): def __init__(self, batch):
self.batch = batch self.batch = batch
@@ -123,6 +136,10 @@ class RecordingAgent(FakeAgent):
self.seen_inputs.append(agent_input) self.seen_inputs.append(agent_input)
return super().compute_loss(agent_input) return super().compute_loss(agent_input)
def predict_action_chunk(self, agent_input):
self.seen_inputs.append({'predict_action_chunk': agent_input})
return torch.ones_like(agent_input['action'])
class ShapeMixedFakeAgent(FakeAgent): class ShapeMixedFakeAgent(FakeAgent):
def __init__(self): def __init__(self):
@@ -355,6 +372,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
batch_size=2, batch_size=2,
num_workers=0, num_workers=0,
val_split=0.25, val_split=0.25,
val_episode_indices=None,
action_mse_val_freq_epochs=0,
seed=0, seed=0,
lr=1e-3, lr=1e-3,
max_steps=2, max_steps=2,
@@ -479,6 +498,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
'batch_size': 2, 'batch_size': 2,
'num_workers': 0, 'num_workers': 0,
'val_split': 0.25, 'val_split': 0.25,
'val_episode_indices': None,
'action_mse_val_freq_epochs': 0,
'seed': 0, 'seed': 0,
'lr': 1e-3, 'lr': 1e-3,
'max_steps': 2, 'max_steps': 2,
@@ -561,6 +582,58 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
self.assertIn('front', first_input['lewm_images']) self.assertIn('front', first_input['lewm_images'])
self.assertIn('front', first_input['lewm_future_images']) self.assertIn('front', first_input['lewm_future_images'])
def test_run_training_logs_epoch_action_mse_for_held_out_val_episode(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
cfg.train.max_steps = 1
cfg.train.save_freq = 100
cfg.train.val_split = 0.0
cfg.train.val_episode_indices = [2]
cfg.train.action_mse_val_freq_epochs = 1
agent = RecordingAgent()
fake_swanlab = FakeSwanLab()
real_import_module = importlib.import_module
def fake_instantiate(config_node, **kwargs):
if config_node is cfg.data:
return SplitAwareFakeDataset(kwargs.get('episode_indices'))
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_loader_factory(dataset, *, shuffle, **_kwargs):
action_value = 0.0 if shuffle else 2.0
batch = {
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.full((1, 1, 2), action_value),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
}
return FakeLoader(batch)
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=fake_loader_factory), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
run_training(cfg)
finally:
os.chdir(previous_cwd)
logged_keys = set().union(*(payload.keys() for payload, _ in fake_swanlab.log_calls))
self.assertIn('val/action_mse', logged_keys)
def test_run_training_skips_swanlab_when_disabled(self): def test_run_training_skips_swanlab_when_disabled(self):
module = self._load_train_vla_module() module = self._load_train_vla_module()
run_training = self._get_run_training(module) run_training = self._get_run_training(module)