refactor:大重构

This commit is contained in:
gouhanke
2026-02-11 15:53:55 +08:00
parent 1e95d40bf9
commit 130d4bb3c5
19 changed files with 1411 additions and 1223 deletions

View File

@@ -0,0 +1,238 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamConfig
from lerobot.optim.schedulers import DiffuserSchedulerConfig
@PreTrainedConfig.register_subclass("diffusion")
@dataclass
class DiffusionConfig(PreTrainedConfig):
"""Configuration class for DiffusionPolicy.
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- Either:
- At least one key starting with "observation.image is required as an input.
AND/OR
- The key "observation.environment_state" is required as input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views. Right now we only support all images having the same shape.
- "action" is required as an output key.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
You may provide a variable number of dimensions, therefore also controlling the degree of
downsampling.
kernel_size: The convolutional kernel size of the diffusion modeling Unet.
n_groups: Number of groups used in the group norm of the Unet's convolutional blocks.
diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear
network. This is the output dimension of that network, i.e., the embedding dimension.
use_film_scale_modulation: FiLM (https://huggingface.co/papers/1709.07871) is used for the Unet conditioning.
Bias modulation is used be default, while this parameter indicates whether to also use scale
modulation.
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
beta_start: Beta value for the first forward-diffusion step.
beta_end: Beta value for the last forward-diffusion step.
prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon"
or "sample". These have equivalent outcomes from a latent variable modeling perspective, but
"epsilon" has been shown to work better in many deep neural network settings.
clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
denoising step at inference time. WARNING: you will need to make sure your action-space is
normalized to fit within this range.
clip_sample_range: The magnitude of the clipping range as described above.
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
`LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
to False as the original Diffusion Policy implementation does the same.
"""
# Inputs / output structure.
n_obs_steps: int = 2
horizon: int = 16
n_action_steps: int = 8
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results.
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
# Architecture / modeling.
# Vision backbone.
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84)
crop_is_random: bool = True
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
spatial_softmax_num_keypoints: int = 32
use_separate_rgb_encoder_per_camera: bool = False
# Unet.
down_dims: tuple[int, ...] = (512, 1024, 2048)
kernel_size: int = 5
n_groups: int = 8
diffusion_step_embed_dim: int = 128
use_film_scale_modulation: bool = True
# Noise scheduler.
noise_scheduler_type: str = "DDPM"
num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001
beta_end: float = 0.02
prediction_type: str = "epsilon"
clip_sample: bool = True
clip_sample_range: float = 1.0
# Inference
num_inference_steps: int | None = None
# Loss computation
do_mask_loss_for_padding: bool = False
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-6
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 500
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:
raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
)
supported_noise_schedulers = ["DDPM", "DDIM"]
if self.noise_scheduler_type not in supported_noise_schedulers:
raise ValueError(
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
f"Got {self.noise_scheduler_type}."
)
# Check that the horizon size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims)
if self.horizon % downsampling_factor != 0:
raise ValueError(
"The horizon should be an integer multiple of the downsampling factor (which is determined "
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
)
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
)
def validate_features(self) -> None:
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
# Check that all input images have the same shape.
if len(self.image_features) > 0:
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
)
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,92 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_diffusion_pre_post_processors(
config: DiffusionConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for a diffusion policy.
The pre-processing pipeline prepares the input data for the model by:
1. Renaming features.
2. Normalizing the input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Moving the data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving the data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the diffusion policy,
containing feature definitions, normalization mappings, and device information.
dataset_stats: A dictionary of statistics used for normalization.
Defaults to None.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -1,13 +1,13 @@
"""
VLA Policy Evaluation Script (Hydra-based)
VLA 策略评估脚本(简化版)
This script evaluates a trained Vision-Language-Action (VLA) policy
in the MuJoCo simulation environment.
该脚本使用 agent 内置的队列管理来评估训练好的 VLA 策略。
无需单独的评估器类 - agent 处理一切!
Usage:
python roboimi/demos/eval_vla.py
python roboimi/demos/eval_vla.py ckpt_path=checkpoints/vla_model_step_8000.pt num_episodes=5
python roboimi/demos/eval_vla.py use_smoothing=true smooth_alpha=0.5
使用方法:
python roboimi/demos/eval_vla_simple.py
python roboimi/demos/eval_vla_simple.py eval.ckpt_path=checkpoints/vla_model_final.pt
python roboimi/demos/eval_vla_simple.py eval.ckpt_path=checkpoints/vla_model_best.pt
"""
import sys
@@ -19,314 +19,152 @@ import torch
import numpy as np
import hydra
from pathlib import Path
from typing import Dict, List
from typing import Dict
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
from einops import rearrange
from roboimi.envs.double_pos_ctrl_env import make_sim_env
from roboimi.utils.act_ex_utils import sample_transfer_pose
from einops import rearrange
# Ensure correct import path
sys.path.append(os.getcwd())
log = logging.getLogger(__name__)
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
if not OmegaConf.has_resolver("len"):
OmegaConf.register_new_resolver("len", lambda x: len(x))
class VLAEvaluator:
"""
VLA Policy Evaluator for MuJoCo Simulation
"""
def __init__(
self,
agent: torch.nn.Module,
device: str = 'cuda',
camera_names: List[str] = ['r_vis', 'top', 'front'],
num_queries: int = 1,
obs_horizon: int = 2,
pred_horizon: int = 16,
use_smoothing: bool = False,
smooth_method: str = 'ema',
smooth_alpha: float = 0.3,
dataset_stats: dict = None
):
self.agent = agent.to(device)
self.device = device
self.camera_names = camera_names
self.num_queries = num_queries
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
# Dataset statistics for normalization/denormalization
self.stats = dataset_stats
if self.stats is not None:
self.normalization_type = self.stats.get('normalization_type', 'gaussian')
self.qpos_mean = torch.tensor(self.stats['qpos_mean'], dtype=torch.float32)
self.qpos_std = torch.tensor(self.stats['qpos_std'], dtype=torch.float32)
self.qpos_min = torch.tensor(self.stats.get('qpos_min', []), dtype=torch.float32)
self.qpos_max = torch.tensor(self.stats.get('qpos_max', []), dtype=torch.float32)
self.action_mean = torch.tensor(self.stats['action_mean'], dtype=torch.float32)
self.action_std = torch.tensor(self.stats['action_std'], dtype=torch.float32)
self.action_min = torch.tensor(self.stats.get('action_min', []), dtype=torch.float32)
self.action_max = torch.tensor(self.stats.get('action_max', []), dtype=torch.float32)
else:
self.normalization_type = None
# Action smoothing
self.use_smoothing = use_smoothing
self.smooth_method = smooth_method
self.smooth_alpha = smooth_alpha
self.smoother = ActionSmoother(
action_dim=16,
method=smooth_method,
alpha=smooth_alpha
) if use_smoothing else None
# Observation buffer for obs_horizon
self.obs_buffer = {
'images': {cam: [] for cam in camera_names},
'qpos': []
}
self.cached_actions = None
self.query_step = 0
# Timing statistics
self.inference_times = [] # Model inference time only
self.total_times = [] # Total prediction time (including preprocessing)
def reset(self):
"""Reset evaluator state"""
self.obs_buffer = {
'images': {cam: [] for cam in self.camera_names},
'qpos': []
}
self.cached_actions = None
self.query_step = 0
if self.smoother is not None:
self.smoother.reset()
# Reset timing stats for each episode
self.inference_times = []
self.total_times = []
def _get_image_dict(self, obs: Dict) -> Dict[str, torch.Tensor]:
images = {}
for cam_name in self.camera_names:
img = obs['images'][cam_name]
img = rearrange(img, 'h w c -> c h w')
img = torch.from_numpy(img / 255.0).float()
images[cam_name] = img
image_dict = {}
for cam_name in self.camera_names:
cam_images = self.obs_buffer['images'][cam_name]
cam_images.append(images[cam_name])
while len(cam_images) < self.obs_horizon:
cam_images.insert(0, cam_images[0])
if len(cam_images) > self.obs_horizon:
cam_images = cam_images[-self.obs_horizon:]
img_tensor = torch.stack(cam_images, dim=0).unsqueeze(0)
image_dict[cam_name] = img_tensor
self.obs_buffer['images'][cam_name] = cam_images[-self.obs_horizon:]
return image_dict
def _get_qpos_dict(self, obs: Dict) -> torch.Tensor:
qpos = obs['qpos']
qpos = torch.from_numpy(qpos).float()
self.obs_buffer['qpos'].append(qpos)
while len(self.obs_buffer['qpos']) < self.obs_horizon:
self.obs_buffer['qpos'].insert(0, self.obs_buffer['qpos'][0])
if len(self.obs_buffer['qpos']) > self.obs_horizon:
self.obs_buffer['qpos'] = self.obs_buffer['qpos'][-self.obs_horizon:]
qpos_tensor = torch.stack(self.obs_buffer['qpos'], dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
# Normalize qpos
if self.stats is not None:
if self.normalization_type == 'gaussian':
qpos_tensor = (qpos_tensor - self.qpos_mean) / self.qpos_std
else: # min_max: normalize to [-1, 1]
qpos_tensor = 2 * (qpos_tensor - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
return qpos_tensor
@torch.no_grad()
def predict_action(self, obs: Dict) -> np.ndarray:
start_total = time.time()
images = self._get_image_dict(obs)
qpos = self._get_qpos_dict(obs)
if self.cached_actions is None or self.query_step % self.num_queries == 0:
images = {k: v.to(self.device) for k, v in images.items()}
qpos = qpos.to(self.device)
# Measure pure model inference time
start_inference = time.time()
predicted_actions = self.agent.predict_action(
images=images,
proprioception=qpos
)
# Synchronize CUDA if using GPU to get accurate timing
if self.device == 'cuda':
torch.cuda.synchronize()
end_inference = time.time()
inference_time = end_inference - start_inference
self.inference_times.append(inference_time)
# Denormalize actions
if self.stats is not None:
if self.normalization_type == 'gaussian':
predicted_actions = predicted_actions * self.action_std.to(self.device) + self.action_mean.to(self.device)
else: # min_max
predicted_actions = (predicted_actions + 1) / 2 * (self.action_max.to(self.device) - self.action_min.to(self.device)) + self.action_min.to(self.device)
self.cached_actions = predicted_actions.squeeze(0).cpu().numpy()
self.query_step = 0
raw_action = self.cached_actions[self.query_step]
self.query_step += 1
if self.smoother is not None:
raw_action = self.smoother.smooth(raw_action)
end_total = time.time()
total_time = end_total - start_total
self.total_times.append(total_time)
return raw_action
def get_timing_stats(self) -> Dict:
"""Get timing statistics"""
if len(self.inference_times) == 0:
return {
'inference_fps': 0.0,
'control_fps': 0.0,
'avg_inference_time_ms': 0.0,
'avg_total_time_ms': 0.0
}
avg_inference_time = np.mean(self.inference_times)
avg_total_time = np.mean(self.total_times)
return {
'inference_fps': 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0,
'control_fps': 1.0 / avg_total_time if avg_total_time > 0 else 0.0,
'avg_inference_time_ms': avg_inference_time * 1000,
'avg_total_time_ms': avg_total_time * 1000,
'num_inferences': len(self.inference_times),
'num_steps': len(self.total_times)
}
class ActionSmoother:
"""Action smoothing for smoother execution"""
def __init__(self, action_dim: int, method: str = 'ema', alpha: float = 0.3):
self.action_dim = action_dim
self.method = method
self.alpha = alpha
self.prev_action = None
def smooth(self, action: np.ndarray) -> np.ndarray:
if self.method == 'ema':
if self.prev_action is None:
smoothed = action
else:
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
self.prev_action = smoothed
return smoothed
else:
return action
def reset(self):
self.prev_action = None
def load_checkpoint(
ckpt_path: str,
agent_cfg: DictConfig,
device: str = 'cuda'
) -> torch.nn.Module:
"""
Load trained VLA model from checkpoint using Hydra agent config.
从检查点加载训练好的 VLA 模型,使用 Hydra agent 配置。
Args:
ckpt_path: Path to checkpoint file (.pt)
agent_cfg: Hydra agent config for instantiation
device: Device to load model on
ckpt_path: 检查点文件路径 (.pt)
agent_cfg: Hydra agent 配置,用于实例化
device: 加载模型的设备
Returns:
Loaded VLAAgent model
加载后的 VLAAgent 模型
"""
from pathlib import Path as PathLib
ckpt_path = PathLib(ckpt_path).absolute()
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
raise FileNotFoundError(f"检查点未找到: {ckpt_path}")
log.info(f"Loading checkpoint from {ckpt_path}")
log.info(f" {ckpt_path} 加载检查点")
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
log.info(f"Checkpoint keys: {checkpoint.keys()}")
log.info(f"检查点键值: {checkpoint.keys()}")
# Instantiate agent from Hydra config
log.info("Instantiating agent from config...")
agent = instantiate(agent_cfg)
# Load model state
agent.load_state_dict(checkpoint['model_state_dict'])
log.info(f"✅ Model state loaded (step: {checkpoint.get('step', 'unknown')})")
# Load dataset statistics for denormalization
# 加载数据集统计信息用于归一化
stats = checkpoint.get('dataset_stats', None)
# 使用数据集统计信息从 Hydra 配置实例化 agent
log.info("从配置实例化 agent...")
agent = instantiate(agent_cfg, dataset_stats=stats)
# 加载模型状态
agent.load_state_dict(checkpoint['model_state_dict'])
log.info(f"✅ 模型状态已加载 (步数: {checkpoint.get('step', 'unknown')})")
if stats is not None:
log.info(f"Dataset statistics loaded (normalization: {stats.get('normalization_type', 'gaussian')})")
log.info(f"数据集统计信息已加载 (归一化: {stats.get('normalization_type', 'gaussian')})")
else:
# Fallback: try external JSON file (兼容旧 checkpoint)
# 后备方案:尝试从外部 JSON 文件加载(兼容旧检查点)
stats_path = ckpt_path.parent / 'dataset_stats.json'
if stats_path.exists():
with open(stats_path, 'r') as f:
stats = json.load(f)
log.info("Dataset statistics loaded from external JSON (legacy)")
log.info("数据集统计信息已从外部 JSON 加载(旧版本兼容)")
else:
log.warning("⚠️ No dataset statistics found. Actions will not be denormalized!")
log.warning("⚠️ 未找到数据集统计信息。动作将无法反归一化!")
agent.eval()
agent.to(device)
log.info(f"Model loaded successfully on {device}")
log.info(f"模型已成功加载到 {device}")
return agent, stats
def prepare_observation(obs: Dict, camera_names: list) -> Dict:
"""
将环境观测转换为 agent 格式。
Args:
obs: 环境观测字典,包含图像和 qpos
camera_names: 摄像头名称列表
Returns:
agent 格式的观测字典
"""
# 转换图像: numpy -> tensor, HWC -> CHW
images = {}
for cam_name in camera_names:
img = obs['images'][cam_name]
img = rearrange(img, 'h w c -> c h w')
img = torch.from_numpy(img / 255.0).float()
images[cam_name] = img
# 转换 qpos: numpy -> tensor
qpos = torch.from_numpy(obs['qpos']).float()
return {'qpos': qpos, 'images': images}
class ActionSmoother:
"""
动作平滑器(指数移动平均)
用于平滑执行动作以获得更稳定的控制
"""
def __init__(self, alpha: float = 0.3):
"""
Args:
alpha: 平滑系数 (0-1),值越大越重视当前动作
"""
self.alpha = alpha
self.prev_action = None
def smooth(self, action: np.ndarray) -> np.ndarray:
"""
平滑动作
Args:
action: 当前动作
Returns:
平滑后的动作
"""
if self.prev_action is None:
smoothed = action
else:
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
self.prev_action = smoothed
return smoothed
def reset(self):
"""重置平滑器状态"""
self.prev_action = None
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
def main(cfg: DictConfig):
"""
VLA Evaluation Script with Hydra Configuration.
使用 agent 内置队列管理的简化版 VLA 评估
All eval parameters come from vla/conf/eval.yaml, merged into cfg.
Override on command line: python eval_vla.py eval.ckpt_path=... eval.num_episodes=5
所有评估参数来自 vla/conf/eval.yaml,合并到 cfg 中。
命令行覆盖: python eval_vla_simple.py eval.ckpt_path=... eval.num_episodes=5
"""
# Print configuration
# 打印配置
print("=" * 80)
print("VLA Evaluation Configuration:")
print("VLA 评估配置:")
print("=" * 80)
print(OmegaConf.to_yaml(cfg))
print("=" * 80)
@@ -335,67 +173,114 @@ def main(cfg: DictConfig):
device = eval_cfg.device
camera_names = list(eval_cfg.camera_names)
# Load model
log.info(f"🚀 Loading model from {eval_cfg.ckpt_path}...")
# =========================================================================
# 加载模型
# =========================================================================
log.info(f"🚀 从 {eval_cfg.ckpt_path} 加载模型...")
agent, dataset_stats = load_checkpoint(
ckpt_path=eval_cfg.ckpt_path,
agent_cfg=cfg.agent,
device=device
)
# Create evaluator
evaluator = VLAEvaluator(
agent=agent,
device=device,
camera_names=camera_names,
num_queries=eval_cfg.num_queries,
obs_horizon=eval_cfg.obs_horizon,
use_smoothing=eval_cfg.use_smoothing,
smooth_method=eval_cfg.smooth_method,
smooth_alpha=eval_cfg.smooth_alpha,
dataset_stats=dataset_stats
)
# 重置 agent 的队列
agent.reset()
# Create environment
# 可选:动作平滑器
smoother = ActionSmoother(alpha=eval_cfg.smooth_alpha) if eval_cfg.use_smoothing else None
# =========================================================================
# 创建环境
# =========================================================================
env = make_sim_env(eval_cfg.task_name)
# Run episodes
# =========================================================================
# 运行评估回合
# =========================================================================
all_stats = []
for episode_idx in range(eval_cfg.num_episodes):
print(f"\n{'='*60}")
print(f"Episode {episode_idx + 1}/{eval_cfg.num_episodes}")
print(f"回合 {episode_idx + 1}/{eval_cfg.num_episodes}")
print(f"{'='*60}\n")
box_pos = sample_transfer_pose()
env.reset(box_pos)
evaluator.reset()
# 为新回合重置 agent 队列
agent.reset()
if smoother:
smoother.reset()
# 计时统计
inference_times = []
total_times = []
with torch.inference_mode():
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"Episode {episode_idx + 1}"):
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"回合 {episode_idx + 1}"):
start_total = time.time()
# 从环境获取观测
obs = env._get_image_obs()
qpos_obs = env._get_qpos_obs()
obs['qpos'] = qpos_obs['qpos']
action = evaluator.predict_action(obs)
env.step_jnt(action)
# 准备给 agent 的观测
observation = prepare_observation(obs, camera_names)
# 选择动作agent 内部处理队列管理)
start_inference = time.time()
action = agent.select_action(observation)
if device == 'cuda':
torch.cuda.synchronize()
end_inference = time.time()
# 转换为 numpy
action = action.cpu().numpy()
# 可选:平滑动作
if smoother:
action = smoother.smooth(action)
# 执行动作
env.step_jnt(action)
env.render()
# Get timing statistics for this episode
stats = evaluator.get_timing_stats()
end_total = time.time()
# 记录计时
inference_times.append(end_inference - start_inference)
total_times.append(end_total - start_total)
# =========================================================================
# 打印回合统计
# =========================================================================
avg_inference_time = np.mean(inference_times)
avg_total_time = np.mean(total_times)
stats = {
'inference_fps': 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0,
'control_fps': 1.0 / avg_total_time if avg_total_time > 0 else 0.0,
'avg_inference_time_ms': avg_inference_time * 1000,
'avg_total_time_ms': avg_total_time * 1000,
'num_inferences': len([t for t in inference_times if t > 0.001]), # 统计实际推理次数
'num_steps': len(total_times)
}
all_stats.append(stats)
print(f"\nEpisode {episode_idx + 1} completed ({eval_cfg.max_timesteps} timesteps)")
print(f" Model Inference FPS: {stats['inference_fps']:.2f} Hz")
print(f" Control Loop FPS: {stats['control_fps']:.2f} Hz")
print(f" Avg Inference Time: {stats['avg_inference_time_ms']:.2f} ms")
print(f" Avg Total Time: {stats['avg_total_time_ms']:.2f} ms")
print(f" Total Inferences: {stats['num_inferences']}")
print(f"\n回合 {episode_idx + 1} 完成 ({eval_cfg.max_timesteps} 时间步)")
print(f" 模型推理 FPS: {stats['inference_fps']:.2f} Hz")
print(f" 控制循环 FPS: {stats['control_fps']:.2f} Hz")
print(f" 平均推理时间: {stats['avg_inference_time_ms']:.2f} ms")
print(f" 平均总时间: {stats['avg_total_time_ms']:.2f} ms")
print(f" 总推理次数: {stats['num_inferences']}")
# Print overall statistics
# =========================================================================
# 总体统计
# =========================================================================
print(f"\n{'='*60}")
print("Evaluation complete!")
print("评估完成!")
print(f"{'='*60}")
if all_stats:
@@ -404,11 +289,11 @@ def main(cfg: DictConfig):
avg_inference_time = np.mean([s['avg_inference_time_ms'] for s in all_stats])
avg_total_time = np.mean([s['avg_total_time_ms'] for s in all_stats])
print(f"\nOverall Statistics ({eval_cfg.num_episodes} episodes):")
print(f" Average Model Inference FPS: {avg_inference_fps:.2f} Hz")
print(f" Average Control Loop FPS: {avg_control_fps:.2f} Hz")
print(f" Average Inference Time: {avg_inference_time:.2f} ms")
print(f" Average Total Time: {avg_total_time:.2f} ms")
print(f"\n总体统计 ({eval_cfg.num_episodes} 个回合):")
print(f" 平均模型推理 FPS: {avg_inference_fps:.2f} Hz")
print(f" 平均控制循环 FPS: {avg_control_fps:.2f} Hz")
print(f" 平均推理时间: {avg_inference_time:.2f} ms")
print(f" 平均总时间: {avg_total_time:.2f} ms")
print()

View File

@@ -12,28 +12,28 @@ from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from pathlib import Path
# Ensure correct import path
# 确保正确的导入路径
sys.path.append(os.getcwd())
from hydra.utils import instantiate
log = logging.getLogger(__name__)
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
# 注册列表长度解析器(用于配置中如 ${len:${data.camera_names}}
if not OmegaConf.has_resolver("len"):
OmegaConf.register_new_resolver("len", lambda x: len(x))
def recursive_to_device(data, device):
"""
Recursively move nested dictionaries/lists of tensors to specified device.
递归地将嵌套字典/列表中的张量移动到指定设备。
Args:
data: Dictionary, list, or tensor
device: Target device (e.g., 'cuda', 'cpu')
data: 字典、列表或张量
device: 目标设备 (例如 'cuda', 'cpu')
Returns:
Data structure with all tensors moved to device
所有张量已移动到指定设备的数据结构
"""
if isinstance(data, torch.Tensor):
return data.to(device)
@@ -46,36 +46,36 @@ def recursive_to_device(data, device):
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
"""
Create a learning rate scheduler with warmup.
创建带预热的学习率调度器。
Args:
optimizer: PyTorch optimizer
warmup_steps: Number of warmup steps
max_steps: Total training steps
scheduler_type: Type of scheduler after warmup ('cosine' or 'constant')
min_lr: Minimum learning rate (for cosine decay)
optimizer: PyTorch 优化器
warmup_steps: 预热步数
max_steps: 总训练步数
scheduler_type: 预热后的调度器类型 ('cosine' 'constant')
min_lr: 最小学习率(用于余弦衰减)
Returns:
LambdaLR scheduler
LambdaLR 调度器
"""
import math
# Capture initial lr before LambdaLR modifies it
# 在 LambdaLR 修改前捕获初始学习率
base_lr = optimizer.param_groups[0]['lr']
min_lr_ratio = min_lr / base_lr if base_lr > 0 else 0.0
def lr_lambda(step):
# Warmup phase: linear increase from 0 to 1
# 预热阶段:从 0 线性增加到 1
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
# Post-warmup phase
# 预热后阶段
if scheduler_type == 'cosine':
# Cosine annealing from 1 to min_lr_ratio
# 1 min_lr_ratio 的余弦退火
progress = float(step - warmup_steps) / float(max(1, max_steps - warmup_steps))
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return max(min_lr_ratio, cosine_decay)
else:
# Constant learning rate
# 恒定学习率
return 1.0
return LambdaLR(optimizer, lr_lambda)
@@ -84,40 +84,40 @@ def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_ty
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
def main(cfg: DictConfig):
"""
VLA Training Script with ResNet Backbone and Diffusion Policy.
VLA 训练脚本ResNet 骨干网络 + Diffusion 策略)
This script:
1. Loads dataset from HDF5 files
2. Instantiates VLAAgent with ResNet vision encoder
3. Trains diffusion-based action prediction
4. Saves checkpoints periodically
该脚本功能:
1. 从 HDF5 文件加载数据集
2. 实例化带 ResNet 视觉编码器的 VLAAgent
3. 训练基于扩散的动作预测模型
4. 定期保存检查点
"""
# Print configuration
# 打印配置
print("=" * 80)
print("VLA Training Configuration:")
print("VLA 训练配置:")
print("=" * 80)
print(OmegaConf.to_yaml(cfg))
print("=" * 80)
log.info(f"🚀 Starting VLA Training (Device: {cfg.train.device})")
log.info(f"🚀 开始 VLA 训练 (设备: {cfg.train.device})")
# Create checkpoint directory
# 创建检查点目录
checkpoint_dir = Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True)
# =========================================================================
# 1. Instantiate Dataset & DataLoader
# 1. 实例化数据集与 DataLoader
# =========================================================================
log.info("📦 Loading dataset...")
log.info("📦 加载数据集...")
try:
dataset = instantiate(cfg.data)
log.info(f"Dataset loaded successfully. Total samples: {len(dataset)}")
log.info(f"数据集加载成功。总样本数: {len(dataset)}")
except Exception as e:
log.error(f"Failed to load dataset: {e}")
log.error(f"数据集加载失败: {e}")
raise
# Train/Val split
# 训练/验证集划分
val_split = float(cfg.train.get('val_split', 0.1))
seed = int(cfg.train.get('seed', 42))
val_size = int(len(dataset) * val_split)
@@ -128,10 +128,10 @@ def main(cfg: DictConfig):
[train_size, val_size],
generator=torch.Generator().manual_seed(seed)
)
log.info(f"Dataset split: train={train_size}, val={val_size} (val_split={val_split})")
log.info(f"数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})")
else:
train_dataset, val_dataset = dataset, None
log.info("Dataset split: train=all, val=0 (val_split=0)")
log.info("数据集划分: 全部用于训练, 验证集=0 (验证比例=0)")
train_loader = DataLoader(
train_dataset,
@@ -139,7 +139,7 @@ def main(cfg: DictConfig):
shuffle=True,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
drop_last=True # Drop incomplete batches for stable training
drop_last=True # 丢弃不完整批次以稳定训练
)
val_loader = None
@@ -153,34 +153,14 @@ def main(cfg: DictConfig):
drop_last=False
)
log.info(f"Train loader batches per epoch: {len(train_loader)}")
log.info(f"训练加载器每轮批次数: {len(train_loader)}")
if val_loader is not None:
log.info(f"Val loader batches per epoch: {len(val_loader)}")
log.info(f"验证加载器每轮批次数: {len(val_loader)}")
# =========================================================================
# 2. Instantiate VLA Agent
# 2. 加载数据集统计信息(将传递给 agent
# =========================================================================
log.info("🤖 Initializing VLA Agent...")
try:
agent = instantiate(cfg.agent)
agent.to(cfg.train.device)
agent.train()
log.info(f"✅ Agent initialized and moved to {cfg.train.device}")
# Count parameters
total_params = sum(p.numel() for p in agent.parameters())
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
log.info(f"📊 Total parameters: {total_params:,}")
log.info(f"📊 Trainable parameters: {trainable_params:,}")
except Exception as e:
log.error(f"❌ Failed to initialize agent: {e}")
raise
# =========================================================================
# 2.5. Load Dataset Statistics (will be saved into checkpoints)
# =========================================================================
log.info("💾 Loading dataset statistics...")
log.info("💾 加载数据集统计信息...")
dataset_stats = None
try:
dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer')
@@ -201,22 +181,43 @@ def main(cfg: DictConfig):
'qpos_min': stats['qpos']['min'].tolist(),
'qpos_max': stats['qpos']['max'].tolist(),
}
log.info(f"Dataset statistics loaded (normalization: {dataset_stats['normalization_type']})")
log.info(f"数据集统计信息加载完成 (归一化: {dataset_stats['normalization_type']})")
else:
log.warning(f"⚠️ Statistics file not found: {stats_path}")
log.warning("⚠️ Actions will not be denormalized during inference!")
log.warning(f"⚠️ 统计文件未找到: {stats_path}")
log.warning("⚠️ 推理时动作将无法反归一化!")
except Exception as e:
log.warning(f"⚠️ Failed to load statistics: {e}")
log.warning("⚠️ Training will continue, but inference may not work correctly")
log.warning(f"⚠️ 统计信息加载失败: {e}")
log.warning("⚠️ 训练将继续,但推理可能无法正常工作")
# =========================================================================
# 3. Setup Optimizer & LR Scheduler
# 3. 实例化 VLA Agent
# =========================================================================
log.info("🤖 初始化 VLA Agent...")
try:
# 将 dataset_stats 和 normalization_type 传递给 agent
agent = instantiate(cfg.agent, dataset_stats=dataset_stats)
agent.to(cfg.train.device)
agent.train()
log.info(f"✅ Agent 初始化完成并已移至 {cfg.train.device}")
# 统计参数量
total_params = sum(p.numel() for p in agent.parameters())
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
log.info(f"📊 总参数量: {total_params:,}")
log.info(f"📊 可训练参数量: {trainable_params:,}")
except Exception as e:
log.error(f"❌ Agent 初始化失败: {e}")
raise
# =========================================================================
# 4. 设置优化器与学习率调度器
# =========================================================================
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})")
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr})")
# Setup learning rate scheduler with warmup
# 设置带预热的学習率调度器
warmup_steps = int(cfg.train.get('warmup_steps', 500))
scheduler_type = cfg.train.get('scheduler_type', 'cosine')
min_lr = float(cfg.train.get('min_lr', 1e-6))
@@ -228,33 +229,36 @@ def main(cfg: DictConfig):
scheduler_type=scheduler_type,
min_lr=min_lr
)
log.info(f"📈 LR Scheduler: {scheduler_type} with {warmup_steps} warmup steps (min_lr={min_lr})")
log.info(f"📈 学习率调度器: {scheduler_type}{warmup_steps} 步预热 (最小学习率={min_lr})")
# =========================================================================
# 4. Training Loop
# 5. 训练循环
# =========================================================================
log.info("🏋️ Starting training loop...")
log.info("🏋️ 开始训练循环...")
def build_agent_input(batch_data):
"""构建 agent 输入格式"""
images = {}
# SimpleRobotDataset 返回 observation.{cam_name} 格式
for cam_name in cfg.data.camera_names:
key = f"image_{cam_name}"
key = f"observation.{cam_name}"
if key in batch_data:
images[cam_name] = batch_data[key]
return {
'images': images,
'qpos': batch_data['qpos'],
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
'action': batch_data['action']
}
def run_validation():
"""运行验证"""
if val_loader is None:
return None
agent.eval()
# 🔧 FIX: Set deterministic seed for validation to get reproducible loss
# This ensures validation loss is comparable across different steps
# 设置确定性种子以获得可重现的损失
# 这确保验证损失在不同步骤之间可比较
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
@@ -272,7 +276,7 @@ def main(cfg: DictConfig):
return total_loss / max(num_batches, 1)
data_iter = iter(train_loader)
pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100)
pbar = tqdm(range(cfg.train.max_steps), desc="训练中", ncols=100)
best_loss = float('inf')
@@ -280,47 +284,47 @@ def main(cfg: DictConfig):
try:
batch = next(data_iter)
except StopIteration:
# Restart iterator when epoch ends
# 轮次结束时重启迭代器
data_iter = iter(train_loader)
batch = next(data_iter)
# =====================================================================
# Move batch to device
# 将批次移至设备
# =====================================================================
batch = recursive_to_device(batch, cfg.train.device)
# =====================================================================
# Prepare agent input
# 准备 agent 输入
# =====================================================================
# Dataset returns: {action, qpos, image_<cam_name>, ...}
# Agent expects: {images: dict, qpos: tensor, action: tensor}
# 数据集返回: {action, qpos, image_<cam_name>, ...}
# Agent 期望: {images: dict, qpos: tensor, action: tensor}
# Prepare agent input
# 准备 agent 输入
agent_input = build_agent_input(batch)
# =====================================================================
# Forward pass & compute loss
# 前向传播与损失计算
# =====================================================================
try:
loss = agent.compute_loss(agent_input)
except Exception as e:
log.error(f"Forward pass failed at step {step}: {e}")
log.error(f"步骤 {step} 前向传播失败: {e}")
raise
# =====================================================================
# Backward pass & optimization
# 反向传播与优化
# =====================================================================
optimizer.zero_grad()
loss.backward()
# Gradient clipping for stable training
# 梯度裁剪以稳定训练
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
# =====================================================================
# Logging
# 日志记录
# =====================================================================
if step % cfg.train.log_freq == 0:
current_lr = optimizer.param_groups[0]['lr']
@@ -329,16 +333,16 @@ def main(cfg: DictConfig):
"lr": f"{current_lr:.2e}",
"best_loss": f"{best_loss:.4f}"
})
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f} | LR: {current_lr:.2e}")
log.info(f"步骤 {step}/{cfg.train.max_steps} | 损失: {loss.item():.4f} | 学习率: {current_lr:.2e}")
# =====================================================================
# Checkpoint saving & Validation
# 检查点保存与验证
# =====================================================================
if step > 0 and step % cfg.train.save_freq == 0:
# Run validation
# 运行验证
val_loss = run_validation()
if val_loss is not None:
log.info(f"Step {step}/{cfg.train.max_steps} | Val Loss: {val_loss:.4f}")
log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}")
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
torch.save({
@@ -351,9 +355,9 @@ def main(cfg: DictConfig):
'dataset_stats': dataset_stats,
'current_lr': optimizer.param_groups[0]['lr'],
}, checkpoint_path)
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
log.info(f"💾 检查点已保存: {checkpoint_path}")
# Save best model based on validation loss
# 根据验证损失保存最佳模型
eval_loss = val_loss if val_loss is not None else loss.item()
if eval_loss < best_loss:
best_loss = eval_loss
@@ -368,10 +372,10 @@ def main(cfg: DictConfig):
'dataset_stats': dataset_stats,
'current_lr': optimizer.param_groups[0]['lr'],
}, best_model_path)
log.info(f"🌟 Best model updated: {best_model_path} (val_loss: {best_loss:.4f})")
log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})")
# =========================================================================
# 5. Save Final Model
# 6. 保存最终模型
# =========================================================================
final_model_path = checkpoint_dir / "vla_model_final.pt"
torch.save({
@@ -383,11 +387,11 @@ def main(cfg: DictConfig):
'dataset_stats': dataset_stats,
'current_lr': optimizer.param_groups[0]['lr'],
}, final_model_path)
log.info(f"💾 Final model saved: {final_model_path}")
log.info(f"💾 最终模型已保存: {final_model_path}")
log.info("Training completed successfully!")
log.info(f"📊 Final Loss: {loss.item():.4f}")
log.info(f"📊 Best Loss: {best_loss:.4f}")
log.info("训练成功完成!")
log.info(f"📊 最终损失: {loss.item():.4f}")
log.info(f"📊 最佳损失: {best_loss:.4f}")
if __name__ == "__main__":

