From 395f5a164547f163d34497c64851dbfd995b399c Mon Sep 17 00:00:00 2001 From: Logic Date: Fri, 17 Apr 2026 19:26:56 +0800 Subject: [PATCH] feat: add held-out validation and dual-decoder lewm imf --- roboimi/demos/vla_scripts/train_vla.py | 219 +++++++++++++----- roboimi/vla/agent_imf.py | 147 ++++++++++-- .../lewm_resnet_dual_decoder_imf_attnres.yaml | 74 ++++++ roboimi/vla/conf/config.yaml | 2 + roboimi/vla/data/simpe_robot_dataset.py | 37 ++- tests/test_imf_vla_agent.py | 192 +++++++++++++++ ...test_simple_robot_dataset_image_loading.py | 35 ++- tests/test_train_vla_swanlab_logging.py | 73 ++++++ 8 files changed, 693 insertions(+), 86 deletions(-) create mode 100644 roboimi/vla/conf/agent/lewm_resnet_dual_decoder_imf_attnres.yaml diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index afbd23f..d4e3842 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -118,6 +118,127 @@ def recursive_to_device(data, device): return data +def build_agent_input(batch_data): + agent_input = { + 'images': { + cam_name.replace('observation.', ''): value + for cam_name, value in batch_data.items() + if cam_name.startswith('observation.') and cam_name != 'observation.state' + }, + 'qpos': batch_data['observation.state'], + 'action': batch_data['action'], + } + + if 'action_is_pad' in batch_data: + agent_input['action_is_pad'] = batch_data['action_is_pad'] + + lewm_images = { + cam_name.replace('lewm.observation.', ''): value + for cam_name, value in batch_data.items() + if cam_name.startswith('lewm.observation.') and cam_name != 'lewm.observation.state' + } + if lewm_images: + agent_input['lewm_images'] = lewm_images + if 'lewm.observation.state' in batch_data: + agent_input['lewm_qpos'] = batch_data['lewm.observation.state'] + + lewm_future_images = { + cam_name.replace('lewm.future.', ''): value + for cam_name, value in batch_data.items() + if cam_name.startswith('lewm.future.') and cam_name != 'lewm.future.state' + } + if lewm_future_images: + agent_input['lewm_future_images'] = lewm_future_images + if 'lewm.future.state' in batch_data: + agent_input['lewm_future_qpos'] = batch_data['lewm.future.state'] + + return agent_input + + +def _instantiate_dataset(cfg, dataset_image_resize_shape, episode_indices=None): + kwargs = {'image_resize_shape': dataset_image_resize_shape} + if episode_indices is not None: + kwargs['episode_indices'] = episode_indices + return instantiate(cfg.data, **kwargs) + + +def build_train_val_datasets(cfg, dataset_image_resize_shape): + val_episode_indices = cfg.train.get('val_episode_indices', None) + if val_episode_indices: + dataset = _instantiate_dataset(cfg, dataset_image_resize_shape) + available_episode_indices = list(getattr(dataset, 'available_episode_indices', [])) + if not available_episode_indices: + raise ValueError('显式 val_episode_indices 需要数据集暴露 available_episode_indices') + requested_val_episode_indices = sorted(int(idx) for idx in val_episode_indices) + available_set = set(available_episode_indices) + missing = sorted(set(requested_val_episode_indices) - available_set) + if missing: + raise ValueError( + f'val_episode_indices {missing} 不存在于数据集可用 episodes {available_episode_indices}' + ) + train_episode_indices = [ + idx for idx in available_episode_indices + if idx not in set(requested_val_episode_indices) + ] + if not train_episode_indices: + raise ValueError('显式 val_episode_indices 不能覆盖全部 episodes,训练集将为空') + + train_dataset = _instantiate_dataset( + cfg, + dataset_image_resize_shape, + episode_indices=train_episode_indices, + ) + val_dataset = _instantiate_dataset( + cfg, + dataset_image_resize_shape, + episode_indices=requested_val_episode_indices, + ) + return dataset, train_dataset, val_dataset, requested_val_episode_indices + + dataset = _instantiate_dataset(cfg, dataset_image_resize_shape) + val_split = float(cfg.train.get('val_split', 0.1)) + seed = int(cfg.train.get('seed', 42)) + val_size = int(len(dataset) * val_split) + train_size = len(dataset) - val_size + if val_size > 0: + train_dataset, val_dataset = random_split( + dataset, + [train_size, val_size], + generator=torch.Generator().manual_seed(seed) + ) + else: + train_dataset, val_dataset = dataset, None + return dataset, train_dataset, val_dataset, None + + +def compute_action_mse_validation(agent, val_loader, device): + if val_loader is None: + return None + + was_training = agent.training + agent.eval() + total_squared_error = 0.0 + total_count = 0.0 + with torch.no_grad(): + for val_batch in val_loader: + val_batch = recursive_to_device(val_batch, device) + val_input = build_agent_input(val_batch) + pred_actions = agent.predict_action_chunk(val_input) + target_actions = val_input['action'] + squared_error = (pred_actions - target_actions).pow(2) + action_is_pad = val_input.get('action_is_pad', None) + if action_is_pad is not None: + mask = (~action_is_pad).unsqueeze(-1).to(squared_error.dtype) + total_squared_error += (squared_error * mask).sum().item() + total_count += mask.sum().item() * squared_error.shape[-1] + else: + total_squared_error += squared_error.sum().item() + total_count += target_actions.numel() + if was_training: + agent.train() + return total_squared_error / max(total_count, 1.0) + + def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir): """ 解析恢复训练用的 checkpoint 路径。 @@ -410,30 +531,30 @@ def _run_training(cfg: DictConfig): vision_backbone_cfg = cfg.agent.get('vision_backbone', None) if vision_backbone_cfg is not None and 'dataset_image_resize_shape' in vision_backbone_cfg: dataset_image_resize_shape = vision_backbone_cfg.get('dataset_image_resize_shape') - dataset = instantiate( - cfg.data, - image_resize_shape=dataset_image_resize_shape, + dataset, train_dataset, val_dataset, explicit_val_episode_indices = ( + build_train_val_datasets(cfg, dataset_image_resize_shape) ) log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}") except Exception as e: log.error(f"❌ 数据集加载失败: {e}") raise - # 训练/验证集划分 - val_split = float(cfg.train.get('val_split', 0.1)) - seed = int(cfg.train.get('seed', 42)) - val_size = int(len(dataset) * val_split) - train_size = len(dataset) - val_size - if val_size > 0: - train_dataset, val_dataset = random_split( - dataset, - [train_size, val_size], - generator=torch.Generator().manual_seed(seed) + if explicit_val_episode_indices is not None: + log.info( + "✅ 数据集划分: 训练集=%s, 验证集=%s (显式 held-out episodes=%s)", + len(train_dataset), + len(val_dataset), + explicit_val_episode_indices, ) - log.info(f"✅ 数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})") else: - train_dataset, val_dataset = dataset, None - log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)") + val_split = float(cfg.train.get('val_split', 0.1)) + val_size = len(val_dataset) if val_dataset is not None else 0 + if val_size > 0: + log.info( + f"✅ 数据集划分: 训练集={len(train_dataset)}, 验证集={val_size} (验证比例={val_split})" + ) + else: + log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)") train_batch_size = int(cfg.train.batch_size) train_drop_last = len(train_dataset) >= train_batch_size @@ -674,44 +795,6 @@ def _run_training(cfg: DictConfig): # ========================================================================= log.info("🏋️ 开始训练循环...") - def build_agent_input(batch_data): - """构建 agent 输入格式""" - images = {} - # SimpleRobotDataset 返回 observation.{cam_name} 格式 - for cam_name in cfg.data.camera_names: - key = f"observation.{cam_name}" - if key in batch_data: - images[cam_name] = batch_data[key] - - agent_input = { - 'images': images, - 'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state - 'action': batch_data['action'], - 'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask - } - - lewm_images = {} - lewm_future_images = {} - for cam_name in cfg.data.camera_names: - lewm_obs_key = f"lewm.observation.{cam_name}" - if lewm_obs_key in batch_data: - lewm_images[cam_name] = batch_data[lewm_obs_key] - - lewm_future_key = f"lewm.future.{cam_name}" - if lewm_future_key in batch_data: - lewm_future_images[cam_name] = batch_data[lewm_future_key] - - if 'lewm.observation.state' in batch_data: - agent_input['lewm_qpos'] = batch_data['lewm.observation.state'] - if lewm_images: - agent_input['lewm_images'] = lewm_images - if 'lewm.future.state' in batch_data: - agent_input['lewm_future_qpos'] = batch_data['lewm.future.state'] - if lewm_future_images: - agent_input['lewm_future_images'] = lewm_future_images - - return agent_input - def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None): agent_stats = agent.get_normalization_stats() torch.save({ @@ -784,6 +867,7 @@ def _run_training(cfg: DictConfig): pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100) steps_per_epoch = len(train_loader) + action_mse_val_freq_epochs = int(cfg.train.get('action_mse_val_freq_epochs', 0) or 0) rollout_val_freq_epochs = int(cfg.train.get('rollout_val_freq_epochs', 0) or 0) rollout_validation_enabled = rollout_val_freq_epochs > 0 best_loss = resume_best_loss @@ -953,6 +1037,33 @@ def _run_training(cfg: DictConfig): and completed_epoch > 0 and completed_epoch % rollout_val_freq_epochs == 0 ) + should_run_action_mse_validation = ( + action_mse_val_freq_epochs > 0 + and val_loader is not None + and steps_per_epoch > 0 + and completed_steps % steps_per_epoch == 0 + and completed_epoch > 0 + and completed_epoch % action_mse_val_freq_epochs == 0 + ) + if should_run_action_mse_validation: + action_mse = compute_action_mse_validation( + agent, + val_loader, + cfg.train.device, + ) + if action_mse is not None: + log.info( + f"步骤 {step}/{cfg.train.max_steps} | Epoch {completed_epoch} " + f"held-out action MSE: {action_mse:.6f}" + ) + _log_to_swanlab( + swanlab_module, + { + 'val/action_mse': action_mse, + 'val/action_mse_epoch': completed_epoch, + }, + step=step, + ) if should_run_epoch_rollout: if checkpoint_path is None: checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt" diff --git a/roboimi/vla/agent_imf.py b/roboimi/vla/agent_imf.py index 557de7d..5a6fa93 100644 --- a/roboimi/vla/agent_imf.py +++ b/roboimi/vla/agent_imf.py @@ -26,6 +26,8 @@ class IMFVLAAgent(VLAAgent): lewm_query_offsets: Optional[Sequence[int]] = None, lewm_predictor: Optional[nn.Module] = None, lewm_pred_projector: Optional[nn.Module] = None, + future_decoder: Optional[nn.Module] = None, + future_query_init_std: float = 0.02, lewm_sigreg: Optional[nn.Module] = None, lewm_sigreg_weight: float = 0.09, lewm_loss_weight: float = 0.0, @@ -39,11 +41,17 @@ class IMFVLAAgent(VLAAgent): ) lewm_query_offsets = tuple(int(offset) for offset in (lewm_query_offsets or ())) inferred_extra_condition_tokens = len(lewm_query_offsets) if lewm_query_offsets else 0 - kwargs.setdefault('extra_condition_tokens', inferred_extra_condition_tokens) + default_extra_condition_tokens = ( + 0 if future_decoder is not None else inferred_extra_condition_tokens + ) + kwargs.setdefault('extra_condition_tokens', default_extra_condition_tokens) self.__dict__['lewm_history_horizon'] = int(lewm_history_horizon or kwargs.get('obs_horizon', 1)) self.__dict__['lewm_query_offsets'] = lewm_query_offsets self.__dict__['lewm_predictor'] = lewm_predictor self.__dict__['lewm_pred_projector'] = lewm_pred_projector or nn.Identity() + self.__dict__['future_decoder'] = future_decoder + self.__dict__['future_query_tokens'] = None + self.__dict__['future_query_init_std'] = float(future_query_init_std) self.__dict__['lewm_sigreg'] = lewm_sigreg self.__dict__['lewm_sigreg_weight'] = float(lewm_sigreg_weight) self.__dict__['lewm_loss_weight'] = float(lewm_loss_weight) @@ -59,8 +67,16 @@ class IMFVLAAgent(VLAAgent): self.lewm_history_horizon = int(lewm_history_horizon or self.obs_horizon) self.lewm_predictor = lewm_predictor self.lewm_pred_projector = lewm_pred_projector or nn.Identity() + if future_decoder is not None and not isinstance(future_decoder, nn.Module): + self.future_decoder = future_decoder() + else: + self.future_decoder = future_decoder + self.future_query_tokens = None + self.future_query_init_std = float(future_query_init_std) self.lewm_sigreg = lewm_sigreg self.lewm_sigreg_weight = float(lewm_sigreg_weight) + if self.lewm_predictor is not None and self.future_decoder is not None: + raise ValueError('lewm_predictor and future_decoder are mutually exclusive') if self.lewm_predictor is None and self.extra_condition_tokens > 0: raise ValueError( 'extra_condition_tokens > 0 requires lewm_predictor to be provided' @@ -69,6 +85,18 @@ class IMFVLAAgent(VLAAgent): raise ValueError( 'extra_condition_tokens must equal len(lewm_query_offsets) when lewm_predictor is enabled' ) + if self.future_decoder is not None: + if inferred_extra_condition_tokens <= 0: + raise ValueError('future_decoder requires non-empty lewm_query_offsets') + if self.extra_condition_tokens != 0: + raise ValueError('future_decoder requires extra_condition_tokens=0') + self.future_query_tokens = nn.Parameter( + torch.randn( + 1, + inferred_extra_condition_tokens, + self.per_step_cond_dim, + ) * self.future_query_init_std + ) if lewm_pretrained_ckpt is not None: self.load_lewm_pretrained_components(lewm_pretrained_ckpt) @@ -232,15 +260,19 @@ class IMFVLAAgent(VLAAgent): *, query_key: str = 'query_tokens', pos_key: str = 'pos_embedding', - ) -> None: + ) -> Dict[str, Sequence[str]]: current_state_dict = module.state_dict() adapted_state_dict = dict(current_state_dict) + loaded_keys = [] + mismatched_keys = [] + missing_keys = [] for key, current_tensor in current_state_dict.items(): if key not in incoming_state_dict: continue source_tensor = incoming_state_dict[key] if source_tensor.shape == current_tensor.shape: adapted_state_dict[key] = source_tensor + loaded_keys.append(key) continue if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim: @@ -256,8 +288,47 @@ class IMFVLAAgent(VLAAgent): if copy_count < current_tensor.shape[1] and copy_count > 0: patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...] adapted_state_dict[key] = patched + loaded_keys.append(key) + continue + mismatched_keys.append(key) + for key in incoming_state_dict.keys(): + if key not in current_state_dict: + missing_keys.append(key) module.load_state_dict(adapted_state_dict, strict=True) + return { + 'loaded_keys': tuple(sorted(loaded_keys)), + 'mismatched_keys': tuple(sorted(set(mismatched_keys))), + 'missing_keys': tuple(sorted(set(missing_keys))), + } + + @staticmethod + def _load_state_dict_ignoring_shape_mismatches( + module: nn.Module, + incoming_state_dict: Mapping[str, torch.Tensor], + ) -> Dict[str, Sequence[str]]: + current_state_dict = module.state_dict() + merged_state_dict = dict(current_state_dict) + loaded_keys = [] + mismatched_keys = [] + missing_keys = [] + + for key, value in incoming_state_dict.items(): + if key not in current_state_dict: + missing_keys.append(key) + continue + if current_state_dict[key].shape != value.shape: + mismatched_keys.append(key) + continue + merged_state_dict[key] = value + loaded_keys.append(key) + + module.load_state_dict(merged_state_dict, strict=True) + return { + 'loaded_keys': tuple(sorted(loaded_keys)), + 'mismatched_keys': tuple(sorted(mismatched_keys)), + 'missing_keys': tuple(sorted(missing_keys)), + } def load_lewm_pretrained_components( self, @@ -266,20 +337,24 @@ class IMFVLAAgent(VLAAgent): state_dict = self._load_checkpoint_payload(checkpoint_or_path) if hasattr(self.vision_encoder, 'load_lewm_checkpoint'): - self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict}) + try: + self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict}) + except RuntimeError: + vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.') + self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict) else: vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.') - self.vision_encoder.load_state_dict(vision_state_dict, strict=True) + self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict) state_encoder_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.state_encoder.') - self.state_encoder.load_state_dict(state_encoder_state_dict, strict=True) + self._load_state_dict_ignoring_shape_mismatches(self.state_encoder, state_encoder_state_dict) projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.projector.proj.') mapped_projector_state_dict = { f'linear.{key}': value for key, value in projector_state_dict.items() } - self.cond_projector.load_state_dict(mapped_projector_state_dict, strict=True) + self._load_state_dict_ignoring_shape_mismatches(self.cond_projector, mapped_projector_state_dict) if self.lewm_predictor is not None: predictor_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.predictor.') @@ -287,7 +362,19 @@ class IMFVLAAgent(VLAAgent): if self.lewm_pred_projector is not None: pred_projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.pred_proj.') - self.lewm_pred_projector.load_state_dict(pred_projector_state_dict, strict=True) + self._load_state_dict_ignoring_shape_mismatches( + self.lewm_pred_projector, + pred_projector_state_dict, + ) + + def _predict_future_tokens_with_decoder(self, history_cond: torch.Tensor) -> torch.Tensor: + if self.future_decoder is None or self.future_query_tokens is None: + raise RuntimeError('future_decoder path requested but not initialized') + batch_size = history_cond.shape[0] + query_tokens = self.future_query_tokens.expand(batch_size, -1, -1) + r = torch.zeros(batch_size, device=history_cond.device, dtype=history_cond.dtype) + t = torch.ones(batch_size, device=history_cond.device, dtype=history_cond.dtype) + return self.future_decoder(query_tokens, r, t, cond=history_cond) def _build_full_condition( self, @@ -302,7 +389,7 @@ class IMFVLAAgent(VLAAgent): predicted_future_tokens = None lewm_history_cond = None - if self.lewm_predictor is None: + if self.lewm_predictor is None and self.future_decoder is None: return history_cond, predicted_future_tokens, lewm_history_cond lewm_images = lewm_images if lewm_images is not None else images @@ -313,17 +400,21 @@ class IMFVLAAgent(VLAAgent): lewm_images, self._normalize_qpos_for_lewm(lewm_proprioception), ) - predicted_future_tokens = self.lewm_predictor(lewm_history_cond) - predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens) - cond = torch.cat([history_cond, predicted_future_tokens], dim=1) - if cond.shape[1] != self.condition_sequence_length: - raise RuntimeError( - f"完整条件序列长度不匹配: got {cond.shape[1]}, expected {self.condition_sequence_length}" - ) - if cond.shape[-1] != self.per_step_cond_dim: - raise RuntimeError( - f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}" - ) + cond = history_cond + if self.lewm_predictor is not None: + predicted_future_tokens = self.lewm_predictor(lewm_history_cond) + predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens) + cond = torch.cat([history_cond, predicted_future_tokens], dim=1) + if cond.shape[1] != self.condition_sequence_length: + raise RuntimeError( + f"完整条件序列长度不匹配: got {cond.shape[1]}, expected {self.condition_sequence_length}" + ) + if cond.shape[-1] != self.per_step_cond_dim: + raise RuntimeError( + f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}" + ) + elif self.future_decoder is not None: + predicted_future_tokens = self._predict_future_tokens_with_decoder(lewm_history_cond) return cond, predicted_future_tokens, lewm_history_cond @staticmethod @@ -452,12 +543,18 @@ class IMFVLAAgent(VLAAgent): lewm_proprioception=None, ): batch_size = proprioception.shape[0] - cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition( - images, - proprioception, - lewm_images=lewm_images, - lewm_proprioception=lewm_proprioception, - ) + if self.lewm_predictor is not None: + cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition( + images, + proprioception, + lewm_images=lewm_images, + lewm_proprioception=lewm_proprioception, + ) + else: + cond = self._build_cond( + images, + self.normalization.normalize_qpos(proprioception), + ) z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype) action = self._sample_one_step(z_t, cond=cond) return self.normalization.denormalize_action(action) diff --git a/roboimi/vla/conf/agent/lewm_resnet_dual_decoder_imf_attnres.yaml b/roboimi/vla/conf/agent/lewm_resnet_dual_decoder_imf_attnres.yaml new file mode 100644 index 0000000..3437fbb --- /dev/null +++ b/roboimi/vla/conf/agent/lewm_resnet_dual_decoder_imf_attnres.yaml @@ -0,0 +1,74 @@ +# @package agent +defaults: + - /backbone@vision_backbone: lewm_resnet_query_fusion + - /modules@state_encoder: lewm_state_encoder + - /modules@action_encoder: identity_action_encoder + - /modules@cond_projector: linear_condition_projector + - /head: imf_transformer1d + - /head@future_decoder: imf_transformer1d + - _self_ + +_target_: roboimi.vla.agent_imf.IMFVLAAgent + +action_dim: 16 +obs_dim: 16 +normalization_type: "min_max" +pred_horizon: 8 +obs_horizon: 2 +num_action_steps: 8 +camera_names: ${data.camera_names} +num_cams: 3 + +vision_backbone: + camera_names: ${agent.camera_names} + num_views: ${agent.num_cams} + +cond_projector: + output_dim: 288 + +lewm_history_horizon: 3 +lewm_query_offsets: [8] +extra_condition_tokens: 0 +lewm_loss_weight: 1.0 +lewm_sigreg_weight: 0.09 +lewm_pretrained_ckpt: null +future_query_init_std: 0.02 + +lewm_sigreg: + _target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg + knots: 17 + num_proj: 1024 + +diffusion_steps: 100 +inference_steps: 1 +head_type: "transformer" + +head: + input_dim: ${agent.action_dim} + output_dim: ${agent.action_dim} + horizon: ${agent.pred_horizon} + n_obs_steps: ${agent.obs_horizon} + cond_dim: ${agent.cond_projector.output_dim} + n_emb: 384 + causal_attn: false + time_as_cond: true + obs_as_cond: true + n_cond_layers: 0 + backbone_type: attnres_full + n_head: 1 + n_kv_head: 1 + +future_decoder: + input_dim: ${agent.cond_projector.output_dim} + output_dim: ${agent.cond_projector.output_dim} + horizon: ${len:${agent.lewm_query_offsets}} + n_obs_steps: ${agent.obs_horizon} + cond_dim: ${agent.cond_projector.output_dim} + n_emb: 384 + causal_attn: false + time_as_cond: true + obs_as_cond: true + n_cond_layers: 0 + backbone_type: attnres_full + n_head: 1 + n_kv_head: 1 diff --git a/roboimi/vla/conf/config.yaml b/roboimi/vla/conf/config.yaml index 7f991e0..05818a0 100644 --- a/roboimi/vla/conf/config.yaml +++ b/roboimi/vla/conf/config.yaml @@ -18,6 +18,8 @@ train: # 数据加载 num_workers: 12 # DataLoader 工作进程数(调试时设为 0) val_split: 0.0 # 验证集比例;默认使用全量数据训练 + val_episode_indices: null # 显式按 episode 划出的验证集,例如 [100] + action_mse_val_freq_epochs: 0 # >0 时每隔多少个 epoch 在 held-out episode 上计算 action MSE seed: 42 # 随机种子(用于数据划分) # 日志和检查点 diff --git a/roboimi/vla/data/simpe_robot_dataset.py b/roboimi/vla/data/simpe_robot_dataset.py index 94156fc..e9aabb0 100644 --- a/roboimi/vla/data/simpe_robot_dataset.py +++ b/roboimi/vla/data/simpe_robot_dataset.py @@ -26,6 +26,7 @@ class SimpleRobotDataset(Dataset): max_open_files: int = 64, lewm_history_horizon: Optional[int] = None, lewm_query_offsets: Optional[Sequence[int]] = None, + episode_indices: Optional[Sequence[int]] = None, ): """ Args: @@ -57,6 +58,9 @@ class SimpleRobotDataset(Dataset): ) self.max_open_files = max(1, int(max_open_files)) self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict() + self.requested_episode_indices = ( + None if episode_indices is None else tuple(sorted(int(idx) for idx in episode_indices)) + ) self.dataset_dir = Path(dataset_dir) if not self.dataset_dir.exists(): @@ -69,20 +73,45 @@ class SimpleRobotDataset(Dataset): if not self.hdf5_files: raise FileNotFoundError(f"在 {dataset_dir} 中未找到 HDF5 文件") + if self.requested_episode_indices is not None: + requested = set(self.requested_episode_indices) + filtered = [] + for hdf5_path in self.hdf5_files: + stem = hdf5_path.stem + if stem.startswith("episode_"): + try: + idx = int(stem.split("_")[-1]) + except ValueError: + continue + if idx in requested: + filtered.append(hdf5_path) + self.hdf5_files = filtered + if not self.hdf5_files: + raise FileNotFoundError( + f"在 {dataset_dir} 中未找到 episode_indices={sorted(requested)} 对应的 HDF5 文件" + ) + # 构建 episode 索引(只存储元数据,不加载数据) self.episodes = {} self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path) for ep_idx, hdf5_path in enumerate(self.hdf5_files): with h5py.File(hdf5_path, 'r') as f: T = f['action'].shape[0] + dataset_episode_idx = ep_idx + stem = hdf5_path.stem + if stem.startswith("episode_"): + try: + dataset_episode_idx = int(stem.split("_")[-1]) + except ValueError: + pass start_idx = len(self.frame_meta) for t in range(T): self.frame_meta.append({ - "ep_idx": ep_idx, + "ep_idx": dataset_episode_idx, "frame_idx": t, "hdf5_path": hdf5_path, }) - self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta))) + self.episodes[dataset_episode_idx] = list(range(start_idx, len(self.frame_meta))) print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)} 帧") @@ -290,6 +319,10 @@ class SimpleRobotDataset(Dataset): """获取所有相机键名 (LeRobotDataset 格式)""" return [f"observation.{cam_name}" for cam_name in self.camera_names] + @property + def available_episode_indices(self) -> List[int]: + return sorted(self.episodes.keys()) + @property def camera_info(self) -> dict: """获取相机信息""" diff --git a/tests/test_imf_vla_agent.py b/tests/test_imf_vla_agent.py index d611009..b14112b 100644 --- a/tests/test_imf_vla_agent.py +++ b/tests/test_imf_vla_agent.py @@ -388,6 +388,26 @@ class _StubFutureTokenPredictor(nn.Module): return summary.repeat(1, self.num_future_tokens, 1) +class _RecordingDirectFutureDecoder(nn.Module): + def __init__(self): + super().__init__() + self.scale = nn.Parameter(torch.tensor(0.5)) + self.calls = [] + + def forward(self, sample, r, t, cond=None): + record = { + 'sample': sample.detach().clone(), + 'r': r.detach().clone(), + 't': t.detach().clone(), + 'cond': None if cond is None else cond.detach().clone(), + } + self.calls.append(record) + cond_term = 0.0 + if cond is not None: + cond_term = cond.mean(dim=1, keepdim=True) + return self.scale * sample + cond_term + + class _RecordingSigReg(nn.Module): def __init__(self, value=0.5): super().__init__() @@ -687,6 +707,148 @@ class IMFVLAAgentTest(unittest.TestCase): ) torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1)) + def test_predict_action_with_dual_decoder_keeps_action_condition_history_only(self): + agent_cls, agent_module = _load_imf_agent_class() + head = _RecordingLinearIMFHead() + future_decoder = _RecordingDirectFutureDecoder() + agent = agent_cls( + vision_backbone=_StubVisionBackbone(), + state_encoder=nn.Identity(), + action_encoder=nn.Identity(), + head=head, + future_decoder=future_decoder, + action_dim=2, + obs_dim=1, + pred_horizon=3, + obs_horizon=2, + diffusion_steps=10, + inference_steps=1, + num_cams=len(_CAMERA_NAMES), + camera_names=_CAMERA_NAMES, + num_action_steps=2, + head_type='transformer', + lewm_history_horizon=3, + lewm_query_offsets=[8], + lewm_loss_weight=1.0, + ) + agent.infer_scheduler = _ForbiddenScheduler() + with torch.no_grad(): + agent.future_query_tokens.copy_(torch.tensor([[[0.1, 0.2, 0.3, 0.4]]], dtype=torch.float32)) + + images = _make_images( + batch_size=1, + obs_horizon=2, + per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0}, + ) + qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32) + initial_noise = torch.tensor( + [[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]], + dtype=torch.float32, + ) + + with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise): + _ = agent.predict_action(images, qpos) + + expected_history = torch.tensor( + [[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]], + dtype=torch.float32, + ) + self.assertEqual(len(head.calls), 1) + self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_history)) + self.assertEqual(len(future_decoder.calls), 0) + + def test_compute_loss_with_dual_decoder_tracks_lewm_loss_breakdown(self): + agent_cls, agent_module = _load_imf_agent_class() + head = _RecordingLinearIMFHead() + future_decoder = _RecordingDirectFutureDecoder() + sigreg = _RecordingSigReg(value=0.75) + agent = agent_cls( + vision_backbone=_StubVisionBackbone(), + state_encoder=nn.Identity(), + action_encoder=nn.Identity(), + head=head, + future_decoder=future_decoder, + action_dim=2, + obs_dim=1, + pred_horizon=3, + obs_horizon=2, + diffusion_steps=10, + inference_steps=1, + num_cams=len(_CAMERA_NAMES), + camera_names=_CAMERA_NAMES, + num_action_steps=2, + head_type='transformer', + lewm_history_horizon=3, + lewm_query_offsets=[8], + lewm_sigreg=sigreg, + lewm_sigreg_weight=0.09, + lewm_loss_weight=1.0, + ) + with torch.no_grad(): + agent.future_query_tokens.copy_(torch.tensor([[[0.2, 0.4, 0.6, 0.8]]], dtype=torch.float32)) + + images = _make_images( + batch_size=1, + obs_horizon=2, + per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0}, + ) + qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32) + actions = torch.tensor( + [[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]], + dtype=torch.float32, + ) + lewm_images = _make_images( + batch_size=1, + obs_horizon=3, + per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0}, + ) + lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32) + lewm_future_images = _make_images( + batch_size=1, + obs_horizon=1, + per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0}, + ) + lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32) + noise = torch.tensor( + [[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]], + dtype=torch.float32, + ) + t_sample = torch.tensor([0.8], dtype=torch.float32) + r_sample = torch.tensor([0.25], dtype=torch.float32) + + with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \ + mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]): + loss = agent.compute_loss( + { + 'images': images, + 'qpos': qpos, + 'action': actions, + 'lewm_images': lewm_images, + 'lewm_qpos': lewm_qpos, + 'lewm_future_images': lewm_future_images, + 'lewm_future_qpos': lewm_future_qpos, + } + ) + + metrics = agent.get_last_loss_breakdown() + self.assertAlmostEqual(loss.item(), metrics['loss'], places=6) + self.assertEqual(len(head.calls), 2) + self.assertEqual(head.calls[0]['cond'].shape, (1, 2, 4)) + self.assertEqual(len(future_decoder.calls), 1) + self.assertEqual(future_decoder.calls[0]['cond'].shape, (1, 3, 4)) + self.assertAlmostEqual( + metrics['loss'], + metrics['action_loss'] + metrics['lewm_loss'], + places=5, + ) + self.assertAlmostEqual( + metrics['lewm_loss'], + metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'], + places=5, + ) + self.assertGreater(metrics['lewm_pred_loss'], 0.0) + self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6) + def test_select_action_only_regenerates_when_action_queue_is_empty(self): agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2) observation = { @@ -1077,6 +1239,36 @@ class IMFVLAAgentTest(unittest.TestCase): ) self.assertIsNotNone(agent.lewm_sigreg) + def test_hydra_config_instantiates_lewm_resnet_dual_decoder_imf_attnres(self): + cfg = _compose_cfg( + overrides=[ + 'agent=lewm_resnet_dual_decoder_imf_attnres', + 'agent.head.n_layer=1', + 'agent.head.n_emb=16', + 'agent.future_decoder.n_layer=1', + 'agent.future_decoder.n_emb=16', + 'agent.lewm_query_offsets=[8]', + ] + ) + + self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent') + self.assertEqual(cfg.agent.extra_condition_tokens, 0) + self.assertEqual( + cfg.agent.future_decoder._target_, + 'roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D', + ) + self.assertEqual(cfg.agent.head.cond_dim, 288) + self.assertEqual(cfg.agent.future_decoder.cond_dim, 288) + + with _stub_optional_modules(include_imf_head=True): + agent = instantiate(cfg.agent) + + self.assertEqual(agent.per_step_cond_dim, 288) + self.assertEqual(agent.condition_sequence_length, agent.obs_horizon) + self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.obs_horizon) + self.assertEqual(agent.future_decoder.constructor_kwargs['cond_dim'], 288) + self.assertEqual(agent.future_query_tokens.shape, (1, 1, 288)) + def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self): cfg = _compose_cfg( diff --git a/tests/test_simple_robot_dataset_image_loading.py b/tests/test_simple_robot_dataset_image_loading.py index 2dd5321..1227b6e 100644 --- a/tests/test_simple_robot_dataset_image_loading.py +++ b/tests/test_simple_robot_dataset_image_loading.py @@ -12,18 +12,21 @@ from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset class SimpleRobotDatasetImageLoadingTest(unittest.TestCase): - def _write_episode(self, dataset_dir: Path) -> None: - episode_path = dataset_dir / "episode_0.hdf5" + def _write_episode(self, dataset_dir: Path, episode_idx: int = 0, *, base_value: float = 0.0) -> None: + episode_path = dataset_dir / f"episode_{episode_idx}.hdf5" with h5py.File(episode_path, "w") as root: - root.create_dataset("action", data=np.arange(8, dtype=np.float32).reshape(4, 2)) + root.create_dataset( + "action", + data=(np.arange(8, dtype=np.float32).reshape(4, 2) + base_value), + ) root.create_dataset( "observations/qpos", - data=np.arange(16, dtype=np.float32).reshape(4, 4), + data=(np.arange(16, dtype=np.float32).reshape(4, 4) + base_value), ) root.create_dataset("task", data=np.array([b"sim_transfer"])) root.create_dataset( "observations/images/front", - data=np.arange(4 * 8 * 8 * 3, dtype=np.uint8).reshape(4, 8, 8, 3), + data=((np.arange(4 * 8 * 8 * 3, dtype=np.uint8) + int(base_value)) % 255).reshape(4, 8, 8, 3), ) def test_getitem_only_resizes_observation_horizon_images(self): @@ -100,3 +103,25 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase): self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8)) self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4)) self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8)) + + def test_dataset_can_limit_loading_to_specific_episode_indices(self): + with tempfile.TemporaryDirectory() as tmpdir: + dataset_dir = Path(tmpdir) + self._write_episode(dataset_dir, episode_idx=0, base_value=0.0) + self._write_episode(dataset_dir, episode_idx=1, base_value=100.0) + + dataset = SimpleRobotDataset( + dataset_dir, + obs_horizon=2, + pred_horizon=3, + camera_names=["front"], + image_resize_shape=None, + episode_indices=[1], + ) + + sample = dataset[0] + + self.assertEqual(len(dataset.hdf5_files), 1) + self.assertEqual(dataset.available_episode_indices, [1]) + self.assertEqual(len(dataset), 4) + self.assertTrue(np.allclose(sample["observation.state"][0].numpy(), np.array([100.0, 101.0, 102.0, 103.0]))) diff --git a/tests/test_train_vla_swanlab_logging.py b/tests/test_train_vla_swanlab_logging.py index a191e5e..e13e479 100644 --- a/tests/test_train_vla_swanlab_logging.py +++ b/tests/test_train_vla_swanlab_logging.py @@ -41,6 +41,19 @@ class FakeDataset: return 4 +class SplitAwareFakeDataset(FakeDataset): + def __init__(self, episode_indices=None): + self.episode_indices = None if episode_indices is None else list(episode_indices) + if self.episode_indices is None: + self.episodes = {0: [0], 1: [1], 2: [2]} + else: + self.episodes = {idx: [idx] for idx in self.episode_indices} + + @property + def available_episode_indices(self): + return sorted(self.episodes.keys()) + + class FakeLoader: def __init__(self, batch): self.batch = batch @@ -123,6 +136,10 @@ class RecordingAgent(FakeAgent): self.seen_inputs.append(agent_input) return super().compute_loss(agent_input) + def predict_action_chunk(self, agent_input): + self.seen_inputs.append({'predict_action_chunk': agent_input}) + return torch.ones_like(agent_input['action']) + class ShapeMixedFakeAgent(FakeAgent): def __init__(self): @@ -355,6 +372,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase): batch_size=2, num_workers=0, val_split=0.25, + val_episode_indices=None, + action_mse_val_freq_epochs=0, seed=0, lr=1e-3, max_steps=2, @@ -479,6 +498,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase): 'batch_size': 2, 'num_workers': 0, 'val_split': 0.25, + 'val_episode_indices': None, + 'action_mse_val_freq_epochs': 0, 'seed': 0, 'lr': 1e-3, 'max_steps': 2, @@ -561,6 +582,58 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase): self.assertIn('front', first_input['lewm_images']) self.assertIn('front', first_input['lewm_future_images']) + def test_run_training_logs_epoch_action_mse_for_held_out_val_episode(self): + module = self._load_train_vla_module() + run_training = self._get_run_training(module) + cfg = self._make_cfg() + cfg.train.max_steps = 1 + cfg.train.save_freq = 100 + cfg.train.val_split = 0.0 + cfg.train.val_episode_indices = [2] + cfg.train.action_mse_val_freq_epochs = 1 + agent = RecordingAgent() + fake_swanlab = FakeSwanLab() + real_import_module = importlib.import_module + + def fake_instantiate(config_node, **kwargs): + if config_node is cfg.data: + return SplitAwareFakeDataset(kwargs.get('episode_indices')) + if config_node is cfg.agent: + return agent + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + def fake_loader_factory(dataset, *, shuffle, **_kwargs): + action_value = 0.0 if shuffle else 2.0 + batch = { + 'observation.front': torch.zeros(1, 3, 2, 2), + 'observation.state': torch.zeros(1, 4), + 'action': torch.full((1, 1, 2), action_value), + 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), + } + return FakeLoader(batch) + + def fake_import_module(name, package=None): + if name == 'swanlab': + return fake_swanlab + return real_import_module(name, package) + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(module, 'DataLoader', side_effect=fake_loader_factory), \ + mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ + mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ + mock.patch.object(module.torch, 'save', return_value=None), \ + mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module): + run_training(cfg) + finally: + os.chdir(previous_cwd) + + logged_keys = set().union(*(payload.keys() for payload, _ in fake_swanlab.log_calls)) + self.assertIn('val/action_mse', logged_keys) + def test_run_training_skips_swanlab_when_disabled(self): module = self._load_train_vla_module() run_training = self._get_run_training(module)