feat: add held-out validation and dual-decoder lewm imf
This commit is contained in:
@@ -118,6 +118,127 @@ def recursive_to_device(data, device):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def build_agent_input(batch_data):
|
||||||
|
agent_input = {
|
||||||
|
'images': {
|
||||||
|
cam_name.replace('observation.', ''): value
|
||||||
|
for cam_name, value in batch_data.items()
|
||||||
|
if cam_name.startswith('observation.') and cam_name != 'observation.state'
|
||||||
|
},
|
||||||
|
'qpos': batch_data['observation.state'],
|
||||||
|
'action': batch_data['action'],
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'action_is_pad' in batch_data:
|
||||||
|
agent_input['action_is_pad'] = batch_data['action_is_pad']
|
||||||
|
|
||||||
|
lewm_images = {
|
||||||
|
cam_name.replace('lewm.observation.', ''): value
|
||||||
|
for cam_name, value in batch_data.items()
|
||||||
|
if cam_name.startswith('lewm.observation.') and cam_name != 'lewm.observation.state'
|
||||||
|
}
|
||||||
|
if lewm_images:
|
||||||
|
agent_input['lewm_images'] = lewm_images
|
||||||
|
if 'lewm.observation.state' in batch_data:
|
||||||
|
agent_input['lewm_qpos'] = batch_data['lewm.observation.state']
|
||||||
|
|
||||||
|
lewm_future_images = {
|
||||||
|
cam_name.replace('lewm.future.', ''): value
|
||||||
|
for cam_name, value in batch_data.items()
|
||||||
|
if cam_name.startswith('lewm.future.') and cam_name != 'lewm.future.state'
|
||||||
|
}
|
||||||
|
if lewm_future_images:
|
||||||
|
agent_input['lewm_future_images'] = lewm_future_images
|
||||||
|
if 'lewm.future.state' in batch_data:
|
||||||
|
agent_input['lewm_future_qpos'] = batch_data['lewm.future.state']
|
||||||
|
|
||||||
|
return agent_input
|
||||||
|
|
||||||
|
|
||||||
|
def _instantiate_dataset(cfg, dataset_image_resize_shape, episode_indices=None):
|
||||||
|
kwargs = {'image_resize_shape': dataset_image_resize_shape}
|
||||||
|
if episode_indices is not None:
|
||||||
|
kwargs['episode_indices'] = episode_indices
|
||||||
|
return instantiate(cfg.data, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def build_train_val_datasets(cfg, dataset_image_resize_shape):
|
||||||
|
val_episode_indices = cfg.train.get('val_episode_indices', None)
|
||||||
|
if val_episode_indices:
|
||||||
|
dataset = _instantiate_dataset(cfg, dataset_image_resize_shape)
|
||||||
|
available_episode_indices = list(getattr(dataset, 'available_episode_indices', []))
|
||||||
|
if not available_episode_indices:
|
||||||
|
raise ValueError('显式 val_episode_indices 需要数据集暴露 available_episode_indices')
|
||||||
|
requested_val_episode_indices = sorted(int(idx) for idx in val_episode_indices)
|
||||||
|
available_set = set(available_episode_indices)
|
||||||
|
missing = sorted(set(requested_val_episode_indices) - available_set)
|
||||||
|
if missing:
|
||||||
|
raise ValueError(
|
||||||
|
f'val_episode_indices {missing} 不存在于数据集可用 episodes {available_episode_indices}'
|
||||||
|
)
|
||||||
|
train_episode_indices = [
|
||||||
|
idx for idx in available_episode_indices
|
||||||
|
if idx not in set(requested_val_episode_indices)
|
||||||
|
]
|
||||||
|
if not train_episode_indices:
|
||||||
|
raise ValueError('显式 val_episode_indices 不能覆盖全部 episodes,训练集将为空')
|
||||||
|
|
||||||
|
train_dataset = _instantiate_dataset(
|
||||||
|
cfg,
|
||||||
|
dataset_image_resize_shape,
|
||||||
|
episode_indices=train_episode_indices,
|
||||||
|
)
|
||||||
|
val_dataset = _instantiate_dataset(
|
||||||
|
cfg,
|
||||||
|
dataset_image_resize_shape,
|
||||||
|
episode_indices=requested_val_episode_indices,
|
||||||
|
)
|
||||||
|
return dataset, train_dataset, val_dataset, requested_val_episode_indices
|
||||||
|
|
||||||
|
dataset = _instantiate_dataset(cfg, dataset_image_resize_shape)
|
||||||
|
val_split = float(cfg.train.get('val_split', 0.1))
|
||||||
|
seed = int(cfg.train.get('seed', 42))
|
||||||
|
val_size = int(len(dataset) * val_split)
|
||||||
|
train_size = len(dataset) - val_size
|
||||||
|
if val_size > 0:
|
||||||
|
train_dataset, val_dataset = random_split(
|
||||||
|
dataset,
|
||||||
|
[train_size, val_size],
|
||||||
|
generator=torch.Generator().manual_seed(seed)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_dataset, val_dataset = dataset, None
|
||||||
|
return dataset, train_dataset, val_dataset, None
|
||||||
|
|
||||||
|
|
||||||
|
def compute_action_mse_validation(agent, val_loader, device):
|
||||||
|
if val_loader is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
was_training = agent.training
|
||||||
|
agent.eval()
|
||||||
|
total_squared_error = 0.0
|
||||||
|
total_count = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for val_batch in val_loader:
|
||||||
|
val_batch = recursive_to_device(val_batch, device)
|
||||||
|
val_input = build_agent_input(val_batch)
|
||||||
|
pred_actions = agent.predict_action_chunk(val_input)
|
||||||
|
target_actions = val_input['action']
|
||||||
|
squared_error = (pred_actions - target_actions).pow(2)
|
||||||
|
action_is_pad = val_input.get('action_is_pad', None)
|
||||||
|
if action_is_pad is not None:
|
||||||
|
mask = (~action_is_pad).unsqueeze(-1).to(squared_error.dtype)
|
||||||
|
total_squared_error += (squared_error * mask).sum().item()
|
||||||
|
total_count += mask.sum().item() * squared_error.shape[-1]
|
||||||
|
else:
|
||||||
|
total_squared_error += squared_error.sum().item()
|
||||||
|
total_count += target_actions.numel()
|
||||||
|
if was_training:
|
||||||
|
agent.train()
|
||||||
|
return total_squared_error / max(total_count, 1.0)
|
||||||
|
|
||||||
|
|
||||||
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
|
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
|
||||||
"""
|
"""
|
||||||
解析恢复训练用的 checkpoint 路径。
|
解析恢复训练用的 checkpoint 路径。
|
||||||
@@ -410,30 +531,30 @@ def _run_training(cfg: DictConfig):
|
|||||||
vision_backbone_cfg = cfg.agent.get('vision_backbone', None)
|
vision_backbone_cfg = cfg.agent.get('vision_backbone', None)
|
||||||
if vision_backbone_cfg is not None and 'dataset_image_resize_shape' in vision_backbone_cfg:
|
if vision_backbone_cfg is not None and 'dataset_image_resize_shape' in vision_backbone_cfg:
|
||||||
dataset_image_resize_shape = vision_backbone_cfg.get('dataset_image_resize_shape')
|
dataset_image_resize_shape = vision_backbone_cfg.get('dataset_image_resize_shape')
|
||||||
dataset = instantiate(
|
dataset, train_dataset, val_dataset, explicit_val_episode_indices = (
|
||||||
cfg.data,
|
build_train_val_datasets(cfg, dataset_image_resize_shape)
|
||||||
image_resize_shape=dataset_image_resize_shape,
|
|
||||||
)
|
)
|
||||||
log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}")
|
log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"❌ 数据集加载失败: {e}")
|
log.error(f"❌ 数据集加载失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# 训练/验证集划分
|
if explicit_val_episode_indices is not None:
|
||||||
val_split = float(cfg.train.get('val_split', 0.1))
|
log.info(
|
||||||
seed = int(cfg.train.get('seed', 42))
|
"✅ 数据集划分: 训练集=%s, 验证集=%s (显式 held-out episodes=%s)",
|
||||||
val_size = int(len(dataset) * val_split)
|
len(train_dataset),
|
||||||
train_size = len(dataset) - val_size
|
len(val_dataset),
|
||||||
if val_size > 0:
|
explicit_val_episode_indices,
|
||||||
train_dataset, val_dataset = random_split(
|
|
||||||
dataset,
|
|
||||||
[train_size, val_size],
|
|
||||||
generator=torch.Generator().manual_seed(seed)
|
|
||||||
)
|
)
|
||||||
log.info(f"✅ 数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})")
|
|
||||||
else:
|
else:
|
||||||
train_dataset, val_dataset = dataset, None
|
val_split = float(cfg.train.get('val_split', 0.1))
|
||||||
log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)")
|
val_size = len(val_dataset) if val_dataset is not None else 0
|
||||||
|
if val_size > 0:
|
||||||
|
log.info(
|
||||||
|
f"✅ 数据集划分: 训练集={len(train_dataset)}, 验证集={val_size} (验证比例={val_split})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)")
|
||||||
|
|
||||||
train_batch_size = int(cfg.train.batch_size)
|
train_batch_size = int(cfg.train.batch_size)
|
||||||
train_drop_last = len(train_dataset) >= train_batch_size
|
train_drop_last = len(train_dataset) >= train_batch_size
|
||||||
@@ -674,44 +795,6 @@ def _run_training(cfg: DictConfig):
|
|||||||
# =========================================================================
|
# =========================================================================
|
||||||
log.info("🏋️ 开始训练循环...")
|
log.info("🏋️ 开始训练循环...")
|
||||||
|
|
||||||
def build_agent_input(batch_data):
|
|
||||||
"""构建 agent 输入格式"""
|
|
||||||
images = {}
|
|
||||||
# SimpleRobotDataset 返回 observation.{cam_name} 格式
|
|
||||||
for cam_name in cfg.data.camera_names:
|
|
||||||
key = f"observation.{cam_name}"
|
|
||||||
if key in batch_data:
|
|
||||||
images[cam_name] = batch_data[key]
|
|
||||||
|
|
||||||
agent_input = {
|
|
||||||
'images': images,
|
|
||||||
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
|
|
||||||
'action': batch_data['action'],
|
|
||||||
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
|
|
||||||
}
|
|
||||||
|
|
||||||
lewm_images = {}
|
|
||||||
lewm_future_images = {}
|
|
||||||
for cam_name in cfg.data.camera_names:
|
|
||||||
lewm_obs_key = f"lewm.observation.{cam_name}"
|
|
||||||
if lewm_obs_key in batch_data:
|
|
||||||
lewm_images[cam_name] = batch_data[lewm_obs_key]
|
|
||||||
|
|
||||||
lewm_future_key = f"lewm.future.{cam_name}"
|
|
||||||
if lewm_future_key in batch_data:
|
|
||||||
lewm_future_images[cam_name] = batch_data[lewm_future_key]
|
|
||||||
|
|
||||||
if 'lewm.observation.state' in batch_data:
|
|
||||||
agent_input['lewm_qpos'] = batch_data['lewm.observation.state']
|
|
||||||
if lewm_images:
|
|
||||||
agent_input['lewm_images'] = lewm_images
|
|
||||||
if 'lewm.future.state' in batch_data:
|
|
||||||
agent_input['lewm_future_qpos'] = batch_data['lewm.future.state']
|
|
||||||
if lewm_future_images:
|
|
||||||
agent_input['lewm_future_images'] = lewm_future_images
|
|
||||||
|
|
||||||
return agent_input
|
|
||||||
|
|
||||||
def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None):
|
def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None):
|
||||||
agent_stats = agent.get_normalization_stats()
|
agent_stats = agent.get_normalization_stats()
|
||||||
torch.save({
|
torch.save({
|
||||||
@@ -784,6 +867,7 @@ def _run_training(cfg: DictConfig):
|
|||||||
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
|
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
|
||||||
|
|
||||||
steps_per_epoch = len(train_loader)
|
steps_per_epoch = len(train_loader)
|
||||||
|
action_mse_val_freq_epochs = int(cfg.train.get('action_mse_val_freq_epochs', 0) or 0)
|
||||||
rollout_val_freq_epochs = int(cfg.train.get('rollout_val_freq_epochs', 0) or 0)
|
rollout_val_freq_epochs = int(cfg.train.get('rollout_val_freq_epochs', 0) or 0)
|
||||||
rollout_validation_enabled = rollout_val_freq_epochs > 0
|
rollout_validation_enabled = rollout_val_freq_epochs > 0
|
||||||
best_loss = resume_best_loss
|
best_loss = resume_best_loss
|
||||||
@@ -953,6 +1037,33 @@ def _run_training(cfg: DictConfig):
|
|||||||
and completed_epoch > 0
|
and completed_epoch > 0
|
||||||
and completed_epoch % rollout_val_freq_epochs == 0
|
and completed_epoch % rollout_val_freq_epochs == 0
|
||||||
)
|
)
|
||||||
|
should_run_action_mse_validation = (
|
||||||
|
action_mse_val_freq_epochs > 0
|
||||||
|
and val_loader is not None
|
||||||
|
and steps_per_epoch > 0
|
||||||
|
and completed_steps % steps_per_epoch == 0
|
||||||
|
and completed_epoch > 0
|
||||||
|
and completed_epoch % action_mse_val_freq_epochs == 0
|
||||||
|
)
|
||||||
|
if should_run_action_mse_validation:
|
||||||
|
action_mse = compute_action_mse_validation(
|
||||||
|
agent,
|
||||||
|
val_loader,
|
||||||
|
cfg.train.device,
|
||||||
|
)
|
||||||
|
if action_mse is not None:
|
||||||
|
log.info(
|
||||||
|
f"步骤 {step}/{cfg.train.max_steps} | Epoch {completed_epoch} "
|
||||||
|
f"held-out action MSE: {action_mse:.6f}"
|
||||||
|
)
|
||||||
|
_log_to_swanlab(
|
||||||
|
swanlab_module,
|
||||||
|
{
|
||||||
|
'val/action_mse': action_mse,
|
||||||
|
'val/action_mse_epoch': completed_epoch,
|
||||||
|
},
|
||||||
|
step=step,
|
||||||
|
)
|
||||||
if should_run_epoch_rollout:
|
if should_run_epoch_rollout:
|
||||||
if checkpoint_path is None:
|
if checkpoint_path is None:
|
||||||
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ class IMFVLAAgent(VLAAgent):
|
|||||||
lewm_query_offsets: Optional[Sequence[int]] = None,
|
lewm_query_offsets: Optional[Sequence[int]] = None,
|
||||||
lewm_predictor: Optional[nn.Module] = None,
|
lewm_predictor: Optional[nn.Module] = None,
|
||||||
lewm_pred_projector: Optional[nn.Module] = None,
|
lewm_pred_projector: Optional[nn.Module] = None,
|
||||||
|
future_decoder: Optional[nn.Module] = None,
|
||||||
|
future_query_init_std: float = 0.02,
|
||||||
lewm_sigreg: Optional[nn.Module] = None,
|
lewm_sigreg: Optional[nn.Module] = None,
|
||||||
lewm_sigreg_weight: float = 0.09,
|
lewm_sigreg_weight: float = 0.09,
|
||||||
lewm_loss_weight: float = 0.0,
|
lewm_loss_weight: float = 0.0,
|
||||||
@@ -39,11 +41,17 @@ class IMFVLAAgent(VLAAgent):
|
|||||||
)
|
)
|
||||||
lewm_query_offsets = tuple(int(offset) for offset in (lewm_query_offsets or ()))
|
lewm_query_offsets = tuple(int(offset) for offset in (lewm_query_offsets or ()))
|
||||||
inferred_extra_condition_tokens = len(lewm_query_offsets) if lewm_query_offsets else 0
|
inferred_extra_condition_tokens = len(lewm_query_offsets) if lewm_query_offsets else 0
|
||||||
kwargs.setdefault('extra_condition_tokens', inferred_extra_condition_tokens)
|
default_extra_condition_tokens = (
|
||||||
|
0 if future_decoder is not None else inferred_extra_condition_tokens
|
||||||
|
)
|
||||||
|
kwargs.setdefault('extra_condition_tokens', default_extra_condition_tokens)
|
||||||
self.__dict__['lewm_history_horizon'] = int(lewm_history_horizon or kwargs.get('obs_horizon', 1))
|
self.__dict__['lewm_history_horizon'] = int(lewm_history_horizon or kwargs.get('obs_horizon', 1))
|
||||||
self.__dict__['lewm_query_offsets'] = lewm_query_offsets
|
self.__dict__['lewm_query_offsets'] = lewm_query_offsets
|
||||||
self.__dict__['lewm_predictor'] = lewm_predictor
|
self.__dict__['lewm_predictor'] = lewm_predictor
|
||||||
self.__dict__['lewm_pred_projector'] = lewm_pred_projector or nn.Identity()
|
self.__dict__['lewm_pred_projector'] = lewm_pred_projector or nn.Identity()
|
||||||
|
self.__dict__['future_decoder'] = future_decoder
|
||||||
|
self.__dict__['future_query_tokens'] = None
|
||||||
|
self.__dict__['future_query_init_std'] = float(future_query_init_std)
|
||||||
self.__dict__['lewm_sigreg'] = lewm_sigreg
|
self.__dict__['lewm_sigreg'] = lewm_sigreg
|
||||||
self.__dict__['lewm_sigreg_weight'] = float(lewm_sigreg_weight)
|
self.__dict__['lewm_sigreg_weight'] = float(lewm_sigreg_weight)
|
||||||
self.__dict__['lewm_loss_weight'] = float(lewm_loss_weight)
|
self.__dict__['lewm_loss_weight'] = float(lewm_loss_weight)
|
||||||
@@ -59,8 +67,16 @@ class IMFVLAAgent(VLAAgent):
|
|||||||
self.lewm_history_horizon = int(lewm_history_horizon or self.obs_horizon)
|
self.lewm_history_horizon = int(lewm_history_horizon or self.obs_horizon)
|
||||||
self.lewm_predictor = lewm_predictor
|
self.lewm_predictor = lewm_predictor
|
||||||
self.lewm_pred_projector = lewm_pred_projector or nn.Identity()
|
self.lewm_pred_projector = lewm_pred_projector or nn.Identity()
|
||||||
|
if future_decoder is not None and not isinstance(future_decoder, nn.Module):
|
||||||
|
self.future_decoder = future_decoder()
|
||||||
|
else:
|
||||||
|
self.future_decoder = future_decoder
|
||||||
|
self.future_query_tokens = None
|
||||||
|
self.future_query_init_std = float(future_query_init_std)
|
||||||
self.lewm_sigreg = lewm_sigreg
|
self.lewm_sigreg = lewm_sigreg
|
||||||
self.lewm_sigreg_weight = float(lewm_sigreg_weight)
|
self.lewm_sigreg_weight = float(lewm_sigreg_weight)
|
||||||
|
if self.lewm_predictor is not None and self.future_decoder is not None:
|
||||||
|
raise ValueError('lewm_predictor and future_decoder are mutually exclusive')
|
||||||
if self.lewm_predictor is None and self.extra_condition_tokens > 0:
|
if self.lewm_predictor is None and self.extra_condition_tokens > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'extra_condition_tokens > 0 requires lewm_predictor to be provided'
|
'extra_condition_tokens > 0 requires lewm_predictor to be provided'
|
||||||
@@ -69,6 +85,18 @@ class IMFVLAAgent(VLAAgent):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'extra_condition_tokens must equal len(lewm_query_offsets) when lewm_predictor is enabled'
|
'extra_condition_tokens must equal len(lewm_query_offsets) when lewm_predictor is enabled'
|
||||||
)
|
)
|
||||||
|
if self.future_decoder is not None:
|
||||||
|
if inferred_extra_condition_tokens <= 0:
|
||||||
|
raise ValueError('future_decoder requires non-empty lewm_query_offsets')
|
||||||
|
if self.extra_condition_tokens != 0:
|
||||||
|
raise ValueError('future_decoder requires extra_condition_tokens=0')
|
||||||
|
self.future_query_tokens = nn.Parameter(
|
||||||
|
torch.randn(
|
||||||
|
1,
|
||||||
|
inferred_extra_condition_tokens,
|
||||||
|
self.per_step_cond_dim,
|
||||||
|
) * self.future_query_init_std
|
||||||
|
)
|
||||||
if lewm_pretrained_ckpt is not None:
|
if lewm_pretrained_ckpt is not None:
|
||||||
self.load_lewm_pretrained_components(lewm_pretrained_ckpt)
|
self.load_lewm_pretrained_components(lewm_pretrained_ckpt)
|
||||||
|
|
||||||
@@ -232,15 +260,19 @@ class IMFVLAAgent(VLAAgent):
|
|||||||
*,
|
*,
|
||||||
query_key: str = 'query_tokens',
|
query_key: str = 'query_tokens',
|
||||||
pos_key: str = 'pos_embedding',
|
pos_key: str = 'pos_embedding',
|
||||||
) -> None:
|
) -> Dict[str, Sequence[str]]:
|
||||||
current_state_dict = module.state_dict()
|
current_state_dict = module.state_dict()
|
||||||
adapted_state_dict = dict(current_state_dict)
|
adapted_state_dict = dict(current_state_dict)
|
||||||
|
loaded_keys = []
|
||||||
|
mismatched_keys = []
|
||||||
|
missing_keys = []
|
||||||
for key, current_tensor in current_state_dict.items():
|
for key, current_tensor in current_state_dict.items():
|
||||||
if key not in incoming_state_dict:
|
if key not in incoming_state_dict:
|
||||||
continue
|
continue
|
||||||
source_tensor = incoming_state_dict[key]
|
source_tensor = incoming_state_dict[key]
|
||||||
if source_tensor.shape == current_tensor.shape:
|
if source_tensor.shape == current_tensor.shape:
|
||||||
adapted_state_dict[key] = source_tensor
|
adapted_state_dict[key] = source_tensor
|
||||||
|
loaded_keys.append(key)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim:
|
if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim:
|
||||||
@@ -256,8 +288,47 @@ class IMFVLAAgent(VLAAgent):
|
|||||||
if copy_count < current_tensor.shape[1] and copy_count > 0:
|
if copy_count < current_tensor.shape[1] and copy_count > 0:
|
||||||
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
|
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
|
||||||
adapted_state_dict[key] = patched
|
adapted_state_dict[key] = patched
|
||||||
|
loaded_keys.append(key)
|
||||||
|
continue
|
||||||
|
mismatched_keys.append(key)
|
||||||
|
|
||||||
|
for key in incoming_state_dict.keys():
|
||||||
|
if key not in current_state_dict:
|
||||||
|
missing_keys.append(key)
|
||||||
module.load_state_dict(adapted_state_dict, strict=True)
|
module.load_state_dict(adapted_state_dict, strict=True)
|
||||||
|
return {
|
||||||
|
'loaded_keys': tuple(sorted(loaded_keys)),
|
||||||
|
'mismatched_keys': tuple(sorted(set(mismatched_keys))),
|
||||||
|
'missing_keys': tuple(sorted(set(missing_keys))),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_state_dict_ignoring_shape_mismatches(
|
||||||
|
module: nn.Module,
|
||||||
|
incoming_state_dict: Mapping[str, torch.Tensor],
|
||||||
|
) -> Dict[str, Sequence[str]]:
|
||||||
|
current_state_dict = module.state_dict()
|
||||||
|
merged_state_dict = dict(current_state_dict)
|
||||||
|
loaded_keys = []
|
||||||
|
mismatched_keys = []
|
||||||
|
missing_keys = []
|
||||||
|
|
||||||
|
for key, value in incoming_state_dict.items():
|
||||||
|
if key not in current_state_dict:
|
||||||
|
missing_keys.append(key)
|
||||||
|
continue
|
||||||
|
if current_state_dict[key].shape != value.shape:
|
||||||
|
mismatched_keys.append(key)
|
||||||
|
continue
|
||||||
|
merged_state_dict[key] = value
|
||||||
|
loaded_keys.append(key)
|
||||||
|
|
||||||
|
module.load_state_dict(merged_state_dict, strict=True)
|
||||||
|
return {
|
||||||
|
'loaded_keys': tuple(sorted(loaded_keys)),
|
||||||
|
'mismatched_keys': tuple(sorted(mismatched_keys)),
|
||||||
|
'missing_keys': tuple(sorted(missing_keys)),
|
||||||
|
}
|
||||||
|
|
||||||
def load_lewm_pretrained_components(
|
def load_lewm_pretrained_components(
|
||||||
self,
|
self,
|
||||||
@@ -266,20 +337,24 @@ class IMFVLAAgent(VLAAgent):
|
|||||||
state_dict = self._load_checkpoint_payload(checkpoint_or_path)
|
state_dict = self._load_checkpoint_payload(checkpoint_or_path)
|
||||||
|
|
||||||
if hasattr(self.vision_encoder, 'load_lewm_checkpoint'):
|
if hasattr(self.vision_encoder, 'load_lewm_checkpoint'):
|
||||||
self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict})
|
try:
|
||||||
|
self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict})
|
||||||
|
except RuntimeError:
|
||||||
|
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
|
||||||
|
self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict)
|
||||||
else:
|
else:
|
||||||
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
|
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
|
||||||
self.vision_encoder.load_state_dict(vision_state_dict, strict=True)
|
self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict)
|
||||||
|
|
||||||
state_encoder_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.state_encoder.')
|
state_encoder_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.state_encoder.')
|
||||||
self.state_encoder.load_state_dict(state_encoder_state_dict, strict=True)
|
self._load_state_dict_ignoring_shape_mismatches(self.state_encoder, state_encoder_state_dict)
|
||||||
|
|
||||||
projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.projector.proj.')
|
projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.projector.proj.')
|
||||||
mapped_projector_state_dict = {
|
mapped_projector_state_dict = {
|
||||||
f'linear.{key}': value
|
f'linear.{key}': value
|
||||||
for key, value in projector_state_dict.items()
|
for key, value in projector_state_dict.items()
|
||||||
}
|
}
|
||||||
self.cond_projector.load_state_dict(mapped_projector_state_dict, strict=True)
|
self._load_state_dict_ignoring_shape_mismatches(self.cond_projector, mapped_projector_state_dict)
|
||||||
|
|
||||||
if self.lewm_predictor is not None:
|
if self.lewm_predictor is not None:
|
||||||
predictor_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.predictor.')
|
predictor_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.predictor.')
|
||||||
@@ -287,7 +362,19 @@ class IMFVLAAgent(VLAAgent):
|
|||||||
|
|
||||||
if self.lewm_pred_projector is not None:
|
if self.lewm_pred_projector is not None:
|
||||||
pred_projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.pred_proj.')
|
pred_projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.pred_proj.')
|
||||||
self.lewm_pred_projector.load_state_dict(pred_projector_state_dict, strict=True)
|
self._load_state_dict_ignoring_shape_mismatches(
|
||||||
|
self.lewm_pred_projector,
|
||||||
|
pred_projector_state_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _predict_future_tokens_with_decoder(self, history_cond: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.future_decoder is None or self.future_query_tokens is None:
|
||||||
|
raise RuntimeError('future_decoder path requested but not initialized')
|
||||||
|
batch_size = history_cond.shape[0]
|
||||||
|
query_tokens = self.future_query_tokens.expand(batch_size, -1, -1)
|
||||||
|
r = torch.zeros(batch_size, device=history_cond.device, dtype=history_cond.dtype)
|
||||||
|
t = torch.ones(batch_size, device=history_cond.device, dtype=history_cond.dtype)
|
||||||
|
return self.future_decoder(query_tokens, r, t, cond=history_cond)
|
||||||
|
|
||||||
def _build_full_condition(
|
def _build_full_condition(
|
||||||
self,
|
self,
|
||||||
@@ -302,7 +389,7 @@ class IMFVLAAgent(VLAAgent):
|
|||||||
predicted_future_tokens = None
|
predicted_future_tokens = None
|
||||||
lewm_history_cond = None
|
lewm_history_cond = None
|
||||||
|
|
||||||
if self.lewm_predictor is None:
|
if self.lewm_predictor is None and self.future_decoder is None:
|
||||||
return history_cond, predicted_future_tokens, lewm_history_cond
|
return history_cond, predicted_future_tokens, lewm_history_cond
|
||||||
|
|
||||||
lewm_images = lewm_images if lewm_images is not None else images
|
lewm_images = lewm_images if lewm_images is not None else images
|
||||||
@@ -313,17 +400,21 @@ class IMFVLAAgent(VLAAgent):
|
|||||||
lewm_images,
|
lewm_images,
|
||||||
self._normalize_qpos_for_lewm(lewm_proprioception),
|
self._normalize_qpos_for_lewm(lewm_proprioception),
|
||||||
)
|
)
|
||||||
predicted_future_tokens = self.lewm_predictor(lewm_history_cond)
|
cond = history_cond
|
||||||
predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens)
|
if self.lewm_predictor is not None:
|
||||||
cond = torch.cat([history_cond, predicted_future_tokens], dim=1)
|
predicted_future_tokens = self.lewm_predictor(lewm_history_cond)
|
||||||
if cond.shape[1] != self.condition_sequence_length:
|
predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens)
|
||||||
raise RuntimeError(
|
cond = torch.cat([history_cond, predicted_future_tokens], dim=1)
|
||||||
f"完整条件序列长度不匹配: got {cond.shape[1]}, expected {self.condition_sequence_length}"
|
if cond.shape[1] != self.condition_sequence_length:
|
||||||
)
|
raise RuntimeError(
|
||||||
if cond.shape[-1] != self.per_step_cond_dim:
|
f"完整条件序列长度不匹配: got {cond.shape[1]}, expected {self.condition_sequence_length}"
|
||||||
raise RuntimeError(
|
)
|
||||||
f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
if cond.shape[-1] != self.per_step_cond_dim:
|
||||||
)
|
raise RuntimeError(
|
||||||
|
f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||||
|
)
|
||||||
|
elif self.future_decoder is not None:
|
||||||
|
predicted_future_tokens = self._predict_future_tokens_with_decoder(lewm_history_cond)
|
||||||
return cond, predicted_future_tokens, lewm_history_cond
|
return cond, predicted_future_tokens, lewm_history_cond
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -452,12 +543,18 @@ class IMFVLAAgent(VLAAgent):
|
|||||||
lewm_proprioception=None,
|
lewm_proprioception=None,
|
||||||
):
|
):
|
||||||
batch_size = proprioception.shape[0]
|
batch_size = proprioception.shape[0]
|
||||||
cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition(
|
if self.lewm_predictor is not None:
|
||||||
images,
|
cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition(
|
||||||
proprioception,
|
images,
|
||||||
lewm_images=lewm_images,
|
proprioception,
|
||||||
lewm_proprioception=lewm_proprioception,
|
lewm_images=lewm_images,
|
||||||
)
|
lewm_proprioception=lewm_proprioception,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cond = self._build_cond(
|
||||||
|
images,
|
||||||
|
self.normalization.normalize_qpos(proprioception),
|
||||||
|
)
|
||||||
z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype)
|
z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype)
|
||||||
action = self._sample_one_step(z_t, cond=cond)
|
action = self._sample_one_step(z_t, cond=cond)
|
||||||
return self.normalization.denormalize_action(action)
|
return self.normalization.denormalize_action(action)
|
||||||
|
|||||||
@@ -0,0 +1,74 @@
|
|||||||
|
# @package agent
|
||||||
|
defaults:
|
||||||
|
- /backbone@vision_backbone: lewm_resnet_query_fusion
|
||||||
|
- /modules@state_encoder: lewm_state_encoder
|
||||||
|
- /modules@action_encoder: identity_action_encoder
|
||||||
|
- /modules@cond_projector: linear_condition_projector
|
||||||
|
- /head: imf_transformer1d
|
||||||
|
- /head@future_decoder: imf_transformer1d
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||||
|
|
||||||
|
action_dim: 16
|
||||||
|
obs_dim: 16
|
||||||
|
normalization_type: "min_max"
|
||||||
|
pred_horizon: 8
|
||||||
|
obs_horizon: 2
|
||||||
|
num_action_steps: 8
|
||||||
|
camera_names: ${data.camera_names}
|
||||||
|
num_cams: 3
|
||||||
|
|
||||||
|
vision_backbone:
|
||||||
|
camera_names: ${agent.camera_names}
|
||||||
|
num_views: ${agent.num_cams}
|
||||||
|
|
||||||
|
cond_projector:
|
||||||
|
output_dim: 288
|
||||||
|
|
||||||
|
lewm_history_horizon: 3
|
||||||
|
lewm_query_offsets: [8]
|
||||||
|
extra_condition_tokens: 0
|
||||||
|
lewm_loss_weight: 1.0
|
||||||
|
lewm_sigreg_weight: 0.09
|
||||||
|
lewm_pretrained_ckpt: null
|
||||||
|
future_query_init_std: 0.02
|
||||||
|
|
||||||
|
lewm_sigreg:
|
||||||
|
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg
|
||||||
|
knots: 17
|
||||||
|
num_proj: 1024
|
||||||
|
|
||||||
|
diffusion_steps: 100
|
||||||
|
inference_steps: 1
|
||||||
|
head_type: "transformer"
|
||||||
|
|
||||||
|
head:
|
||||||
|
input_dim: ${agent.action_dim}
|
||||||
|
output_dim: ${agent.action_dim}
|
||||||
|
horizon: ${agent.pred_horizon}
|
||||||
|
n_obs_steps: ${agent.obs_horizon}
|
||||||
|
cond_dim: ${agent.cond_projector.output_dim}
|
||||||
|
n_emb: 384
|
||||||
|
causal_attn: false
|
||||||
|
time_as_cond: true
|
||||||
|
obs_as_cond: true
|
||||||
|
n_cond_layers: 0
|
||||||
|
backbone_type: attnres_full
|
||||||
|
n_head: 1
|
||||||
|
n_kv_head: 1
|
||||||
|
|
||||||
|
future_decoder:
|
||||||
|
input_dim: ${agent.cond_projector.output_dim}
|
||||||
|
output_dim: ${agent.cond_projector.output_dim}
|
||||||
|
horizon: ${len:${agent.lewm_query_offsets}}
|
||||||
|
n_obs_steps: ${agent.obs_horizon}
|
||||||
|
cond_dim: ${agent.cond_projector.output_dim}
|
||||||
|
n_emb: 384
|
||||||
|
causal_attn: false
|
||||||
|
time_as_cond: true
|
||||||
|
obs_as_cond: true
|
||||||
|
n_cond_layers: 0
|
||||||
|
backbone_type: attnres_full
|
||||||
|
n_head: 1
|
||||||
|
n_kv_head: 1
|
||||||
@@ -18,6 +18,8 @@ train:
|
|||||||
# 数据加载
|
# 数据加载
|
||||||
num_workers: 12 # DataLoader 工作进程数(调试时设为 0)
|
num_workers: 12 # DataLoader 工作进程数(调试时设为 0)
|
||||||
val_split: 0.0 # 验证集比例;默认使用全量数据训练
|
val_split: 0.0 # 验证集比例;默认使用全量数据训练
|
||||||
|
val_episode_indices: null # 显式按 episode 划出的验证集,例如 [100]
|
||||||
|
action_mse_val_freq_epochs: 0 # >0 时每隔多少个 epoch 在 held-out episode 上计算 action MSE
|
||||||
seed: 42 # 随机种子(用于数据划分)
|
seed: 42 # 随机种子(用于数据划分)
|
||||||
|
|
||||||
# 日志和检查点
|
# 日志和检查点
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class SimpleRobotDataset(Dataset):
|
|||||||
max_open_files: int = 64,
|
max_open_files: int = 64,
|
||||||
lewm_history_horizon: Optional[int] = None,
|
lewm_history_horizon: Optional[int] = None,
|
||||||
lewm_query_offsets: Optional[Sequence[int]] = None,
|
lewm_query_offsets: Optional[Sequence[int]] = None,
|
||||||
|
episode_indices: Optional[Sequence[int]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -57,6 +58,9 @@ class SimpleRobotDataset(Dataset):
|
|||||||
)
|
)
|
||||||
self.max_open_files = max(1, int(max_open_files))
|
self.max_open_files = max(1, int(max_open_files))
|
||||||
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
|
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
|
||||||
|
self.requested_episode_indices = (
|
||||||
|
None if episode_indices is None else tuple(sorted(int(idx) for idx in episode_indices))
|
||||||
|
)
|
||||||
|
|
||||||
self.dataset_dir = Path(dataset_dir)
|
self.dataset_dir = Path(dataset_dir)
|
||||||
if not self.dataset_dir.exists():
|
if not self.dataset_dir.exists():
|
||||||
@@ -69,20 +73,45 @@ class SimpleRobotDataset(Dataset):
|
|||||||
if not self.hdf5_files:
|
if not self.hdf5_files:
|
||||||
raise FileNotFoundError(f"在 {dataset_dir} 中未找到 HDF5 文件")
|
raise FileNotFoundError(f"在 {dataset_dir} 中未找到 HDF5 文件")
|
||||||
|
|
||||||
|
if self.requested_episode_indices is not None:
|
||||||
|
requested = set(self.requested_episode_indices)
|
||||||
|
filtered = []
|
||||||
|
for hdf5_path in self.hdf5_files:
|
||||||
|
stem = hdf5_path.stem
|
||||||
|
if stem.startswith("episode_"):
|
||||||
|
try:
|
||||||
|
idx = int(stem.split("_")[-1])
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
if idx in requested:
|
||||||
|
filtered.append(hdf5_path)
|
||||||
|
self.hdf5_files = filtered
|
||||||
|
if not self.hdf5_files:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"在 {dataset_dir} 中未找到 episode_indices={sorted(requested)} 对应的 HDF5 文件"
|
||||||
|
)
|
||||||
|
|
||||||
# 构建 episode 索引(只存储元数据,不加载数据)
|
# 构建 episode 索引(只存储元数据,不加载数据)
|
||||||
self.episodes = {}
|
self.episodes = {}
|
||||||
self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
|
self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
|
||||||
for ep_idx, hdf5_path in enumerate(self.hdf5_files):
|
for ep_idx, hdf5_path in enumerate(self.hdf5_files):
|
||||||
with h5py.File(hdf5_path, 'r') as f:
|
with h5py.File(hdf5_path, 'r') as f:
|
||||||
T = f['action'].shape[0]
|
T = f['action'].shape[0]
|
||||||
|
dataset_episode_idx = ep_idx
|
||||||
|
stem = hdf5_path.stem
|
||||||
|
if stem.startswith("episode_"):
|
||||||
|
try:
|
||||||
|
dataset_episode_idx = int(stem.split("_")[-1])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
start_idx = len(self.frame_meta)
|
start_idx = len(self.frame_meta)
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
self.frame_meta.append({
|
self.frame_meta.append({
|
||||||
"ep_idx": ep_idx,
|
"ep_idx": dataset_episode_idx,
|
||||||
"frame_idx": t,
|
"frame_idx": t,
|
||||||
"hdf5_path": hdf5_path,
|
"hdf5_path": hdf5_path,
|
||||||
})
|
})
|
||||||
self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta)))
|
self.episodes[dataset_episode_idx] = list(range(start_idx, len(self.frame_meta)))
|
||||||
|
|
||||||
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)} 帧")
|
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)} 帧")
|
||||||
|
|
||||||
@@ -290,6 +319,10 @@ class SimpleRobotDataset(Dataset):
|
|||||||
"""获取所有相机键名 (LeRobotDataset 格式)"""
|
"""获取所有相机键名 (LeRobotDataset 格式)"""
|
||||||
return [f"observation.{cam_name}" for cam_name in self.camera_names]
|
return [f"observation.{cam_name}" for cam_name in self.camera_names]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_episode_indices(self) -> List[int]:
|
||||||
|
return sorted(self.episodes.keys())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def camera_info(self) -> dict:
|
def camera_info(self) -> dict:
|
||||||
"""获取相机信息"""
|
"""获取相机信息"""
|
||||||
|
|||||||
@@ -388,6 +388,26 @@ class _StubFutureTokenPredictor(nn.Module):
|
|||||||
return summary.repeat(1, self.num_future_tokens, 1)
|
return summary.repeat(1, self.num_future_tokens, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class _RecordingDirectFutureDecoder(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = nn.Parameter(torch.tensor(0.5))
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def forward(self, sample, r, t, cond=None):
|
||||||
|
record = {
|
||||||
|
'sample': sample.detach().clone(),
|
||||||
|
'r': r.detach().clone(),
|
||||||
|
't': t.detach().clone(),
|
||||||
|
'cond': None if cond is None else cond.detach().clone(),
|
||||||
|
}
|
||||||
|
self.calls.append(record)
|
||||||
|
cond_term = 0.0
|
||||||
|
if cond is not None:
|
||||||
|
cond_term = cond.mean(dim=1, keepdim=True)
|
||||||
|
return self.scale * sample + cond_term
|
||||||
|
|
||||||
|
|
||||||
class _RecordingSigReg(nn.Module):
|
class _RecordingSigReg(nn.Module):
|
||||||
def __init__(self, value=0.5):
|
def __init__(self, value=0.5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -687,6 +707,148 @@ class IMFVLAAgentTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1))
|
torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1))
|
||||||
|
|
||||||
|
def test_predict_action_with_dual_decoder_keeps_action_condition_history_only(self):
|
||||||
|
agent_cls, agent_module = _load_imf_agent_class()
|
||||||
|
head = _RecordingLinearIMFHead()
|
||||||
|
future_decoder = _RecordingDirectFutureDecoder()
|
||||||
|
agent = agent_cls(
|
||||||
|
vision_backbone=_StubVisionBackbone(),
|
||||||
|
state_encoder=nn.Identity(),
|
||||||
|
action_encoder=nn.Identity(),
|
||||||
|
head=head,
|
||||||
|
future_decoder=future_decoder,
|
||||||
|
action_dim=2,
|
||||||
|
obs_dim=1,
|
||||||
|
pred_horizon=3,
|
||||||
|
obs_horizon=2,
|
||||||
|
diffusion_steps=10,
|
||||||
|
inference_steps=1,
|
||||||
|
num_cams=len(_CAMERA_NAMES),
|
||||||
|
camera_names=_CAMERA_NAMES,
|
||||||
|
num_action_steps=2,
|
||||||
|
head_type='transformer',
|
||||||
|
lewm_history_horizon=3,
|
||||||
|
lewm_query_offsets=[8],
|
||||||
|
lewm_loss_weight=1.0,
|
||||||
|
)
|
||||||
|
agent.infer_scheduler = _ForbiddenScheduler()
|
||||||
|
with torch.no_grad():
|
||||||
|
agent.future_query_tokens.copy_(torch.tensor([[[0.1, 0.2, 0.3, 0.4]]], dtype=torch.float32))
|
||||||
|
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=2,
|
||||||
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||||
|
)
|
||||||
|
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||||
|
initial_noise = torch.tensor(
|
||||||
|
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
||||||
|
_ = agent.predict_action(images, qpos)
|
||||||
|
|
||||||
|
expected_history = torch.tensor(
|
||||||
|
[[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(head.calls), 1)
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_history))
|
||||||
|
self.assertEqual(len(future_decoder.calls), 0)
|
||||||
|
|
||||||
|
def test_compute_loss_with_dual_decoder_tracks_lewm_loss_breakdown(self):
|
||||||
|
agent_cls, agent_module = _load_imf_agent_class()
|
||||||
|
head = _RecordingLinearIMFHead()
|
||||||
|
future_decoder = _RecordingDirectFutureDecoder()
|
||||||
|
sigreg = _RecordingSigReg(value=0.75)
|
||||||
|
agent = agent_cls(
|
||||||
|
vision_backbone=_StubVisionBackbone(),
|
||||||
|
state_encoder=nn.Identity(),
|
||||||
|
action_encoder=nn.Identity(),
|
||||||
|
head=head,
|
||||||
|
future_decoder=future_decoder,
|
||||||
|
action_dim=2,
|
||||||
|
obs_dim=1,
|
||||||
|
pred_horizon=3,
|
||||||
|
obs_horizon=2,
|
||||||
|
diffusion_steps=10,
|
||||||
|
inference_steps=1,
|
||||||
|
num_cams=len(_CAMERA_NAMES),
|
||||||
|
camera_names=_CAMERA_NAMES,
|
||||||
|
num_action_steps=2,
|
||||||
|
head_type='transformer',
|
||||||
|
lewm_history_horizon=3,
|
||||||
|
lewm_query_offsets=[8],
|
||||||
|
lewm_sigreg=sigreg,
|
||||||
|
lewm_sigreg_weight=0.09,
|
||||||
|
lewm_loss_weight=1.0,
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
agent.future_query_tokens.copy_(torch.tensor([[[0.2, 0.4, 0.6, 0.8]]], dtype=torch.float32))
|
||||||
|
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=2,
|
||||||
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||||
|
)
|
||||||
|
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
|
||||||
|
actions = torch.tensor(
|
||||||
|
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
lewm_images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=3,
|
||||||
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||||
|
)
|
||||||
|
lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32)
|
||||||
|
lewm_future_images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=1,
|
||||||
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||||
|
)
|
||||||
|
lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32)
|
||||||
|
noise = torch.tensor(
|
||||||
|
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
t_sample = torch.tensor([0.8], dtype=torch.float32)
|
||||||
|
r_sample = torch.tensor([0.25], dtype=torch.float32)
|
||||||
|
|
||||||
|
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
|
||||||
|
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
|
||||||
|
loss = agent.compute_loss(
|
||||||
|
{
|
||||||
|
'images': images,
|
||||||
|
'qpos': qpos,
|
||||||
|
'action': actions,
|
||||||
|
'lewm_images': lewm_images,
|
||||||
|
'lewm_qpos': lewm_qpos,
|
||||||
|
'lewm_future_images': lewm_future_images,
|
||||||
|
'lewm_future_qpos': lewm_future_qpos,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = agent.get_last_loss_breakdown()
|
||||||
|
self.assertAlmostEqual(loss.item(), metrics['loss'], places=6)
|
||||||
|
self.assertEqual(len(head.calls), 2)
|
||||||
|
self.assertEqual(head.calls[0]['cond'].shape, (1, 2, 4))
|
||||||
|
self.assertEqual(len(future_decoder.calls), 1)
|
||||||
|
self.assertEqual(future_decoder.calls[0]['cond'].shape, (1, 3, 4))
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
metrics['loss'],
|
||||||
|
metrics['action_loss'] + metrics['lewm_loss'],
|
||||||
|
places=5,
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
metrics['lewm_loss'],
|
||||||
|
metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'],
|
||||||
|
places=5,
|
||||||
|
)
|
||||||
|
self.assertGreater(metrics['lewm_pred_loss'], 0.0)
|
||||||
|
self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6)
|
||||||
|
|
||||||
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
|
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
|
||||||
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
|
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
|
||||||
observation = {
|
observation = {
|
||||||
@@ -1077,6 +1239,36 @@ class IMFVLAAgentTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertIsNotNone(agent.lewm_sigreg)
|
self.assertIsNotNone(agent.lewm_sigreg)
|
||||||
|
|
||||||
|
def test_hydra_config_instantiates_lewm_resnet_dual_decoder_imf_attnres(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent=lewm_resnet_dual_decoder_imf_attnres',
|
||||||
|
'agent.head.n_layer=1',
|
||||||
|
'agent.head.n_emb=16',
|
||||||
|
'agent.future_decoder.n_layer=1',
|
||||||
|
'agent.future_decoder.n_emb=16',
|
||||||
|
'agent.lewm_query_offsets=[8]',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||||
|
self.assertEqual(cfg.agent.extra_condition_tokens, 0)
|
||||||
|
self.assertEqual(
|
||||||
|
cfg.agent.future_decoder._target_,
|
||||||
|
'roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D',
|
||||||
|
)
|
||||||
|
self.assertEqual(cfg.agent.head.cond_dim, 288)
|
||||||
|
self.assertEqual(cfg.agent.future_decoder.cond_dim, 288)
|
||||||
|
|
||||||
|
with _stub_optional_modules(include_imf_head=True):
|
||||||
|
agent = instantiate(cfg.agent)
|
||||||
|
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, 288)
|
||||||
|
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon)
|
||||||
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.obs_horizon)
|
||||||
|
self.assertEqual(agent.future_decoder.constructor_kwargs['cond_dim'], 288)
|
||||||
|
self.assertEqual(agent.future_query_tokens.shape, (1, 1, 288))
|
||||||
|
|
||||||
|
|
||||||
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
|
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
|
||||||
cfg = _compose_cfg(
|
cfg = _compose_cfg(
|
||||||
|
|||||||
@@ -12,18 +12,21 @@ from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset
|
|||||||
|
|
||||||
|
|
||||||
class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
||||||
def _write_episode(self, dataset_dir: Path) -> None:
|
def _write_episode(self, dataset_dir: Path, episode_idx: int = 0, *, base_value: float = 0.0) -> None:
|
||||||
episode_path = dataset_dir / "episode_0.hdf5"
|
episode_path = dataset_dir / f"episode_{episode_idx}.hdf5"
|
||||||
with h5py.File(episode_path, "w") as root:
|
with h5py.File(episode_path, "w") as root:
|
||||||
root.create_dataset("action", data=np.arange(8, dtype=np.float32).reshape(4, 2))
|
root.create_dataset(
|
||||||
|
"action",
|
||||||
|
data=(np.arange(8, dtype=np.float32).reshape(4, 2) + base_value),
|
||||||
|
)
|
||||||
root.create_dataset(
|
root.create_dataset(
|
||||||
"observations/qpos",
|
"observations/qpos",
|
||||||
data=np.arange(16, dtype=np.float32).reshape(4, 4),
|
data=(np.arange(16, dtype=np.float32).reshape(4, 4) + base_value),
|
||||||
)
|
)
|
||||||
root.create_dataset("task", data=np.array([b"sim_transfer"]))
|
root.create_dataset("task", data=np.array([b"sim_transfer"]))
|
||||||
root.create_dataset(
|
root.create_dataset(
|
||||||
"observations/images/front",
|
"observations/images/front",
|
||||||
data=np.arange(4 * 8 * 8 * 3, dtype=np.uint8).reshape(4, 8, 8, 3),
|
data=((np.arange(4 * 8 * 8 * 3, dtype=np.uint8) + int(base_value)) % 255).reshape(4, 8, 8, 3),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_getitem_only_resizes_observation_horizon_images(self):
|
def test_getitem_only_resizes_observation_horizon_images(self):
|
||||||
@@ -100,3 +103,25 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
|||||||
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
|
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
|
||||||
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
|
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
|
||||||
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))
|
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))
|
||||||
|
|
||||||
|
def test_dataset_can_limit_loading_to_specific_episode_indices(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
dataset_dir = Path(tmpdir)
|
||||||
|
self._write_episode(dataset_dir, episode_idx=0, base_value=0.0)
|
||||||
|
self._write_episode(dataset_dir, episode_idx=1, base_value=100.0)
|
||||||
|
|
||||||
|
dataset = SimpleRobotDataset(
|
||||||
|
dataset_dir,
|
||||||
|
obs_horizon=2,
|
||||||
|
pred_horizon=3,
|
||||||
|
camera_names=["front"],
|
||||||
|
image_resize_shape=None,
|
||||||
|
episode_indices=[1],
|
||||||
|
)
|
||||||
|
|
||||||
|
sample = dataset[0]
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset.hdf5_files), 1)
|
||||||
|
self.assertEqual(dataset.available_episode_indices, [1])
|
||||||
|
self.assertEqual(len(dataset), 4)
|
||||||
|
self.assertTrue(np.allclose(sample["observation.state"][0].numpy(), np.array([100.0, 101.0, 102.0, 103.0])))
|
||||||
|
|||||||
@@ -41,6 +41,19 @@ class FakeDataset:
|
|||||||
return 4
|
return 4
|
||||||
|
|
||||||
|
|
||||||
|
class SplitAwareFakeDataset(FakeDataset):
|
||||||
|
def __init__(self, episode_indices=None):
|
||||||
|
self.episode_indices = None if episode_indices is None else list(episode_indices)
|
||||||
|
if self.episode_indices is None:
|
||||||
|
self.episodes = {0: [0], 1: [1], 2: [2]}
|
||||||
|
else:
|
||||||
|
self.episodes = {idx: [idx] for idx in self.episode_indices}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_episode_indices(self):
|
||||||
|
return sorted(self.episodes.keys())
|
||||||
|
|
||||||
|
|
||||||
class FakeLoader:
|
class FakeLoader:
|
||||||
def __init__(self, batch):
|
def __init__(self, batch):
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
@@ -123,6 +136,10 @@ class RecordingAgent(FakeAgent):
|
|||||||
self.seen_inputs.append(agent_input)
|
self.seen_inputs.append(agent_input)
|
||||||
return super().compute_loss(agent_input)
|
return super().compute_loss(agent_input)
|
||||||
|
|
||||||
|
def predict_action_chunk(self, agent_input):
|
||||||
|
self.seen_inputs.append({'predict_action_chunk': agent_input})
|
||||||
|
return torch.ones_like(agent_input['action'])
|
||||||
|
|
||||||
|
|
||||||
class ShapeMixedFakeAgent(FakeAgent):
|
class ShapeMixedFakeAgent(FakeAgent):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -355,6 +372,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
|||||||
batch_size=2,
|
batch_size=2,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
val_split=0.25,
|
val_split=0.25,
|
||||||
|
val_episode_indices=None,
|
||||||
|
action_mse_val_freq_epochs=0,
|
||||||
seed=0,
|
seed=0,
|
||||||
lr=1e-3,
|
lr=1e-3,
|
||||||
max_steps=2,
|
max_steps=2,
|
||||||
@@ -479,6 +498,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
|||||||
'batch_size': 2,
|
'batch_size': 2,
|
||||||
'num_workers': 0,
|
'num_workers': 0,
|
||||||
'val_split': 0.25,
|
'val_split': 0.25,
|
||||||
|
'val_episode_indices': None,
|
||||||
|
'action_mse_val_freq_epochs': 0,
|
||||||
'seed': 0,
|
'seed': 0,
|
||||||
'lr': 1e-3,
|
'lr': 1e-3,
|
||||||
'max_steps': 2,
|
'max_steps': 2,
|
||||||
@@ -561,6 +582,58 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
|||||||
self.assertIn('front', first_input['lewm_images'])
|
self.assertIn('front', first_input['lewm_images'])
|
||||||
self.assertIn('front', first_input['lewm_future_images'])
|
self.assertIn('front', first_input['lewm_future_images'])
|
||||||
|
|
||||||
|
def test_run_training_logs_epoch_action_mse_for_held_out_val_episode(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
run_training = self._get_run_training(module)
|
||||||
|
cfg = self._make_cfg()
|
||||||
|
cfg.train.max_steps = 1
|
||||||
|
cfg.train.save_freq = 100
|
||||||
|
cfg.train.val_split = 0.0
|
||||||
|
cfg.train.val_episode_indices = [2]
|
||||||
|
cfg.train.action_mse_val_freq_epochs = 1
|
||||||
|
agent = RecordingAgent()
|
||||||
|
fake_swanlab = FakeSwanLab()
|
||||||
|
real_import_module = importlib.import_module
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return SplitAwareFakeDataset(kwargs.get('episode_indices'))
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return agent
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
def fake_loader_factory(dataset, *, shuffle, **_kwargs):
|
||||||
|
action_value = 0.0 if shuffle else 2.0
|
||||||
|
batch = {
|
||||||
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||||
|
'observation.state': torch.zeros(1, 4),
|
||||||
|
'action': torch.full((1, 1, 2), action_value),
|
||||||
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||||
|
}
|
||||||
|
return FakeLoader(batch)
|
||||||
|
|
||||||
|
def fake_import_module(name, package=None):
|
||||||
|
if name == 'swanlab':
|
||||||
|
return fake_swanlab
|
||||||
|
return real_import_module(name, package)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(module, 'DataLoader', side_effect=fake_loader_factory), \
|
||||||
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||||
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||||
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||||
|
run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
logged_keys = set().union(*(payload.keys() for payload, _ in fake_swanlab.log_calls))
|
||||||
|
self.assertIn('val/action_mse', logged_keys)
|
||||||
|
|
||||||
def test_run_training_skips_swanlab_when_disabled(self):
|
def test_run_training_skips_swanlab_when_disabled(self):
|
||||||
module = self._load_train_vla_module()
|
module = self._load_train_vla_module()
|
||||||
run_training = self._get_run_training(module)
|
run_training = self._get_run_training(module)
|
||||||
|
|||||||
Reference in New Issue
Block a user