View File

@@ -1,17 +1,19 @@
import torch
import torch.nn as nn
import numpy as np
from typing import Dict, Optional, Any
from collections import deque
from typing import Dict, Optional, Any, Tuple
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D
from roboimi.vla.models.normalization import NormalizationModule
class VLAAgent(nn.Module):
def __init__(
self,
vision_backbone, # 你之前定义的 ResNet
vision_backbone, # 视觉编码器(ResNet 等)
state_encoder,
action_encoder,
head,
@@ -19,23 +21,35 @@ class VLAAgent(nn.Module):
obs_dim, # 本体感知维度 (例如 关节角度)
pred_horizon=16, # 预测未来多少步动作
obs_horizon=4, # 使用多少步历史观测
diffusion_steps=100,
diffusion_steps=100, # DDPM 加噪步数
inference_steps=10, # DDIM 推理步数
num_cams=3, # 视觉输入的摄像头数量
dataset_stats=None, # 数据集统计信息,用于归一化
normalization_type='gaussian', # 归一化类型: 'gaussian' 或 'min_max'
num_action_steps=1, # 每次推理实际执行多少步动作
):
super().__init__()
# Store parameters
# 保存参数
self.action_dim = action_dim
self.obs_dim = obs_dim
self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon
self.num_cams = num_cams
self.num_action_steps = num_action_steps
self.inference_steps = inference_steps
# 归一化模块 - 统一训练和推理的归一化逻辑
self.normalization = NormalizationModule(
stats=dataset_stats,
normalization_type=normalization_type
)
self.vision_encoder = vision_backbone
single_img_feat_dim = self.vision_encoder.output_dim
total_vision_dim = single_img_feat_dim * num_cams * obs_horizon
single_cam_feat_dim = self.vision_encoder.output_dim
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
total_prop_dim = obs_dim * obs_horizon
self.global_cond_dim = total_vision_dim + total_prop_dim
# self.global_cond_dim = total_vision_dim
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps,
@@ -44,7 +58,7 @@ class VLAAgent(nn.Module):
prediction_type='epsilon' # 预测噪声
)
# DDIM scheduler for faster inference
# DDIM 调度器用于快速推理
self.infer_scheduler = DDIMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule='squaredcos_cap_v2',
@@ -54,45 +68,55 @@ class VLAAgent(nn.Module):
self.noise_pred_net = head(
input_dim=action_dim,
# input_dim = action_dim + obs_dim,
# input_dim = action_dim + obs_dim, # 备选:包含观测维度
global_cond_dim=self.global_cond_dim
)
self.state_encoder = state_encoder
self.action_encoder = action_encoder
# 初始化队列(用于在线推理)
self.reset()
# ==========================
# 训练阶段 (Training)
# ==========================
def compute_loss(self, batch):
"""
batch: 包含 images, qpos (proprioception), action
计算训练损失
Args:
batch: 包含 images, qpos (本体感知), action 的字典
"""
actions, states, images = batch['action'], batch['qpos'], batch['images']
B = actions.shape[0]
# 归一化 states (qpos) 和 actions
states = self.normalization.normalize_qpos(states)
actions = self.normalization.normalize_action(actions)
state_features = self.state_encoder(states)
# 1. 提取视觉特征
visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim)
action_features = self.action_encoder(actions)
# 3. 采样噪声
# 2. 采样噪声
noise = torch.randn_like(action_features)
# 4. 随机采样时间步 (Timesteps)
# 3. 随机采样时间步 (Timesteps)
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps,
(B,), device=action_features.device
).long()
# 5. 给动作加噪 (Forward Diffusion)
# 4. 给动作加噪 (Forward Diffusion)
noisy_actions = self.noise_scheduler.add_noise(
action_features, noise, timesteps
)
# 6. 网络预测噪声
# 5. 网络预测噪声
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
@@ -100,30 +124,192 @@ class VLAAgent(nn.Module):
proprioception=state_features
)
# 7. 计算 Loss (MSE)
# 6. 计算 Loss (MSE)
loss = nn.functional.mse_loss(pred_noise, noise)
return loss
# ==========================
# 推理阶段 (Inference)
# 队列管理 (Queue Management)
# ==========================
def reset(self):
"""清空观测和动作队列。应在 env.reset() 时调用"""
self._queues = {
'qpos': deque(maxlen=self.obs_horizon),
'images': deque(maxlen=self.obs_horizon),
'action': deque(maxlen=self.pred_horizon - self.obs_horizon + 1), # 可执行的动作缓存
}
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
"""
将新的观测添加到队列中。
Args:
observation: 包含 'qpos''images' 的字典
"""
# 添加本体感知
if 'qpos' in observation:
self._queues['qpos'].append(observation['qpos'].clone())
# 添加图像
if 'images' in observation:
self._queues['images'].append({k: v.clone() for k, v in observation['images'].items()})
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
"""
从队列中准备用于推理的批量观测。
如果队列未满(首次调用时),用最新观测重复填充。
Returns:
batch: 包含堆叠后的历史观测的字典
"""
# 堆叠历史本体感知
qpos_list = list(self._queues['qpos'])
if len(qpos_list) == 0:
raise ValueError("观测队列为空,请先调用 _populate_queues 添加观测")
# 如果队列未满,用最后一个观测填充
while len(qpos_list) < self.obs_horizon:
qpos_list.append(qpos_list[-1])
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
# 堆叠历史图像
images_list = list(self._queues['images'])
if len(images_list) == 0:
raise ValueError("图像队列为空,请先调用 _populate_queues 添加观测")
# 如果队列未满,用最后一个观测填充
while len(images_list) < self.obs_horizon:
images_list.append(images_list[-1])
batch_images = {}
for cam_name in images_list[0].keys():
batch_images[cam_name] = torch.stack([img[cam_name] for img in images_list], dim=0).unsqueeze(0)
return {'qpos': batch_qpos, 'images': batch_images}
# ==========================
# 在线推理 (Online Inference)
# ==========================
@torch.no_grad()
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
根据当前观测选择单个动作。
这个方法维护一个历史观测和生成动作轨迹的缓存。工作流程:
- 缓存 `obs_horizon` 步的历史观测
- Diffusion 模型生成 `pred_horizon` 步的动作
- 实际执行 `num_action_steps` 步动作
示意图:
--------------------------------------------------------------
(图例: o=obs_horizon, h=pred_horizon, a=num_action_steps)
|时间步 | 0 | 1 | ... | o-1 | o | ... | h-1 |
|观测是否使用 | 是 | 是 | 是 | 是 | 否 | 否 | 否 |
|动作是否生成 | 是 | 是 | 是 | 是 | 是 | 是 | 是 |
|动作是否执行 | 否 | 否 | 否 | 否 | 是 | 是 | 是 |
--------------------------------------------------------------
Args:
observation: 包含 'qpos''images' 的字典
Returns:
action: (action_dim,) 单个动作
"""
# 检测设备并确保所有组件在同一设备上
# 尝试从观测中获取设备
device = None
for v in observation.values():
if isinstance(v, torch.Tensor):
device = v.device
break
if device is not None and self.normalization.enabled:
# 确保归一化参数在同一设备上
norm_device = self.normalization.qpos_mean.device
if device != norm_device:
self.normalization.to(device)
# 同时确保其他模块也在正确设备
self.vision_encoder.to(device)
self.state_encoder.to(device)
self.action_encoder.to(device)
self.noise_pred_net.to(device)
# 将所有 observation 移到正确设备
observation = {k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()}
# 将新观测添加到队列
self._populate_queues(observation)
# 如果动作队列为空,生成新的动作序列
if len(self._queues['action']) == 0:
# 从队列准备批量观测
batch = self._prepare_observation_batch()
# 生成动作块
actions = self.predict_action_chunk(batch) # (1, pred_horizon, action_dim)
# 提取可执行的动作部分
# 从 obs_horizon-1 开始,因为前面的动作对应过去的观测
start = self.obs_horizon - 1
end = start + self.num_action_steps
executable_actions = actions[:, start:end] # (1, num_action_steps, action_dim)
# 将动作添加到队列
for i in range(executable_actions.shape[1]):
self._queues['action'].append(executable_actions[:, i].squeeze(0)) # (action_dim,)
# 从队列中取出一个动作
action = self._queues['action'].popleft() # (action_dim,)
return action
@torch.no_grad()
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
预测一个动作块(用于在线推理)。
Args:
batch: 包含 'qpos''images' 的字典
- qpos: (B, obs_horizon, obs_dim)
- images: Dict[str, (B, obs_horizon, C, H, W)]
Returns:
actions: (B, pred_horizon, action_dim) 预测的动作序列
"""
return self.predict_action(batch['images'], batch['qpos'])
# ==========================
# 批量推理 (Batch Inference - 原有方法)
# ==========================
@torch.no_grad()
def predict_action(self, images, proprioception):
"""
批量预测动作序列(用于训练和离线评估)
Args:
images: 图像观测字典
proprioception: 本体感知观测 (qpos)
Returns:
denormalized_actions: 反归一化后的动作序列
"""
B = proprioception.shape[0]
# 1. 提取当前观测特征 (只做一次)
# 归一化 proprioception (qpos)
proprioception = self.normalization.normalize_qpos(proprioception)
# 1. 提取当前观测特征(只提取一次)
visual_features = self.vision_encoder(images)
state_features = self.state_encoder(proprioception)
# 2. 初始化纯高斯噪声动作
# Shape: (B, pred_horizon, action_dim)
# 形状: (B, pred_horizon, action_dim)
device = visual_features.device
current_actions = torch.randn(
(B, self.pred_horizon, self.action_dim), device=device
)
# 3. 逐步去噪循环 (Reverse Diffusion)
self.infer_scheduler.set_timesteps(10) # DDIM 推理步数
self.infer_scheduler.set_timesteps(self.inference_steps) # DDIM 推理步数
for t in self.infer_scheduler.timesteps:
model_input = current_actions
@@ -141,5 +327,11 @@ class VLAAgent(nn.Module):
noise_pred, t, current_actions
).prev_sample
# 4. 输出最终动作序列(归一化空间,由调用方负责反归一化)
return current_actions
# 4. 反归一化动作序列
denormalized_actions = self.normalization.denormalize_action(current_actions)
return denormalized_actions
def get_normalization_stats(self):
"""获取归一化统计信息(用于保存到 checkpoint"""
return self.normalization.get_stats()

View File

@@ -9,14 +9,26 @@ defaults:
_target_: roboimi.vla.agent.VLAAgent
# Action and Observation Dimensions
action_dim: 16
obs_dim: 16
# ====================
# 模型维度配置
# ====================
action_dim: 16 # 动作维度(机器人关节数)
obs_dim: 16 # 本体感知维度(关节位置)
# Prediction and Observation Horizons
pred_horizon: 16
obs_horizon: 2
# ====================
# 时间步配置
# ====================
pred_horizon: 16 # 预测未来多少步动作
obs_horizon: 2 # 使用多少步历史观测
num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1
# ====================
# 相机配置
# ====================
num_cams: 3 # 摄像头数量 (r_vis, top, front)
# Camera Configuration
num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取
# ====================
# 扩散过程配置
# ====================
diffusion_steps: 100 # 扩散训练步数DDPM
inference_steps: 10 # 推理时的去噪步数DDIM固定为 10

View File

@@ -1,4 +0,0 @@
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
model_name: "microsoft/resnet-18"
freeze: true

View File

@@ -1,8 +1,28 @@
_target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
vision_backbone: "resnet18"
pretrained_backbone_weights: null
input_shape: [3, 96, 96]
crop_shape: [84, 84]
crop_is_random: true
use_group_norm: true
spatial_softmax_num_keypoints: 32
# ====================
# 骨干网络选择
# ====================
vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50
pretrained_backbone_weights: null # 预训练权重路径或 nullImageNet 权重)
# ====================
# 输入配置
# ====================
input_shape: [3, 96, 96] # 输入图像形状 (C, H, W)
crop_shape: [84, 84] # 裁剪后的图像形状 (H, W)
crop_is_random: true # 训练时使用随机裁剪,评估时使用中心裁剪
# ====================
# 归一化和特征提取
# ====================
use_group_norm: true # 使用 GroupNorm 替代 BatchNorm更适合小批次训练
spatial_softmax_num_keypoints: 32 # Spatial Softmax 关键点数量
# ====================
# 编码器模式
# ====================
# false: 共享编码器(所有摄像头共享一个 ResNet参数少但容量受限推荐
# true: 独立编码器(每个摄像头有独立的 ResNet参数多但容量大
use_separate_rgb_encoder_per_camera: true
num_cameras: 3 # 摄像头数量

View File

@@ -1,19 +1,41 @@
defaults:
- agent: resnet_diffusion
- data: resnet_dataset
- data: simpe_robot_dataset
- eval: eval
- _self_
# ====================
# 训练配置
# ====================
train:
batch_size: 8 # Batch size for training
lr: 1e-4 # Learning rate
max_steps: 20000 # Maximum training steps
log_freq: 100 # Log frequency (steps)
save_freq: 2000 # Save checkpoint frequency (steps)
device: "cuda" # Device: "cuda" or "cpu"
num_workers: 8 # DataLoader workers (set to 0 for debugging, 8 for production)
# 基础训练参数
batch_size: 8 # 批次大小
lr: 1e-4 # 学习率
max_steps: 100000 # 最大训练步数
device: "cuda" # 设备: "cuda" 或 "cpu"
# Learning rate scheduler with warmup
warmup_steps: 500 # Number of warmup steps
scheduler_type: "cosine" # Scheduler after warmup: "constant" or "cosine"
min_lr: 1e-6 # Minimum learning rate (for cosine decay)
# 数据加载
num_workers: 8 # DataLoader 工作进程数(调试时设为 0生产环境用 8
val_split: 0.1 # 验证集比例
seed: 42 # 随机种子(用于数据划分)
# 日志和检查点
log_freq: 100 # 日志记录频率(步数)
save_freq: 5000 # 保存检查点频率(步数)
# 学习率调度器(带预热)
warmup_steps: 500 # 预热步数
scheduler_type: "cosine" # 预热后的调度器: "constant" 或 "cosine"
min_lr: 1e-6 # 最小学习率(用于余弦退火)
# 优化器
weight_decay: 1e-5 # 权重衰减L2 正则化)
grad_clip: 1.0 # 梯度裁剪阈值
# ====================
# 实验配置
# ====================
experiment:
name: "vla_diffusion" # 实验名称
notes: "" # 实验备注
tags: [] # 实验标签

View File

@@ -1,19 +0,0 @@
# @package data
_target_: roboimi.vla.data.dataset.RobotDiffusionDataset
# Dataset Directory (CHANGE THIS TO YOUR DATA PATH)
dataset_dir: "roboimi/demos/dataset/sim_transfer" # Path to your dataset directory
# Horizon Parameters — 使用 Hydra 插值,从 agent 配置中引用,保持一致性
pred_horizon: ${agent.pred_horizon}
obs_horizon: ${agent.obs_horizon}
action_horizon: 8 # Action execution horizon (used during evaluation)
# Camera Names (CHANGE THIS TO MATCH YOUR CAMERAS)
camera_names:
- r_vis
- top
- front
# Normalization Type: 'gaussian' (mean/std) or 'min_max' ([-1, 1])
normalization_type: min_max

View File

@@ -0,0 +1,21 @@
# @package data
_target_: roboimi.vla.data.simpe_robot_dataset.SimpleRobotDataset
# ====================
# 数据集路径
# ====================
dataset_dir: "roboimi/demos/dataset/sim_transfer"
# ====================
# 时间步参数(从 agent 配置引用)
# ====================
pred_horizon: ${agent.pred_horizon} # 预测步数
obs_horizon: ${agent.obs_horizon} # 观测步数
# ====================
# 相机配置
# ====================
camera_names:
- r_vis # 机器人视角相机
- top # 顶部相机
- front # 前方相机

View File

@@ -1,19 +1,27 @@
# @package eval
# Evaluation Configuration
ckpt_path: "checkpoints/vla_model_best.pt" # Path to model checkpoint
num_episodes: 3 # Number of evaluation episodes
max_timesteps: 700 # Maximum timesteps per episode
# 评估配置
ckpt_path: "checkpoints/vla_model_best.pt" # 模型检查点路径
num_episodes: 3 # 评估回合数
max_timesteps: 700 # 每回合最大时间步
device: ${train.device} # 与训练保持一致
task_name: "sim_transfer" # Task name for environment creation
task_name: "sim_transfer" # 环境任务名称
# Policy execution — 从 agent 配置中引用,保持一致性
num_queries: 4 # 每次预测 pred_horizon 步后重新查询
# ====================
# 策略执行参数
# ====================
# num_queries 已废弃,现在使用 agent 的 select_action() 自动管理队列
# 以下参数仅用于兼容旧代码,实际使用 agent.num_action_steps
num_queries: ${agent.num_action_steps}
obs_horizon: ${agent.obs_horizon}
# Camera names — 从 data 配置中引用,保持一致性
# ====================
# 相机配置
# ====================
camera_names: ${data.camera_names}
# Action smoothing
# ====================
# 动作平滑
# ====================
use_smoothing: false
smooth_method: "ema"
smooth_alpha: 0.3

View File

@@ -1,5 +1,15 @@
_target_: roboimi.vla.models.heads.conditional_unet1d.ConditionalUnet1D
_partial_: true
kernel_size: 3
cond_predict_scale: false
# ====================
# UNet1D 配置
# ====================
kernel_size: 3 # 卷积核大小
cond_predict_scale: false # FiLM 条件化时是否同时预测 scalebias + scale 或仅 bias
# ====================
# 网络架构(默认值,可覆盖)
# ====================
# diffusion_step_embed_dim: 256 # 扩散时间步嵌入维度
# down_dims: [256, 512, 1024] # 下采样各层通道数
# n_groups: 8 # GroupNorm 分组数

View File

@@ -1,152 +0,0 @@
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import h5py
import numpy as np
import os
import glob
import pickle
class RobotDiffusionDataset(Dataset):
def __init__(self,
dataset_dir,
pred_horizon=16,
obs_horizon=2,
action_horizon=8,
camera_names=['r_vis', 'top', 'front'],
normalization_type='gaussian'):
"""
Args:
dataset_dir: 存放 episode_*.hdf5 的文件夹路径
pred_horizon: 预测未来动作的长度 (Tp)
obs_horizon: 历史观测长度 (To)
action_horizon: 执行动作长度 (Ta) - 在Dataset中主要影响Evaluation这里作为参数保留
"""
self.dataset_dir = dataset_dir
self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon
self.action_horizon = action_horizon
self.camera_names = camera_names
self.normalization_type = normalization_type
# 1. 扫描所有HDF5文件并建立索引
# 格式: [(file_path, episode_length), ...]
self.episode_files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
self.indices = []
print(f"Found {len(self.episode_files)} episodes. Building index...")
for file_path in self.episode_files:
with h5py.File(file_path, 'r') as f:
# 获取该 episode 的长度 (例如 700)
l = f['action'].shape[0]
# 保存每个有效 step 的索引信息
# (file_path, episode_length, current_step_index)
for i in range(l):
self.indices.append((file_path, l, i))
# 2. 统计数据
with open(os.path.join(dataset_dir, 'data_stats.pkl'), 'rb') as f:
self.stats = pickle.load(f)
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
file_path, episode_len, start_ts = self.indices[idx]
# -----------------------------
# 1. 打开文件
# -----------------------------
# 注意: 在 __getitem__ 中打开文件对多进程 DataLoader 更友好
# 如果追求极致IO性能可以考虑使用 h5py 的 swmr 模式或内存缓存
with h5py.File(file_path, 'r') as root:
# -----------------------------
# 2. 处理 Action (Prediction Target)
# -----------------------------
# 目标: 获取 [t, t + pred_horizon] 的动作
action_start = start_ts
action_end = min(start_ts + self.pred_horizon, episode_len)
actions = root['action'][action_start:action_end] # shape: (T_subset, 16)
# Padding: 如果剩余动作不足 pred_horizon复制最后一步
if len(actions) < self.pred_horizon:
pad_len = self.pred_horizon - len(actions)
last_action = actions[-1]
# 重复最后一行
pad_content = np.repeat(last_action[np.newaxis, :], pad_len, axis=0)
actions = np.concatenate([actions, pad_content], axis=0)
# 归一化 Action
if self.stats:
actions = self._normalize_data(actions, self.stats['action'])
# -----------------------------
# 3. 处理 Observations (History)
# -----------------------------
# 目标: 获取 [t - obs_horizon + 1, t + 1] 的观测
# 索引逻辑:
# 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding)
# 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5]
start_idx_raw = start_ts - (self.obs_horizon - 1)
start_idx = max(start_idx_raw, 0)
end_idx = start_ts + 1
pad_len = max(0, -start_idx_raw)
# Qpos
qpos_data = root['observations/qpos']
qpos_val = qpos_data[start_idx:end_idx]
if pad_len > 0:
first_frame = qpos_val[0]
padding = np.repeat(first_frame[np.newaxis, :], pad_len, axis=0)
qpos_val = np.concatenate([padding, qpos_val], axis=0)
if self.stats:
qpos_val = self._normalize_data(qpos_val, self.stats['qpos'])
# Images
image_dict = {}
for cam_name in self.camera_names:
img_dset = root['observations']['images'][cam_name]
imgs_np = img_dset[start_idx:end_idx] # (T, H, W, C)
if pad_len > 0:
first_frame = imgs_np[0]
padding = np.repeat(first_frame[np.newaxis, ...], pad_len, axis=0)
imgs_np = np.concatenate([padding, imgs_np], axis=0)
# 转换为 Tensor: (T, H, W, C) -> (T, C, H, W)
imgs_tensor = torch.from_numpy(imgs_np).float() / 255.0
imgs_tensor = torch.einsum('thwc->tchw', imgs_tensor)
image_dict[cam_name] = imgs_tensor
# ==============================
# 3. 组装 Batch
# ==============================
data_batch = {
'action': torch.from_numpy(actions).float(),
'qpos': torch.from_numpy(qpos_val).float(),
}
for cam_name, img_tensor in image_dict.items():
data_batch[f'image_{cam_name}'] = img_tensor
return data_batch
def _normalize_data(self, data, stats):
if self.normalization_type == 'min_max':
# 之前的逻辑: [-1, 1]
min_val = stats['min']
max_val = stats['max']
data = (data - min_val) / (max_val - min_val + 1e-8)
return data * 2 - 1
elif self.normalization_type == 'gaussian':
# 新逻辑: Mean/Std
mean = stats['mean']
std = stats['std']
# (data - mean) / std
# 这里的 data 是 numpy array
return (data - mean) / (std + 1e-8)

View File

@@ -1,53 +1,98 @@
import torch
import h5py
from torch.utils.data import Dataset
from typing import List, Dict, Optional
from typing import List, Dict, Union
from pathlib import Path
class SimpleRobotDataset(Dataset):
"""
LeRobotDataset 简化版 - 图像以字典形式存储
HDF5 懒加载数据集 - LeRobotDataset 格式
与真实 LeRobotDataset 保持一致:
- Dataset 返回字典,每个摄像头单独的 key
- Policy 负责在 forward 时 stack 图像
返回格式:
- observation.state: (obs_horizon, state_dim)
- observation.{cam_name}: (obs_horizon, C, H, W)
- action: (pred_horizon, action_dim)
"""
def __init__(
self,
frames: List[Dict],
dataset_dir: Union[str, Path],
obs_horizon: int = 2,
pred_horizon: int = 8,
image_keys: List[str] = None,
camera_names: List[str] = None,
):
"""
Args:
frames: 帧数据列表。每个元素是一个字典,包含:
- "episode_index" (int): [必须] 该帧所属的 Episode ID。Dataset 使用它来确定 Episode 的边界(用于 Padding
- "task" (str): [必须] 任务描述字符串(例如 "pick_up_cube")。
- "observation.state" (torch.Tensor): (state_dim,) [必须] 当前帧的机器人状态向量(例如关节角度)。
- "action" (torch.Tensor): (action_dim,) [必须] 当前帧对应的动作向量。
- "{image_key}" (torch.Tensor): (C, H, W) [可选] 当前帧的图像数据。键名必须与初始化 Dataset 时传入的 image_keys 列表一致。
dataset_dir: HDF5 文件目录路径
obs_horizon: 观察过去多少帧
pred_horizon: 预测未来多少帧动作
image_keys: 哪些 key 是图像数据(例如 ["observation.image_0", "observation.image_1"]
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
HDF5 文件格式:
- action: [T, action_dim]
- observations/qpos: [T, obs_dim]
- observations/images/{cam_name}: [T, H, W, C]
"""
self.frames = frames
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.image_keys = image_keys or []
self.camera_names = camera_names or []
# 构建 episode 索引
self.dataset_dir = Path(dataset_dir)
if not self.dataset_dir.exists():
raise FileNotFoundError(f"数据集目录不存在: {dataset_dir}")
# 查找 HDF5 文件
self.hdf5_files = sorted(self.dataset_dir.glob("*.hdf5"))
if not self.hdf5_files:
self.hdf5_files = sorted(self.dataset_dir.glob("episode_*.hdf5"))
if not self.hdf5_files:
raise FileNotFoundError(f"{dataset_dir} 中未找到 HDF5 文件")
# 构建 episode 索引(只存储元数据,不加载数据)
self.episodes = {}
for idx, frame in enumerate(frames):
ep_idx = frame["episode_index"]
if ep_idx not in self.episodes:
self.episodes[ep_idx] = []
self.episodes[ep_idx].append(idx)
self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
for ep_idx, hdf5_path in enumerate(self.hdf5_files):
with h5py.File(hdf5_path, 'r') as f:
T = f['action'].shape[0]
start_idx = len(self.frame_meta)
for t in range(T):
self.frame_meta.append({
"ep_idx": ep_idx,
"frame_idx": t,
"hdf5_path": hdf5_path,
})
self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta)))
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)}")
def __len__(self):
return len(self.frames)
return len(self.frame_meta)
def _load_frame(self, idx: int) -> Dict:
"""从 HDF5 文件懒加载单帧数据"""
meta = self.frame_meta[idx]
with h5py.File(meta["hdf5_path"], 'r') as f:
frame = {
"episode_index": meta["ep_idx"],
"frame_index": meta["frame_idx"],
"task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown",
"observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(),
"action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(),
}
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
for cam_name in self.camera_names:
h5_path = f'observations/images/{cam_name}'
if h5_path in f:
img = f[h5_path][meta["frame_idx"]]
img = torch.from_numpy(img).float()
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
return frame
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
frame = self.frames[idx]
frame = self._load_frame(idx)
ep_idx = frame["episode_index"]
# 获取当前 episode 的帧索引范围
@@ -61,9 +106,9 @@ class SimpleRobotDataset(Dataset):
observations = {
"state": [], # 状态数据
}
# 为每个摄像头初始化独立列表(字典形式)
for cam_key in self.image_keys:
observations[cam_key] = []
# 为每个摄像头初始化独立列表
for cam_name in self.camera_names:
observations[f"observation.{cam_name}"] = []
observation_is_pad = []
@@ -72,22 +117,22 @@ class SimpleRobotDataset(Dataset):
# 边界检查
if ep_start <= target_idx <= ep_end:
target_frame = self.frames[target_idx]
target_frame = self._load_frame(target_idx)
is_pad = False
else:
# 超出边界,用边界帧填充
if target_idx < ep_start:
target_frame = self.frames[ep_start]
target_frame = self._load_frame(ep_start)
else:
target_frame = self.frames[ep_end]
target_frame = self._load_frame(ep_end)
is_pad = True
# 收集状态
observations["state"].append(target_frame["observation.state"])
# 收集每个摄像头的图像(字典形式,不 stack
for cam_key in self.image_keys:
observations[cam_key].append(target_frame[cam_key])
# 收集每个摄像头的图像
for cam_name in self.camera_names:
observations[f"observation.{cam_name}"].append(target_frame[f"observation.{cam_name}"])
observation_is_pad.append(is_pad)
@@ -101,14 +146,14 @@ class SimpleRobotDataset(Dataset):
target_idx = idx + delta
if target_idx <= ep_end:
actions.append(self.frames[target_idx]["action"])
actions.append(self._load_frame(target_idx)["action"])
action_is_pad.append(False)
else:
actions.append(self.frames[ep_end]["action"])
actions.append(self._load_frame(ep_end)["action"])
action_is_pad.append(True)
# ============================================
# 3. 组装返回数据(字典形式)
# 3. 组装返回数据(LeRobotDataset 格式)
# ============================================
result = {
# 状态观察: (obs_horizon, state_dim)
@@ -123,401 +168,32 @@ class SimpleRobotDataset(Dataset):
"task": frame["task"],
}
# 图像:每个摄像头独立的 key(字典形式)
# 图像:每个摄像头独立的 key
# 形状: (obs_horizon, C, H, W)
for cam_key in self.image_keys:
result[cam_key] = torch.stack(observations[cam_key])
for cam_name in self.camera_names:
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
return result
@property
def camera_keys(self) -> list[str]:
"""获取所有相机键名"""
return self.image_keys
"""获取所有相机键名 (LeRobotDataset 格式)"""
return [f"observation.{cam_name}" for cam_name in self.camera_names]
@property
def camera_info(self) -> dict:
"""获取相机信息"""
if not self.image_keys:
if not self.camera_names:
return {}
# 从第一个样本获取形状
sample = self[0]
info = {}
for cam_key in self.image_keys:
if cam_key in sample:
info[cam_key] = {
"shape": sample[cam_key].shape,
"dtype": str(sample[cam_key].dtype),
for cam_name in self.camera_names:
key = f"observation.{cam_name}"
if key in sample:
info[key] = {
"shape": sample[key].shape,
"dtype": str(sample[key].dtype),
}
return info
class SimpleDiffusionPolicy(torch.nn.Module):
"""简化的 Diffusion Policy - 展示如何在 forward 时 stack 图像"""
def __init__(
self,
state_dim: int,
action_dim: int,
image_features: Dict[str, tuple] = None,
obs_horizon: int = 2,
pred_horizon: int = 8,
):
super().__init__()
self.state_dim = state_dim
self.action_dim = action_dim
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.image_features = image_features or {}
self.state_encoder = torch.nn.Linear(state_dim, 64)
if image_features:
num_cameras = len(image_features)
self.image_encoder = torch.nn.Conv2d(3, 32, kernel_size=7, stride=2)
self.fusion = torch.nn.Linear(64 + 32 * num_cameras, 128)
else:
self.fusion = torch.nn.Linear(64, 128)
self.action_head = torch.nn.Linear(128, action_dim * pred_horizon)
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""前向传播"""
# 处理状态
state_features = self.state_encoder(batch["observation.state"])
state_features = state_features.mean(dim=1)
# 处理图像(字典形式 → stack
if self.image_features:
image_tensors = [batch[key] for key in self.image_features.keys()]
stacked_images = torch.stack(image_tensors, dim=1)
B, num_cam, T, C, H, W = stacked_images.shape
images_flat = stacked_images.reshape(B * num_cam * T, C, H, W)
image_features = self.image_encoder(images_flat)
image_features = image_features.mean(dim=[2, 3])
image_features = image_features.reshape(B, num_cam, T, 32).mean(dim=2)
image_features = image_features.reshape(B, -1)
features = torch.cat([state_features, image_features], dim=-1)
else:
features = state_features
fused = self.fusion(features)
pred_actions = self.action_head(fused)
pred_actions = pred_actions.reshape(B, self.pred_horizon, self.action_dim)
return pred_actions
def create_demo_data_with_images():
"""创建包含图像的模拟数据"""
frames = []
# Episode 0: pick_up_cube task
for t in range(10):
frames.append({
"episode_index": 0,
"frame_index": t,
"task": "pick_up_cube",
"observation.state": torch.randn(6),
"observation.image_high_resize": torch.randn(3, 64, 64),
"observation.image_left_wrist": torch.randn(3, 64, 64),
"action": torch.randn(6),
})
# Episode 1: stack_blocks task
for t in range(10):
frames.append({
"episode_index": 1,
"frame_index": t,
"task": "stack_blocks",
"observation.state": torch.randn(6),
"observation.image_high_resize": torch.randn(3, 64, 64),
"observation.image_left_wrist": torch.randn(3, 64, 64),
"action": torch.randn(6),
})
return frames
def print_section(title: str):
"""打印分节标题"""
print("\n" + "=" * 80)
print(f" {title}")
print("=" * 80)
def test_dataset_basic_info(dataset):
"""测试数据集基本信息"""
print("\n📊 数据集基本信息:")
print(f" 总帧数: {len(dataset)}")
print(f" 总 episode 数: {len(dataset.episodes)}")
print(f" 观察窗口: {dataset.obs_horizon}")
print(f" 预测窗口: {dataset.pred_horizon}")
print(f"\n📷 相机信息:")
cameras = dataset.camera_keys
print(f" 相机数量: {len(cameras)}")
for cam in cameras:
print(f" - {cam}")
print(f"\n相机详细信息:")
cam_info = dataset.camera_info
for cam, info in cam_info.items():
print(f" {cam}:")
print(f" shape: {info['shape']}")
print(f" dtype: {info['dtype']}")
def test_single_sample(dataset):
"""测试单个样本"""
print_section("1. 测试单个样本")
# Episode 中间的样本
sample = dataset[5]
print("\n样本结构 (字典形式):")
for key, value in sample.items():
if isinstance(value, torch.Tensor):
print(f" {key:30s}: {str(value.shape):20s} {value.dtype}")
elif isinstance(value, str):
print(f" {key:30s}: {value}")
# 验证图像是字典形式
print("\n✅ 验证图像存储形式:")
print(" 图像以字典形式存储,每个摄像头独立的 key:")
for cam_key in dataset.camera_keys:
if cam_key in sample:
print(f" - {cam_key}: {sample[cam_key].shape}")
# 验证时间维度
print("\n✅ 验证时间维度:")
print(f" observation.state: {sample['observation.state'].shape}")
print(f" 预期: (obs_horizon={dataset.obs_horizon}, state_dim=6)")
assert sample['observation.state'].shape[0] == dataset.obs_horizon, "观察时间维度错误"
print(f" action: {sample['action'].shape}")
print(f" 预期: (pred_horizon={dataset.pred_horizon}, action_dim=6)")
assert sample['action'].shape[0] == dataset.pred_horizon, "动作时间维度错误"
print(" ✓ 时间维度验证通过")
def test_edge_cases(dataset):
"""测试边界情况"""
print_section("2. 测试边界情况")
test_cases = [
("Episode 开头", 0, {"obs_pad": [True, False], "action_pad": [False] * 8}),
("Episode 中间", 5, {"obs_pad": [False, False], "action_pad": [False] * 5 + [True] * 3}),
("Episode 末尾", 9, {"obs_pad": [False, False], "action_pad": [True] * 8}),
("跨 Episode", 10, {"obs_pad": [True, False], "action_pad": [False] * 8}),
]
for name, idx, expected in test_cases:
print(f"\n📍 {name} (idx={idx}):")
sample = dataset[idx]
obs_pad = sample["observation_is_pad"].tolist()
action_pad_count = sample["action_is_pad"].sum().item()
print(f" observation_is_pad: {obs_pad}")
print(f" action_is_pad: {sample['action_is_pad'].tolist()}")
print(f" action padding 数量: {action_pad_count}")
# 验证观察 padding
if name == "Episode 开头":
assert obs_pad[0] == True, "Episode 开头第一帧应该是 padding"
elif name == "跨 Episode":
assert obs_pad[0] == True, "跨 Episode 第一帧应该是 padding"
def test_dataloader(dataset):
"""测试 DataLoader"""
print_section("3. 测试 DataLoader 集成")
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
num_workers=0, # 测试时用 0
)
batch = next(iter(dataloader))
print("\n📦 Batch 结构:")
for key in ["observation.state", "observation.image_high_resize",
"observation.image_left_wrist", "action", "task"]:
if key in batch:
value = batch[key]
if isinstance(value, torch.Tensor):
print(f" {key:30s}: {str(value.shape):20s} {value.dtype}")
else:
print(f" {key:30s}: {type(value).__name__} (length={len(value)})")
print("\n✅ 验证 Batch 形状:")
B = len(batch["observation.state"])
print(f" Batch size: {B}")
# 验证每个摄像头的形状
for cam_key in dataset.camera_keys:
expected_shape = (B, dataset.obs_horizon, 3, 64, 64)
actual_shape = batch[cam_key].shape
print(f" {cam_key}:")
print(f" 预期: {expected_shape}")
print(f" 实际: {actual_shape}")
assert actual_shape == expected_shape, f"{cam_key} 形状不匹配"
print(" ✓ Batch 形状验证通过")
def test_policy_forward(dataset):
"""测试 Policy 前向传播"""
print_section("4. 测试 Policy 前向传播")
# 创建 Policy
policy = SimpleDiffusionPolicy(
state_dim=6,
action_dim=6,
image_features={
"observation.image_high_resize": (3, 64, 64),
"observation.image_left_wrist": (3, 64, 64),
},
obs_horizon=dataset.obs_horizon,
pred_horizon=dataset.pred_horizon,
)
# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
batch = next(iter(dataloader))
print("\n🔄 Policy.forward() 流程:")
# 1. Stack 之前
print("\n 1⃣ Stack 之前 (字典形式):")
for cam_key in policy.image_features.keys():
print(f" batch['{cam_key}']: {batch[cam_key].shape}")
# 2. 模拟 Stack 操作
print("\n 2⃣ Stack 操作:")
image_tensors = [batch[key] for key in policy.image_features.keys()]
stacked = torch.stack(image_tensors, dim=1)
print(f" stacked_images: {stacked.shape}")
print(f" (B={stacked.shape[0]}, num_cam={stacked.shape[1]}, ")
print(f" obs_hor={stacked.shape[2]}, C={stacked.shape[3]}, H={stacked.shape[4]}, W={stacked.shape[5]})")
# 3. 前向传播
print("\n 3⃣ 前向传播:")
with torch.no_grad():
pred_actions = policy(batch)
print(f" 输入:")
print(f" observation.state: {batch['observation.state'].shape}")
print(f" 图像已 stack")
print(f" 输出:")
print(f" pred_actions: {pred_actions.shape}")
print(f" (B={pred_actions.shape[0]}, pred_horizon={pred_actions.shape[1]}, action_dim={pred_actions.shape[2]})")
print("\n✅ Policy 前向传播验证通过")
def test_data_consistency(dataset):
"""测试数据一致性"""
print_section("5. 测试数据一致性")
print("\n🔍 验证图像 padding 的正确性:")
# Episode 开头的样本
sample = dataset[0]
if sample["observation_is_pad"][0]:
img_0 = sample["observation.image_high_resize"][0]
img_1 = sample["observation.image_high_resize"][1]
print(f" Episode 开头 (idx=0):")
print(f" 第0帧是 padding: {sample['observation_is_pad'][0]}")
print(f" 第0帧图像 = 第1帧图像: {torch.equal(img_0, img_1)}")
assert torch.equal(img_0, img_1), "Padding 应该复制边界帧"
print(" ✓ Padding 正确")
# Episode 中间的样本
sample = dataset[5]
if not sample["observation_is_pad"].any():
img_0 = sample["observation.image_high_resize"][0]
img_1 = sample["observation.image_high_resize"][1]
print(f"\n Episode 中间 (idx=5):")
print(f" 没有 padding: {sample['observation_is_pad']}")
print(f" 第0帧图像 ≠ 第1帧图像: {not torch.equal(img_0, img_1)}")
print(" ✓ 正常帧不重复")
print("\n✅ 数据一致性验证通过")
def test_task_info(dataset):
"""测试任务信息"""
print_section("6. 测试任务信息")
print("\n📋 统计任务分布:")
task_count = {}
for frame in dataset.frames:
task = frame["task"]
task_count[task] = task_count.get(task, 0) + 1
for task, count in task_count.items():
print(f" {task}: {count}")
# 验证 sample 中的 task 信息
sample = dataset[0]
print(f"\n样本 task: {sample['task']}")
print(f" 类型: {type(sample['task'])}")
# 验证 DataLoader 中的 task
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
batch = next(iter(dataloader))
print(f"\nBatch task:")
print(f" 值: {batch['task']}")
print(f" 类型: {type(batch['task'])}")
print(f" 长度: {len(batch['task'])}")
print("\n✅ 任务信息验证通过")
def run_all_tests():
"""运行所有测试"""
print("\n" + "🚀" * 40)
print(" SimpleRobotDataset 完整测试套件")
print("🚀" * 40)
# 创建数据集
print("\n创建测试数据...")
frames = create_demo_data_with_images()
dataset = SimpleRobotDataset(
frames,
obs_horizon=2,
pred_horizon=8,
image_keys=["observation.image_high_resize", "observation.image_left_wrist"],
)
print("✓ 数据集创建完成")
# 运行测试
test_dataset_basic_info(dataset)
test_single_sample(dataset)
test_edge_cases(dataset)
test_dataloader(dataset)
test_policy_forward(dataset)
test_data_consistency(dataset)
test_task_info(dataset)
# 总结
print_section("✅ 测试总结")
print("\n所有测试通过!✨")
print("\n关键验证点:")
print(" ✓ 图像以字典形式存储")
print(" ✓ 每个摄像头独立的 key")
print(" ✓ Policy 在 forward 时 stack 图像")
print(" ✓ 时间维度正确 (obs_horizon, pred_horizon)")
print(" ✓ Padding 处理正确")
print(" ✓ DataLoader 集成正确")
print(" ✓ Task 信息传递正确")
print("\n与 LeRobotDataset 设计完全一致!🎉")
if __name__ == "__main__":
from torch.utils.data import DataLoader
run_all_tests()

View File

@@ -1,4 +1,4 @@
# Backbone models
from .resnet import ResNetBackbone
from .resnet_diffusion import ResNetDiffusionBackbone
__all__ = ["ResNetBackbone"]
__all__ = ["ResNetBackbone", "ResNetDiffusionBackbone"]

View File

@@ -1,93 +0,0 @@
from roboimi.vla.core.interfaces import VLABackbone
from transformers import ResNetModel
from torchvision import transforms
import torch
import torch.nn as nn
class ResNetBackbone(VLABackbone):
def __init__(
self,
model_name = "microsoft/resnet-18",
freeze: bool = True,
):
super().__init__()
self.model = ResNetModel.from_pretrained(model_name)
self.out_channels = self.model.config.hidden_sizes[-1]
self.transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.spatial_softmax = SpatialSoftmax(num_rows=12, num_cols=12)
if freeze:
self._freeze_parameters()
def _freeze_parameters(self):
print("❄️ Freezing ResNet Backbone parameters")
for param in self.model.parameters():
param.requires_grad = False
self.model.eval()
def train(self, mode=True):
"""
Override train() to keep frozen ResNet in eval mode.
This ensures BatchNorm layers use running statistics consistently.
"""
super().train(mode)
if hasattr(self, 'model'):
self.model.eval() # Always keep ResNet in eval mode
return self
def forward_single_image(self, image):
B, T, C, H, W = image.shape
image = image.view(B * T, C, H, W)
image = self.transform(image)
feature_map = self.model(image).last_hidden_state # (B*T, D, H', W')
features = self.spatial_softmax(feature_map) # (B*T, D*2)
return features
def forward(self, images):
any_tensor = next(iter(images.values()))
B, T = any_tensor.shape[:2]
features_all = []
sorted_cam_names = sorted(images.keys())
for cam_name in sorted_cam_names:
img = images[cam_name]
features = self.forward_single_image(img) # (B*T, D*2)
features_all.append(features)
combined_features = torch.cat(features_all, dim=1) # (B*T, Num_Cams*D*2)
return combined_features.view(B, T, -1)
@property
def output_dim(self):
"""Output dimension after spatial softmax: out_channels * 2"""
return self.out_channels * 2
class SpatialSoftmax(nn.Module):
"""
将特征图 (N, C, H, W) 转换为坐标特征 (N, C*2)
"""
def __init__(self, num_rows, num_cols, temperature=None):
super().__init__()
self.temperature = nn.Parameter(torch.ones(1))
# 创建网格坐标
pos_x, pos_y = torch.meshgrid(
torch.linspace(-1, 1, num_rows),
torch.linspace(-1, 1, num_cols),
indexing='ij'
)
self.register_buffer('pos_x', pos_x.reshape(-1))
self.register_buffer('pos_y', pos_y.reshape(-1))
def forward(self, x):
N, C, H, W = x.shape
x = x.view(N, C, -1) # (N, C, H*W)
# 计算 Softmax 注意力图
softmax_attention = torch.nn.functional.softmax(x / self.temperature, dim=2)
# 计算期望坐标 (x, y)
expected_x = torch.sum(softmax_attention * self.pos_x, dim=2, keepdim=True)
expected_y = torch.sum(softmax_attention * self.pos_y, dim=2, keepdim=True)
# 拼接并展平 -> (N, C*2)
return torch.cat([expected_x, expected_y], dim=2).reshape(N, -1)

View File

@@ -91,20 +91,21 @@ class SpatialSoftmax(nn.Module):
return feature_keypoints
class ResNetDiffusionBackbone(VLABackbone):
class _SingleRgbEncoder(nn.Module):
"""单个摄像头的 RGB 编码器,支持独立或共享使用"""
def __init__(
self,
vision_backbone: str = "resnet18",
pretrained_backbone_weights: str | None = None,
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W)
crop_shape: Optional[Tuple[int, int]] = None,
crop_is_random: bool = True,
use_group_norm: bool = True,
spatial_softmax_num_keypoints: int = 32,
vision_backbone: str,
pretrained_backbone_weights: str | None,
input_shape: Tuple[int, int, int],
crop_shape: Optional[Tuple[int, int]],
crop_is_random: bool,
use_group_norm: bool,
spatial_softmax_num_keypoints: int,
):
super().__init__()
# 设置可选的预处理
# 设置可选的预处理
if crop_shape is not None:
self.do_crop = True
# 评估时始终使用中心裁剪
@@ -117,7 +118,7 @@ class ResNetDiffusionBackbone(VLABackbone):
self.do_crop = False
crop_shape = input_shape[1:]
# 设置骨干网络
# 设置骨干网络
backbone_model = getattr(torchvision.models, vision_backbone)(
weights=pretrained_backbone_weights
)
@@ -132,12 +133,12 @@ class ResNetDiffusionBackbone(VLABackbone):
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# 设置池化和最终层
# 使用试运行来获取特征图形状
# 设置池化和最终层
# 使用试运行来获取特征图形状
dummy_shape = (1, input_shape[0], *crop_shape)
with torch.no_grad():
dummy_out = self.backbone(torch.zeros(dummy_shape))
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
self.pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
self.feature_dim = spatial_softmax_num_keypoints * 2
@@ -150,22 +151,94 @@ class ResNetDiffusionBackbone(VLABackbone):
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
return x
class ResNetDiffusionBackbone(VLABackbone):
def __init__(
self,
vision_backbone: str = "resnet18",
pretrained_backbone_weights: str | None = None,
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W)
crop_shape: Optional[Tuple[int, int]] = None,
crop_is_random: bool = True,
use_group_norm: bool = True,
spatial_softmax_num_keypoints: int = 32,
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
):
super().__init__()
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
self.num_cameras = num_cameras
if use_separate_rgb_encoder_per_camera:
# 独立编码器模式:为每个摄像头创建独立的编码器
encoders = [
_SingleRgbEncoder(
vision_backbone=vision_backbone,
pretrained_backbone_weights=pretrained_backbone_weights,
input_shape=input_shape,
crop_shape=crop_shape,
crop_is_random=crop_is_random,
use_group_norm=use_group_norm,
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
)
for _ in range(num_cameras)
]
self.rgb_encoder = nn.ModuleList(encoders)
# 重要output_dim 始终表示单个编码器的特征维度(与 lerobot 保持一致)
self.feature_dim = encoders[0].feature_dim
else:
# 共享编码器模式:所有摄像头共享同一个编码器
self.rgb_encoder = _SingleRgbEncoder(
vision_backbone=vision_backbone,
pretrained_backbone_weights=pretrained_backbone_weights,
input_shape=input_shape,
crop_shape=crop_shape,
crop_is_random=crop_is_random,
use_group_norm=use_group_norm,
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
)
self.feature_dim = self.rgb_encoder.feature_dim
def forward(self, images):
"""
Args:
images: Dict[str, Tensor], 每个摄像头的图像
形状: {cam_name: (B, T, C, H, W)}
Returns:
Tensor: (B, T, total_feature_dim)
"""
any_tensor = next(iter(images.values()))
B, T = any_tensor.shape[:2]
features_all = []
for cam_name in sorted(images.keys()):
img = images[cam_name]
features = self.forward_single_image(img.view(B * T, *img.shape[2:]))
features_all.append(features)
return torch.cat(features_all, dim=1).view(B, T, -1)
cam_names = sorted(images.keys())
if self.use_separate_rgb_encoder_per_camera:
# 独立编码器模式:每个摄像头使用对应的编码器
features_all = []
for cam_idx, cam_name in enumerate(cam_names):
img = images[cam_name]
encoder = self.rgb_encoder[cam_idx]
features = encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
features_all.append(features)
return torch.cat(features_all, dim=1).view(B, T, -1)
else:
# 共享编码器模式:所有摄像头共享同一个编码器
features_all = []
for cam_name in cam_names:
img = images[cam_name]
features = self.rgb_encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
features_all.append(features)
return torch.cat(features_all, dim=1).view(B, T, -1)
@property
def output_dim(self):
return self.feature_dim
if __name__ == "__main__":
print("🚀 Testing ResNetDiffusionBackbone...")
print("=" * 60)
print("🚀 Testing ResNetDiffusionBackbone")
print("=" * 60)
# Configuration
B, T = 2, 5
@@ -174,34 +247,109 @@ if __name__ == "__main__":
num_keypoints = 32
feature_dim_per_cam = num_keypoints * 2
# Instantiate model
backbone = ResNetDiffusionBackbone(
vision_backbone="resnet18",
pretrained_backbone_weights=None, # Speed up test
input_shape=(C, H, W),
crop_shape=(crop_h, crop_w),
crop_is_random=True,
use_group_norm=True,
spatial_softmax_num_keypoints=num_keypoints
)
print(f"✅ Model instantiated. Output dim per camera: {backbone.output_dim}")
# Create dummy input
# Create dummy input (2 cameras)
images = {
"cam_high": torch.randn(B, T, C, H, W),
"cam_wrist": torch.randn(B, T, C, H, W)
}
num_cameras = len(images)
# Forward pass
print("🔄 Running forward pass...")
output = backbone(images)
# ============================================================================
# Test 1: Shared Encoder (默认模式)
# ============================================================================
print("\n[Test 1] Shared Encoder Mode")
print("-" * 60)
backbone_shared = ResNetDiffusionBackbone(
vision_backbone="resnet18",
pretrained_backbone_weights=None, # Speed up test
input_shape=(C, H, W),
crop_shape=(crop_h, crop_w),
crop_is_random=True,
use_group_norm=True,
spatial_softmax_num_keypoints=num_keypoints,
use_separate_rgb_encoder_per_camera=False, # 共享编码器
)
print(f"Input shapes: {[v.shape for v in images.values()]}")
print(f"Output shape: {output.shape}")
print(f"✅ Shared encoder model instantiated")
print(f" Output dim per camera: {feature_dim_per_cam}")
print(f" Number of cameras: {num_cameras}")
print(f" Expected total dim: {num_cameras * feature_dim_per_cam}")
# Verification
expected_dim = len(images) * feature_dim_per_cam
output = backbone_shared(images)
print(f"\n🔄 Forward pass completed")
print(f" Input shapes: {[v.shape for v in images.values()]}")
print(f" Output shape: {output.shape}")
expected_dim = num_cameras * feature_dim_per_cam
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
print(f"✨ Test passed!")
print("✨ Test passed!")
# ============================================================================
# Test 2: Separate Encoders (独立编码器模式)
# ============================================================================
print("\n[Test 2] Separate Encoders Mode")
print("-" * 60)
backbone_separate = ResNetDiffusionBackbone(
vision_backbone="resnet18",
pretrained_backbone_weights=None, # Speed up test
input_shape=(C, H, W),
crop_shape=(crop_h, crop_w),
crop_is_random=True,
use_group_norm=True,
spatial_softmax_num_keypoints=num_keypoints,
use_separate_rgb_encoder_per_camera=True, # 独立编码器
num_cameras=num_cameras,
)
print(f"✅ Separate encoders model instantiated")
print(f" Output dim per camera: {feature_dim_per_cam}")
print(f" Number of cameras: {num_cameras}")
print(f" Number of encoders: {len(backbone_separate.rgb_encoder)}")
output = backbone_separate(images)
print(f"\n🔄 Forward pass completed")
print(f" Input shapes: {[v.shape for v in images.values()]}")
print(f" Output shape: {output.shape}")
expected_dim = num_cameras * feature_dim_per_cam
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
print(f"✨ Test passed!")
# ============================================================================
# Test 3: Verify parameters count
# ============================================================================
print("\n[Test 3] Parameter Count Comparison")
print("-" * 60)
shared_params = sum(p.numel() for p in backbone_shared.parameters())
separate_params = sum(p.numel() for p in backbone_separate.parameters())
print(f" Shared encoder parameters: {shared_params:,}")
print(f" Separate encoders parameters: {separate_params:,}")
print(f" Ratio: {separate_params / shared_params:.2f}x")
assert separate_params > shared_params, "Separate encoders should have more parameters"
print(f"✨ Verification passed!")
# ============================================================================
# Test 4: Verify independent parameters
# ============================================================================
print("\n[Test 4] Verify Independent Parameters")
print("-" * 60)
# Check that encoders have independent parameters
encoder_0_first_param = list(backbone_separate.rgb_encoder[0].parameters())[0]
encoder_1_first_param = list(backbone_separate.rgb_encoder[1].parameters())[0]
# Modify first encoder's parameter
with torch.no_grad():
encoder_0_first_param += 1.0
# Verify they are not the same tensor
assert not torch.allclose(encoder_0_first_param, encoder_1_first_param), \
"Encoders should have independent parameters"
print(f"✅ Encoders have independent parameters")
print(f"✨ All tests passed!")
print("\n" + "=" * 60)
print("🎉 All tests completed successfully!")
print("=" * 60)

View File

@@ -0,0 +1,128 @@
"""
归一化模块 - 统一训练和推理的归一化逻辑
支持两种归一化方式:
1. Gaussian (z-score): (x - mean) / std
2. MinMax: 2 * (x - min) / (max - min) - 1 -> [-1, 1]
"""
import torch
import torch.nn as nn
from typing import Optional, Dict, Literal
class NormalizationModule(nn.Module):
"""
统一的归一化模块
用于在 Agent 内部对 qpos 和 action 进行归一化/反归一化
"""
def __init__(
self,
stats: Optional[Dict] = None,
normalization_type: Literal['gaussian', 'min_max'] = 'gaussian'
):
"""
Args:
stats: 数据集统计信息字典,格式:
{
'normalization_type': 'gaussian' | 'min_max',
'qpos_mean': [...],
'qpos_std': [...],
'qpos_min': [...], # 仅 min_max 需要
'qpos_max': [...], # 仅 min_max 需要
'action_mean': [...],
'action_std': [...],
'action_min': [...], # 仅 min_max 需要
'action_max': [...], # 仅 min_max 需要
}
normalization_type: 归一化类型 ('gaussian''min_max')
"""
super().__init__()
self.normalization_type = normalization_type
self.enabled = stats is not None
if self.enabled:
# 从 stats 中读取归一化类型(如果提供)
self.normalization_type = stats.get('normalization_type', normalization_type)
# 注册为 buffer (不会被优化,但会随模型保存)
self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32))
self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32))
self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32))
self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32))
# MinMax 归一化需要 min/max
if self.normalization_type == 'min_max':
qpos_min = stats.get('qpos_min', [0.0] * len(stats['qpos_mean']))
qpos_max = stats.get('qpos_max', [1.0] * len(stats['qpos_mean']))
action_min = stats.get('action_min', [0.0] * len(stats['action_mean']))
action_max = stats.get('action_max', [1.0] * len(stats['action_mean']))
self.register_buffer('qpos_min', torch.tensor(qpos_min, dtype=torch.float32))
self.register_buffer('qpos_max', torch.tensor(qpos_max, dtype=torch.float32))
self.register_buffer('action_min', torch.tensor(action_min, dtype=torch.float32))
self.register_buffer('action_max', torch.tensor(action_max, dtype=torch.float32))
def normalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
"""归一化 qpos"""
if not self.enabled:
return qpos
if self.normalization_type == 'gaussian':
return (qpos - self.qpos_mean) / self.qpos_std
else: # min_max
return 2 * (qpos - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
def denormalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
"""反归一化 qpos"""
if not self.enabled:
return qpos
if self.normalization_type == 'gaussian':
return qpos * self.qpos_std + self.qpos_mean
else: # min_max
return (qpos + 1) / 2 * (self.qpos_max - self.qpos_min) + self.qpos_min
def normalize_action(self, action: torch.Tensor) -> torch.Tensor:
"""归一化 action"""
if not self.enabled:
return action
if self.normalization_type == 'gaussian':
return (action - self.action_mean) / self.action_std
else: # min_max
return 2 * (action - self.action_min) / (self.action_max - self.action_min) - 1
def denormalize_action(self, action: torch.Tensor) -> torch.Tensor:
"""反归一化 action"""
if not self.enabled:
return action
if self.normalization_type == 'gaussian':
return action * self.action_std + self.action_mean
else: # min_max
return (action + 1) / 2 * (self.action_max - self.action_min) + self.action_min
def get_stats(self) -> Optional[Dict]:
"""导出统计信息(用于保存到 checkpoint"""
if not self.enabled:
return None
stats = {
'normalization_type': self.normalization_type,
'qpos_mean': self.qpos_mean.cpu().tolist(),
'qpos_std': self.qpos_std.cpu().tolist(),
'action_mean': self.action_mean.cpu().tolist(),
'action_std': self.action_std.cpu().tolist(),
}
if self.normalization_type == 'min_max':
stats['qpos_min'] = self.qpos_min.cpu().tolist()
stats['qpos_max'] = self.qpos_max.cpu().tolist()
stats['action_min'] = self.action_min.cpu().tolist()
stats['action_max'] = self.action_max.cpu().tolist()
return stats