refactor:大重构

This commit is contained in:
gouhanke
2026-02-11 15:53:55 +08:00
parent 1e95d40bf9
commit 130d4bb3c5
19 changed files with 1411 additions and 1223 deletions

View File

@@ -0,0 +1,238 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamConfig
from lerobot.optim.schedulers import DiffuserSchedulerConfig
@PreTrainedConfig.register_subclass("diffusion")
@dataclass
class DiffusionConfig(PreTrainedConfig):
"""Configuration class for DiffusionPolicy.
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- Either:
- At least one key starting with "observation.image is required as an input.
AND/OR
- The key "observation.environment_state" is required as input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views. Right now we only support all images having the same shape.
- "action" is required as an output key.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
You may provide a variable number of dimensions, therefore also controlling the degree of
downsampling.
kernel_size: The convolutional kernel size of the diffusion modeling Unet.
n_groups: Number of groups used in the group norm of the Unet's convolutional blocks.
diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear
network. This is the output dimension of that network, i.e., the embedding dimension.
use_film_scale_modulation: FiLM (https://huggingface.co/papers/1709.07871) is used for the Unet conditioning.
Bias modulation is used be default, while this parameter indicates whether to also use scale
modulation.
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
beta_start: Beta value for the first forward-diffusion step.
beta_end: Beta value for the last forward-diffusion step.
prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon"
or "sample". These have equivalent outcomes from a latent variable modeling perspective, but
"epsilon" has been shown to work better in many deep neural network settings.
clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
denoising step at inference time. WARNING: you will need to make sure your action-space is
normalized to fit within this range.
clip_sample_range: The magnitude of the clipping range as described above.
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
`LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
to False as the original Diffusion Policy implementation does the same.
"""
# Inputs / output structure.
n_obs_steps: int = 2
horizon: int = 16
n_action_steps: int = 8
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results.
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
# Architecture / modeling.
# Vision backbone.
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84)
crop_is_random: bool = True
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
spatial_softmax_num_keypoints: int = 32
use_separate_rgb_encoder_per_camera: bool = False
# Unet.
down_dims: tuple[int, ...] = (512, 1024, 2048)
kernel_size: int = 5
n_groups: int = 8
diffusion_step_embed_dim: int = 128
use_film_scale_modulation: bool = True
# Noise scheduler.
noise_scheduler_type: str = "DDPM"
num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001
beta_end: float = 0.02
prediction_type: str = "epsilon"
clip_sample: bool = True
clip_sample_range: float = 1.0
# Inference
num_inference_steps: int | None = None
# Loss computation
do_mask_loss_for_padding: bool = False
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-6
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 500
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:
raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
)
supported_noise_schedulers = ["DDPM", "DDIM"]
if self.noise_scheduler_type not in supported_noise_schedulers:
raise ValueError(
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
f"Got {self.noise_scheduler_type}."
)
# Check that the horizon size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims)
if self.horizon % downsampling_factor != 0:
raise ValueError(
"The horizon should be an integer multiple of the downsampling factor (which is determined "
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
)
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
)
def validate_features(self) -> None:
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
# Check that all input images have the same shape.
if len(self.image_features) > 0:
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
)
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,92 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_diffusion_pre_post_processors(
config: DiffusionConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for a diffusion policy.
The pre-processing pipeline prepares the input data for the model by:
1. Renaming features.
2. Normalizing the input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Moving the data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving the data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the diffusion policy,
containing feature definitions, normalization mappings, and device information.
dataset_stats: A dictionary of statistics used for normalization.
Defaults to None.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -1,13 +1,13 @@
""" """
VLA Policy Evaluation Script (Hydra-based) VLA 策略评估脚本(简化版)
This script evaluates a trained Vision-Language-Action (VLA) policy 该脚本使用 agent 内置的队列管理来评估训练好的 VLA 策略。
in the MuJoCo simulation environment. 无需单独的评估器类 - agent 处理一切!
Usage: 使用方法:
python roboimi/demos/eval_vla.py python roboimi/demos/eval_vla_simple.py
python roboimi/demos/eval_vla.py ckpt_path=checkpoints/vla_model_step_8000.pt num_episodes=5 python roboimi/demos/eval_vla_simple.py eval.ckpt_path=checkpoints/vla_model_final.pt
python roboimi/demos/eval_vla.py use_smoothing=true smooth_alpha=0.5 python roboimi/demos/eval_vla_simple.py eval.ckpt_path=checkpoints/vla_model_best.pt
""" """
import sys import sys
@@ -19,314 +19,152 @@ import torch
import numpy as np import numpy as np
import hydra import hydra
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict
from tqdm import tqdm from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate from hydra.utils import instantiate
from einops import rearrange
from roboimi.envs.double_pos_ctrl_env import make_sim_env from roboimi.envs.double_pos_ctrl_env import make_sim_env
from roboimi.utils.act_ex_utils import sample_transfer_pose from roboimi.utils.act_ex_utils import sample_transfer_pose
from einops import rearrange
# Ensure correct import path
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
if not OmegaConf.has_resolver("len"): if not OmegaConf.has_resolver("len"):
OmegaConf.register_new_resolver("len", lambda x: len(x)) OmegaConf.register_new_resolver("len", lambda x: len(x))
class VLAEvaluator:
"""
VLA Policy Evaluator for MuJoCo Simulation
"""
def __init__(
self,
agent: torch.nn.Module,
device: str = 'cuda',
camera_names: List[str] = ['r_vis', 'top', 'front'],
num_queries: int = 1,
obs_horizon: int = 2,
pred_horizon: int = 16,
use_smoothing: bool = False,
smooth_method: str = 'ema',
smooth_alpha: float = 0.3,
dataset_stats: dict = None
):
self.agent = agent.to(device)
self.device = device
self.camera_names = camera_names
self.num_queries = num_queries
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
# Dataset statistics for normalization/denormalization
self.stats = dataset_stats
if self.stats is not None:
self.normalization_type = self.stats.get('normalization_type', 'gaussian')
self.qpos_mean = torch.tensor(self.stats['qpos_mean'], dtype=torch.float32)
self.qpos_std = torch.tensor(self.stats['qpos_std'], dtype=torch.float32)
self.qpos_min = torch.tensor(self.stats.get('qpos_min', []), dtype=torch.float32)
self.qpos_max = torch.tensor(self.stats.get('qpos_max', []), dtype=torch.float32)
self.action_mean = torch.tensor(self.stats['action_mean'], dtype=torch.float32)
self.action_std = torch.tensor(self.stats['action_std'], dtype=torch.float32)
self.action_min = torch.tensor(self.stats.get('action_min', []), dtype=torch.float32)
self.action_max = torch.tensor(self.stats.get('action_max', []), dtype=torch.float32)
else:
self.normalization_type = None
# Action smoothing
self.use_smoothing = use_smoothing
self.smooth_method = smooth_method
self.smooth_alpha = smooth_alpha
self.smoother = ActionSmoother(
action_dim=16,
method=smooth_method,
alpha=smooth_alpha
) if use_smoothing else None
# Observation buffer for obs_horizon
self.obs_buffer = {
'images': {cam: [] for cam in camera_names},
'qpos': []
}
self.cached_actions = None
self.query_step = 0
# Timing statistics
self.inference_times = [] # Model inference time only
self.total_times = [] # Total prediction time (including preprocessing)
def reset(self):
"""Reset evaluator state"""
self.obs_buffer = {
'images': {cam: [] for cam in self.camera_names},
'qpos': []
}
self.cached_actions = None
self.query_step = 0
if self.smoother is not None:
self.smoother.reset()
# Reset timing stats for each episode
self.inference_times = []
self.total_times = []
def _get_image_dict(self, obs: Dict) -> Dict[str, torch.Tensor]:
images = {}
for cam_name in self.camera_names:
img = obs['images'][cam_name]
img = rearrange(img, 'h w c -> c h w')
img = torch.from_numpy(img / 255.0).float()
images[cam_name] = img
image_dict = {}
for cam_name in self.camera_names:
cam_images = self.obs_buffer['images'][cam_name]
cam_images.append(images[cam_name])
while len(cam_images) < self.obs_horizon:
cam_images.insert(0, cam_images[0])
if len(cam_images) > self.obs_horizon:
cam_images = cam_images[-self.obs_horizon:]
img_tensor = torch.stack(cam_images, dim=0).unsqueeze(0)
image_dict[cam_name] = img_tensor
self.obs_buffer['images'][cam_name] = cam_images[-self.obs_horizon:]
return image_dict
def _get_qpos_dict(self, obs: Dict) -> torch.Tensor:
qpos = obs['qpos']
qpos = torch.from_numpy(qpos).float()
self.obs_buffer['qpos'].append(qpos)
while len(self.obs_buffer['qpos']) < self.obs_horizon:
self.obs_buffer['qpos'].insert(0, self.obs_buffer['qpos'][0])
if len(self.obs_buffer['qpos']) > self.obs_horizon:
self.obs_buffer['qpos'] = self.obs_buffer['qpos'][-self.obs_horizon:]
qpos_tensor = torch.stack(self.obs_buffer['qpos'], dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
# Normalize qpos
if self.stats is not None:
if self.normalization_type == 'gaussian':
qpos_tensor = (qpos_tensor - self.qpos_mean) / self.qpos_std
else: # min_max: normalize to [-1, 1]
qpos_tensor = 2 * (qpos_tensor - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
return qpos_tensor
@torch.no_grad()
def predict_action(self, obs: Dict) -> np.ndarray:
start_total = time.time()
images = self._get_image_dict(obs)
qpos = self._get_qpos_dict(obs)
if self.cached_actions is None or self.query_step % self.num_queries == 0:
images = {k: v.to(self.device) for k, v in images.items()}
qpos = qpos.to(self.device)
# Measure pure model inference time
start_inference = time.time()
predicted_actions = self.agent.predict_action(
images=images,
proprioception=qpos
)
# Synchronize CUDA if using GPU to get accurate timing
if self.device == 'cuda':
torch.cuda.synchronize()
end_inference = time.time()
inference_time = end_inference - start_inference
self.inference_times.append(inference_time)
# Denormalize actions
if self.stats is not None:
if self.normalization_type == 'gaussian':
predicted_actions = predicted_actions * self.action_std.to(self.device) + self.action_mean.to(self.device)
else: # min_max
predicted_actions = (predicted_actions + 1) / 2 * (self.action_max.to(self.device) - self.action_min.to(self.device)) + self.action_min.to(self.device)
self.cached_actions = predicted_actions.squeeze(0).cpu().numpy()
self.query_step = 0
raw_action = self.cached_actions[self.query_step]
self.query_step += 1
if self.smoother is not None:
raw_action = self.smoother.smooth(raw_action)
end_total = time.time()
total_time = end_total - start_total
self.total_times.append(total_time)
return raw_action
def get_timing_stats(self) -> Dict:
"""Get timing statistics"""
if len(self.inference_times) == 0:
return {
'inference_fps': 0.0,
'control_fps': 0.0,
'avg_inference_time_ms': 0.0,
'avg_total_time_ms': 0.0
}
avg_inference_time = np.mean(self.inference_times)
avg_total_time = np.mean(self.total_times)
return {
'inference_fps': 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0,
'control_fps': 1.0 / avg_total_time if avg_total_time > 0 else 0.0,
'avg_inference_time_ms': avg_inference_time * 1000,
'avg_total_time_ms': avg_total_time * 1000,
'num_inferences': len(self.inference_times),
'num_steps': len(self.total_times)
}
class ActionSmoother:
"""Action smoothing for smoother execution"""
def __init__(self, action_dim: int, method: str = 'ema', alpha: float = 0.3):
self.action_dim = action_dim
self.method = method
self.alpha = alpha
self.prev_action = None
def smooth(self, action: np.ndarray) -> np.ndarray:
if self.method == 'ema':
if self.prev_action is None:
smoothed = action
else:
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
self.prev_action = smoothed
return smoothed
else:
return action
def reset(self):
self.prev_action = None
def load_checkpoint( def load_checkpoint(
ckpt_path: str, ckpt_path: str,
agent_cfg: DictConfig, agent_cfg: DictConfig,
device: str = 'cuda' device: str = 'cuda'
) -> torch.nn.Module: ) -> torch.nn.Module:
""" """
Load trained VLA model from checkpoint using Hydra agent config. 从检查点加载训练好的 VLA 模型,使用 Hydra agent 配置。
Args: Args:
ckpt_path: Path to checkpoint file (.pt) ckpt_path: 检查点文件路径 (.pt)
agent_cfg: Hydra agent config for instantiation agent_cfg: Hydra agent 配置,用于实例化
device: Device to load model on device: 加载模型的设备
Returns: Returns:
Loaded VLAAgent model 加载后的 VLAAgent 模型
""" """
from pathlib import Path as PathLib from pathlib import Path as PathLib
ckpt_path = PathLib(ckpt_path).absolute() ckpt_path = PathLib(ckpt_path).absolute()
if not ckpt_path.exists(): if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") raise FileNotFoundError(f"检查点未找到: {ckpt_path}")
log.info(f"Loading checkpoint from {ckpt_path}") log.info(f" {ckpt_path} 加载检查点")
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
log.info(f"Checkpoint keys: {checkpoint.keys()}") log.info(f"检查点键值: {checkpoint.keys()}")
# Instantiate agent from Hydra config # 加载数据集统计信息用于归一化
log.info("Instantiating agent from config...")
agent = instantiate(agent_cfg)
# Load model state
agent.load_state_dict(checkpoint['model_state_dict'])
log.info(f"✅ Model state loaded (step: {checkpoint.get('step', 'unknown')})")
# Load dataset statistics for denormalization
stats = checkpoint.get('dataset_stats', None) stats = checkpoint.get('dataset_stats', None)
# 使用数据集统计信息从 Hydra 配置实例化 agent
log.info("从配置实例化 agent...")
agent = instantiate(agent_cfg, dataset_stats=stats)
# 加载模型状态
agent.load_state_dict(checkpoint['model_state_dict'])
log.info(f"✅ 模型状态已加载 (步数: {checkpoint.get('step', 'unknown')})")
if stats is not None: if stats is not None:
log.info(f"Dataset statistics loaded (normalization: {stats.get('normalization_type', 'gaussian')})") log.info(f"数据集统计信息已加载 (归一化: {stats.get('normalization_type', 'gaussian')})")
else: else:
# Fallback: try external JSON file (兼容旧 checkpoint) # 后备方案:尝试从外部 JSON 文件加载(兼容旧检查点)
stats_path = ckpt_path.parent / 'dataset_stats.json' stats_path = ckpt_path.parent / 'dataset_stats.json'
if stats_path.exists(): if stats_path.exists():
with open(stats_path, 'r') as f: with open(stats_path, 'r') as f:
stats = json.load(f) stats = json.load(f)
log.info("Dataset statistics loaded from external JSON (legacy)") log.info("数据集统计信息已从外部 JSON 加载(旧版本兼容)")
else: else:
log.warning("⚠️ No dataset statistics found. Actions will not be denormalized!") log.warning("⚠️ 未找到数据集统计信息。动作将无法反归一化!")
agent.eval() agent.eval()
agent.to(device) agent.to(device)
log.info(f"Model loaded successfully on {device}") log.info(f"模型已成功加载到 {device}")
return agent, stats return agent, stats
def prepare_observation(obs: Dict, camera_names: list) -> Dict:
"""
将环境观测转换为 agent 格式。
Args:
obs: 环境观测字典,包含图像和 qpos
camera_names: 摄像头名称列表
Returns:
agent 格式的观测字典
"""
# 转换图像: numpy -> tensor, HWC -> CHW
images = {}
for cam_name in camera_names:
img = obs['images'][cam_name]
img = rearrange(img, 'h w c -> c h w')
img = torch.from_numpy(img / 255.0).float()
images[cam_name] = img
# 转换 qpos: numpy -> tensor
qpos = torch.from_numpy(obs['qpos']).float()
return {'qpos': qpos, 'images': images}
class ActionSmoother:
"""
动作平滑器(指数移动平均)
用于平滑执行动作以获得更稳定的控制
"""
def __init__(self, alpha: float = 0.3):
"""
Args:
alpha: 平滑系数 (0-1),值越大越重视当前动作
"""
self.alpha = alpha
self.prev_action = None
def smooth(self, action: np.ndarray) -> np.ndarray:
"""
平滑动作
Args:
action: 当前动作
Returns:
平滑后的动作
"""
if self.prev_action is None:
smoothed = action
else:
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
self.prev_action = smoothed
return smoothed
def reset(self):
"""重置平滑器状态"""
self.prev_action = None
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config") @hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
def main(cfg: DictConfig): def main(cfg: DictConfig):
""" """
VLA Evaluation Script with Hydra Configuration. 使用 agent 内置队列管理的简化版 VLA 评估
All eval parameters come from vla/conf/eval.yaml, merged into cfg. 所有评估参数来自 vla/conf/eval.yaml,合并到 cfg 中。
Override on command line: python eval_vla.py eval.ckpt_path=... eval.num_episodes=5 命令行覆盖: python eval_vla_simple.py eval.ckpt_path=... eval.num_episodes=5
""" """
# Print configuration # 打印配置
print("=" * 80) print("=" * 80)
print("VLA Evaluation Configuration:") print("VLA 评估配置:")
print("=" * 80) print("=" * 80)
print(OmegaConf.to_yaml(cfg)) print(OmegaConf.to_yaml(cfg))
print("=" * 80) print("=" * 80)
@@ -335,67 +173,114 @@ def main(cfg: DictConfig):
device = eval_cfg.device device = eval_cfg.device
camera_names = list(eval_cfg.camera_names) camera_names = list(eval_cfg.camera_names)
# Load model # =========================================================================
log.info(f"🚀 Loading model from {eval_cfg.ckpt_path}...") # 加载模型
# =========================================================================
log.info(f"🚀 从 {eval_cfg.ckpt_path} 加载模型...")
agent, dataset_stats = load_checkpoint( agent, dataset_stats = load_checkpoint(
ckpt_path=eval_cfg.ckpt_path, ckpt_path=eval_cfg.ckpt_path,
agent_cfg=cfg.agent, agent_cfg=cfg.agent,
device=device device=device
) )
# Create evaluator # 重置 agent 的队列
evaluator = VLAEvaluator( agent.reset()
agent=agent,
device=device,
camera_names=camera_names,
num_queries=eval_cfg.num_queries,
obs_horizon=eval_cfg.obs_horizon,
use_smoothing=eval_cfg.use_smoothing,
smooth_method=eval_cfg.smooth_method,
smooth_alpha=eval_cfg.smooth_alpha,
dataset_stats=dataset_stats
)
# Create environment # 可选:动作平滑器
smoother = ActionSmoother(alpha=eval_cfg.smooth_alpha) if eval_cfg.use_smoothing else None
# =========================================================================
# 创建环境
# =========================================================================
env = make_sim_env(eval_cfg.task_name) env = make_sim_env(eval_cfg.task_name)
# Run episodes # =========================================================================
# 运行评估回合
# =========================================================================
all_stats = [] all_stats = []
for episode_idx in range(eval_cfg.num_episodes): for episode_idx in range(eval_cfg.num_episodes):
print(f"\n{'='*60}") print(f"\n{'='*60}")
print(f"Episode {episode_idx + 1}/{eval_cfg.num_episodes}") print(f"回合 {episode_idx + 1}/{eval_cfg.num_episodes}")
print(f"{'='*60}\n") print(f"{'='*60}\n")
box_pos = sample_transfer_pose() box_pos = sample_transfer_pose()
env.reset(box_pos) env.reset(box_pos)
evaluator.reset()
# 为新回合重置 agent 队列
agent.reset()
if smoother:
smoother.reset()
# 计时统计
inference_times = []
total_times = []
with torch.inference_mode(): with torch.inference_mode():
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"Episode {episode_idx + 1}"): for t in tqdm(range(eval_cfg.max_timesteps), desc=f"回合 {episode_idx + 1}"):
start_total = time.time()
# 从环境获取观测
obs = env._get_image_obs() obs = env._get_image_obs()
qpos_obs = env._get_qpos_obs() qpos_obs = env._get_qpos_obs()
obs['qpos'] = qpos_obs['qpos'] obs['qpos'] = qpos_obs['qpos']
action = evaluator.predict_action(obs) # 准备给 agent 的观测
env.step_jnt(action) observation = prepare_observation(obs, camera_names)
# 选择动作agent 内部处理队列管理)
start_inference = time.time()
action = agent.select_action(observation)
if device == 'cuda':
torch.cuda.synchronize()
end_inference = time.time()
# 转换为 numpy
action = action.cpu().numpy()
# 可选:平滑动作
if smoother:
action = smoother.smooth(action)
# 执行动作
env.step_jnt(action)
env.render() env.render()
# Get timing statistics for this episode end_total = time.time()
stats = evaluator.get_timing_stats()
# 记录计时
inference_times.append(end_inference - start_inference)
total_times.append(end_total - start_total)
# =========================================================================
# 打印回合统计
# =========================================================================
avg_inference_time = np.mean(inference_times)
avg_total_time = np.mean(total_times)
stats = {
'inference_fps': 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0,
'control_fps': 1.0 / avg_total_time if avg_total_time > 0 else 0.0,
'avg_inference_time_ms': avg_inference_time * 1000,
'avg_total_time_ms': avg_total_time * 1000,
'num_inferences': len([t for t in inference_times if t > 0.001]), # 统计实际推理次数
'num_steps': len(total_times)
}
all_stats.append(stats) all_stats.append(stats)
print(f"\nEpisode {episode_idx + 1} completed ({eval_cfg.max_timesteps} timesteps)") print(f"\n回合 {episode_idx + 1} 完成 ({eval_cfg.max_timesteps} 时间步)")
print(f" Model Inference FPS: {stats['inference_fps']:.2f} Hz") print(f" 模型推理 FPS: {stats['inference_fps']:.2f} Hz")
print(f" Control Loop FPS: {stats['control_fps']:.2f} Hz") print(f" 控制循环 FPS: {stats['control_fps']:.2f} Hz")
print(f" Avg Inference Time: {stats['avg_inference_time_ms']:.2f} ms") print(f" 平均推理时间: {stats['avg_inference_time_ms']:.2f} ms")
print(f" Avg Total Time: {stats['avg_total_time_ms']:.2f} ms") print(f" 平均总时间: {stats['avg_total_time_ms']:.2f} ms")
print(f" Total Inferences: {stats['num_inferences']}") print(f" 总推理次数: {stats['num_inferences']}")
# Print overall statistics # =========================================================================
# 总体统计
# =========================================================================
print(f"\n{'='*60}") print(f"\n{'='*60}")
print("Evaluation complete!") print("评估完成!")
print(f"{'='*60}") print(f"{'='*60}")
if all_stats: if all_stats:
@@ -404,11 +289,11 @@ def main(cfg: DictConfig):
avg_inference_time = np.mean([s['avg_inference_time_ms'] for s in all_stats]) avg_inference_time = np.mean([s['avg_inference_time_ms'] for s in all_stats])
avg_total_time = np.mean([s['avg_total_time_ms'] for s in all_stats]) avg_total_time = np.mean([s['avg_total_time_ms'] for s in all_stats])
print(f"\nOverall Statistics ({eval_cfg.num_episodes} episodes):") print(f"\n总体统计 ({eval_cfg.num_episodes} 个回合):")
print(f" Average Model Inference FPS: {avg_inference_fps:.2f} Hz") print(f" 平均模型推理 FPS: {avg_inference_fps:.2f} Hz")
print(f" Average Control Loop FPS: {avg_control_fps:.2f} Hz") print(f" 平均控制循环 FPS: {avg_control_fps:.2f} Hz")
print(f" Average Inference Time: {avg_inference_time:.2f} ms") print(f" 平均推理时间: {avg_inference_time:.2f} ms")
print(f" Average Total Time: {avg_total_time:.2f} ms") print(f" 平均总时间: {avg_total_time:.2f} ms")
print() print()

View File

@@ -12,28 +12,28 @@ from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from pathlib import Path from pathlib import Path
# Ensure correct import path # 确保正确的导入路径
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from hydra.utils import instantiate from hydra.utils import instantiate
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}}) # 注册列表长度解析器(用于配置中如 ${len:${data.camera_names}}
if not OmegaConf.has_resolver("len"): if not OmegaConf.has_resolver("len"):
OmegaConf.register_new_resolver("len", lambda x: len(x)) OmegaConf.register_new_resolver("len", lambda x: len(x))
def recursive_to_device(data, device): def recursive_to_device(data, device):
""" """
Recursively move nested dictionaries/lists of tensors to specified device. 递归地将嵌套字典/列表中的张量移动到指定设备。
Args: Args:
data: Dictionary, list, or tensor data: 字典、列表或张量
device: Target device (e.g., 'cuda', 'cpu') device: 目标设备 (例如 'cuda', 'cpu')
Returns: Returns:
Data structure with all tensors moved to device 所有张量已移动到指定设备的数据结构
""" """
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return data.to(device) return data.to(device)
@@ -46,36 +46,36 @@ def recursive_to_device(data, device):
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0): def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
""" """
Create a learning rate scheduler with warmup. 创建带预热的学习率调度器。
Args: Args:
optimizer: PyTorch optimizer optimizer: PyTorch 优化器
warmup_steps: Number of warmup steps warmup_steps: 预热步数
max_steps: Total training steps max_steps: 总训练步数
scheduler_type: Type of scheduler after warmup ('cosine' or 'constant') scheduler_type: 预热后的调度器类型 ('cosine' 'constant')
min_lr: Minimum learning rate (for cosine decay) min_lr: 最小学习率(用于余弦衰减)
Returns: Returns:
LambdaLR scheduler LambdaLR 调度器
""" """
import math import math
# Capture initial lr before LambdaLR modifies it # 在 LambdaLR 修改前捕获初始学习率
base_lr = optimizer.param_groups[0]['lr'] base_lr = optimizer.param_groups[0]['lr']
min_lr_ratio = min_lr / base_lr if base_lr > 0 else 0.0 min_lr_ratio = min_lr / base_lr if base_lr > 0 else 0.0
def lr_lambda(step): def lr_lambda(step):
# Warmup phase: linear increase from 0 to 1 # 预热阶段:从 0 线性增加到 1
if step < warmup_steps: if step < warmup_steps:
return float(step) / float(max(1, warmup_steps)) return float(step) / float(max(1, warmup_steps))
# Post-warmup phase # 预热后阶段
if scheduler_type == 'cosine': if scheduler_type == 'cosine':
# Cosine annealing from 1 to min_lr_ratio # 1 min_lr_ratio 的余弦退火
progress = float(step - warmup_steps) / float(max(1, max_steps - warmup_steps)) progress = float(step - warmup_steps) / float(max(1, max_steps - warmup_steps))
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return max(min_lr_ratio, cosine_decay) return max(min_lr_ratio, cosine_decay)
else: else:
# Constant learning rate # 恒定学习率
return 1.0 return 1.0
return LambdaLR(optimizer, lr_lambda) return LambdaLR(optimizer, lr_lambda)
@@ -84,40 +84,40 @@ def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_ty
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config") @hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
def main(cfg: DictConfig): def main(cfg: DictConfig):
""" """
VLA Training Script with ResNet Backbone and Diffusion Policy. VLA 训练脚本ResNet 骨干网络 + Diffusion 策略)
This script: 该脚本功能:
1. Loads dataset from HDF5 files 1. 从 HDF5 文件加载数据集
2. Instantiates VLAAgent with ResNet vision encoder 2. 实例化带 ResNet 视觉编码器的 VLAAgent
3. Trains diffusion-based action prediction 3. 训练基于扩散的动作预测模型
4. Saves checkpoints periodically 4. 定期保存检查点
""" """
# Print configuration # 打印配置
print("=" * 80) print("=" * 80)
print("VLA Training Configuration:") print("VLA 训练配置:")
print("=" * 80) print("=" * 80)
print(OmegaConf.to_yaml(cfg)) print(OmegaConf.to_yaml(cfg))
print("=" * 80) print("=" * 80)
log.info(f"🚀 Starting VLA Training (Device: {cfg.train.device})") log.info(f"🚀 开始 VLA 训练 (设备: {cfg.train.device})")
# Create checkpoint directory # 创建检查点目录
checkpoint_dir = Path("checkpoints") checkpoint_dir = Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True) checkpoint_dir.mkdir(exist_ok=True)
# ========================================================================= # =========================================================================
# 1. Instantiate Dataset & DataLoader # 1. 实例化数据集与 DataLoader
# ========================================================================= # =========================================================================
log.info("📦 Loading dataset...") log.info("📦 加载数据集...")
try: try:
dataset = instantiate(cfg.data) dataset = instantiate(cfg.data)
log.info(f"Dataset loaded successfully. Total samples: {len(dataset)}") log.info(f"数据集加载成功。总样本数: {len(dataset)}")
except Exception as e: except Exception as e:
log.error(f"Failed to load dataset: {e}") log.error(f"数据集加载失败: {e}")
raise raise
# Train/Val split # 训练/验证集划分
val_split = float(cfg.train.get('val_split', 0.1)) val_split = float(cfg.train.get('val_split', 0.1))
seed = int(cfg.train.get('seed', 42)) seed = int(cfg.train.get('seed', 42))
val_size = int(len(dataset) * val_split) val_size = int(len(dataset) * val_split)
@@ -128,10 +128,10 @@ def main(cfg: DictConfig):
[train_size, val_size], [train_size, val_size],
generator=torch.Generator().manual_seed(seed) generator=torch.Generator().manual_seed(seed)
) )
log.info(f"Dataset split: train={train_size}, val={val_size} (val_split={val_split})") log.info(f"数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})")
else: else:
train_dataset, val_dataset = dataset, None train_dataset, val_dataset = dataset, None
log.info("Dataset split: train=all, val=0 (val_split=0)") log.info("数据集划分: 全部用于训练, 验证集=0 (验证比例=0)")
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
@@ -139,7 +139,7 @@ def main(cfg: DictConfig):
shuffle=True, shuffle=True,
num_workers=cfg.train.num_workers, num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"), pin_memory=(cfg.train.device != "cpu"),
drop_last=True # Drop incomplete batches for stable training drop_last=True # 丢弃不完整批次以稳定训练
) )
val_loader = None val_loader = None
@@ -153,34 +153,14 @@ def main(cfg: DictConfig):
drop_last=False drop_last=False
) )
log.info(f"Train loader batches per epoch: {len(train_loader)}") log.info(f"训练加载器每轮批次数: {len(train_loader)}")
if val_loader is not None: if val_loader is not None:
log.info(f"Val loader batches per epoch: {len(val_loader)}") log.info(f"验证加载器每轮批次数: {len(val_loader)}")
# ========================================================================= # =========================================================================
# 2. Instantiate VLA Agent # 2. 加载数据集统计信息(将传递给 agent
# ========================================================================= # =========================================================================
log.info("🤖 Initializing VLA Agent...") log.info("💾 加载数据集统计信息...")
try:
agent = instantiate(cfg.agent)
agent.to(cfg.train.device)
agent.train()
log.info(f"✅ Agent initialized and moved to {cfg.train.device}")
# Count parameters
total_params = sum(p.numel() for p in agent.parameters())
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
log.info(f"📊 Total parameters: {total_params:,}")
log.info(f"📊 Trainable parameters: {trainable_params:,}")
except Exception as e:
log.error(f"❌ Failed to initialize agent: {e}")
raise
# =========================================================================
# 2.5. Load Dataset Statistics (will be saved into checkpoints)
# =========================================================================
log.info("💾 Loading dataset statistics...")
dataset_stats = None dataset_stats = None
try: try:
dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer') dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer')
@@ -201,22 +181,43 @@ def main(cfg: DictConfig):
'qpos_min': stats['qpos']['min'].tolist(), 'qpos_min': stats['qpos']['min'].tolist(),
'qpos_max': stats['qpos']['max'].tolist(), 'qpos_max': stats['qpos']['max'].tolist(),
} }
log.info(f"Dataset statistics loaded (normalization: {dataset_stats['normalization_type']})") log.info(f"数据集统计信息加载完成 (归一化: {dataset_stats['normalization_type']})")
else: else:
log.warning(f"⚠️ Statistics file not found: {stats_path}") log.warning(f"⚠️ 统计文件未找到: {stats_path}")
log.warning("⚠️ Actions will not be denormalized during inference!") log.warning("⚠️ 推理时动作将无法反归一化!")
except Exception as e: except Exception as e:
log.warning(f"⚠️ Failed to load statistics: {e}") log.warning(f"⚠️ 统计信息加载失败: {e}")
log.warning("⚠️ Training will continue, but inference may not work correctly") log.warning("⚠️ 训练将继续,但推理可能无法正常工作")
# ========================================================================= # =========================================================================
# 3. Setup Optimizer & LR Scheduler # 3. 实例化 VLA Agent
# =========================================================================
log.info("🤖 初始化 VLA Agent...")
try:
# 将 dataset_stats 和 normalization_type 传递给 agent
agent = instantiate(cfg.agent, dataset_stats=dataset_stats)
agent.to(cfg.train.device)
agent.train()
log.info(f"✅ Agent 初始化完成并已移至 {cfg.train.device}")
# 统计参数量
total_params = sum(p.numel() for p in agent.parameters())
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
log.info(f"📊 总参数量: {total_params:,}")
log.info(f"📊 可训练参数量: {trainable_params:,}")
except Exception as e:
log.error(f"❌ Agent 初始化失败: {e}")
raise
# =========================================================================
# 4. 设置优化器与学习率调度器
# ========================================================================= # =========================================================================
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5) optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})") log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr})")
# Setup learning rate scheduler with warmup # 设置带预热的学習率调度器
warmup_steps = int(cfg.train.get('warmup_steps', 500)) warmup_steps = int(cfg.train.get('warmup_steps', 500))
scheduler_type = cfg.train.get('scheduler_type', 'cosine') scheduler_type = cfg.train.get('scheduler_type', 'cosine')
min_lr = float(cfg.train.get('min_lr', 1e-6)) min_lr = float(cfg.train.get('min_lr', 1e-6))
@@ -228,33 +229,36 @@ def main(cfg: DictConfig):
scheduler_type=scheduler_type, scheduler_type=scheduler_type,
min_lr=min_lr min_lr=min_lr
) )
log.info(f"📈 LR Scheduler: {scheduler_type} with {warmup_steps} warmup steps (min_lr={min_lr})") log.info(f"📈 学习率调度器: {scheduler_type}{warmup_steps} 步预热 (最小学习率={min_lr})")
# ========================================================================= # =========================================================================
# 4. Training Loop # 5. 训练循环
# ========================================================================= # =========================================================================
log.info("🏋️ Starting training loop...") log.info("🏋️ 开始训练循环...")
def build_agent_input(batch_data): def build_agent_input(batch_data):
"""构建 agent 输入格式"""
images = {} images = {}
# SimpleRobotDataset 返回 observation.{cam_name} 格式
for cam_name in cfg.data.camera_names: for cam_name in cfg.data.camera_names:
key = f"image_{cam_name}" key = f"observation.{cam_name}"
if key in batch_data: if key in batch_data:
images[cam_name] = batch_data[key] images[cam_name] = batch_data[key]
return { return {
'images': images, 'images': images,
'qpos': batch_data['qpos'], 'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
'action': batch_data['action'] 'action': batch_data['action']
} }
def run_validation(): def run_validation():
"""运行验证"""
if val_loader is None: if val_loader is None:
return None return None
agent.eval() agent.eval()
# 🔧 FIX: Set deterministic seed for validation to get reproducible loss # 设置确定性种子以获得可重现的损失
# This ensures validation loss is comparable across different steps # 这确保验证损失在不同步骤之间可比较
torch.manual_seed(42) torch.manual_seed(42)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
@@ -272,7 +276,7 @@ def main(cfg: DictConfig):
return total_loss / max(num_batches, 1) return total_loss / max(num_batches, 1)
data_iter = iter(train_loader) data_iter = iter(train_loader)
pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100) pbar = tqdm(range(cfg.train.max_steps), desc="训练中", ncols=100)
best_loss = float('inf') best_loss = float('inf')
@@ -280,47 +284,47 @@ def main(cfg: DictConfig):
try: try:
batch = next(data_iter) batch = next(data_iter)
except StopIteration: except StopIteration:
# Restart iterator when epoch ends # 轮次结束时重启迭代器
data_iter = iter(train_loader) data_iter = iter(train_loader)
batch = next(data_iter) batch = next(data_iter)
# ===================================================================== # =====================================================================
# Move batch to device # 将批次移至设备
# ===================================================================== # =====================================================================
batch = recursive_to_device(batch, cfg.train.device) batch = recursive_to_device(batch, cfg.train.device)
# ===================================================================== # =====================================================================
# Prepare agent input # 准备 agent 输入
# ===================================================================== # =====================================================================
# Dataset returns: {action, qpos, image_<cam_name>, ...} # 数据集返回: {action, qpos, image_<cam_name>, ...}
# Agent expects: {images: dict, qpos: tensor, action: tensor} # Agent 期望: {images: dict, qpos: tensor, action: tensor}
# Prepare agent input # 准备 agent 输入
agent_input = build_agent_input(batch) agent_input = build_agent_input(batch)
# ===================================================================== # =====================================================================
# Forward pass & compute loss # 前向传播与损失计算
# ===================================================================== # =====================================================================
try: try:
loss = agent.compute_loss(agent_input) loss = agent.compute_loss(agent_input)
except Exception as e: except Exception as e:
log.error(f"Forward pass failed at step {step}: {e}") log.error(f"步骤 {step} 前向传播失败: {e}")
raise raise
# ===================================================================== # =====================================================================
# Backward pass & optimization # 反向传播与优化
# ===================================================================== # =====================================================================
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
# Gradient clipping for stable training # 梯度裁剪以稳定训练
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0) torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
# ===================================================================== # =====================================================================
# Logging # 日志记录
# ===================================================================== # =====================================================================
if step % cfg.train.log_freq == 0: if step % cfg.train.log_freq == 0:
current_lr = optimizer.param_groups[0]['lr'] current_lr = optimizer.param_groups[0]['lr']
@@ -329,16 +333,16 @@ def main(cfg: DictConfig):
"lr": f"{current_lr:.2e}", "lr": f"{current_lr:.2e}",
"best_loss": f"{best_loss:.4f}" "best_loss": f"{best_loss:.4f}"
}) })
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f} | LR: {current_lr:.2e}") log.info(f"步骤 {step}/{cfg.train.max_steps} | 损失: {loss.item():.4f} | 学习率: {current_lr:.2e}")
# ===================================================================== # =====================================================================
# Checkpoint saving & Validation # 检查点保存与验证
# ===================================================================== # =====================================================================
if step > 0 and step % cfg.train.save_freq == 0: if step > 0 and step % cfg.train.save_freq == 0:
# Run validation # 运行验证
val_loss = run_validation() val_loss = run_validation()
if val_loss is not None: if val_loss is not None:
log.info(f"Step {step}/{cfg.train.max_steps} | Val Loss: {val_loss:.4f}") log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}")
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt" checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
torch.save({ torch.save({
@@ -351,9 +355,9 @@ def main(cfg: DictConfig):
'dataset_stats': dataset_stats, 'dataset_stats': dataset_stats,
'current_lr': optimizer.param_groups[0]['lr'], 'current_lr': optimizer.param_groups[0]['lr'],
}, checkpoint_path) }, checkpoint_path)
log.info(f"💾 Checkpoint saved: {checkpoint_path}") log.info(f"💾 检查点已保存: {checkpoint_path}")
# Save best model based on validation loss # 根据验证损失保存最佳模型
eval_loss = val_loss if val_loss is not None else loss.item() eval_loss = val_loss if val_loss is not None else loss.item()
if eval_loss < best_loss: if eval_loss < best_loss:
best_loss = eval_loss best_loss = eval_loss
@@ -368,10 +372,10 @@ def main(cfg: DictConfig):
'dataset_stats': dataset_stats, 'dataset_stats': dataset_stats,
'current_lr': optimizer.param_groups[0]['lr'], 'current_lr': optimizer.param_groups[0]['lr'],
}, best_model_path) }, best_model_path)
log.info(f"🌟 Best model updated: {best_model_path} (val_loss: {best_loss:.4f})") log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})")
# ========================================================================= # =========================================================================
# 5. Save Final Model # 6. 保存最终模型
# ========================================================================= # =========================================================================
final_model_path = checkpoint_dir / "vla_model_final.pt" final_model_path = checkpoint_dir / "vla_model_final.pt"
torch.save({ torch.save({
@@ -383,11 +387,11 @@ def main(cfg: DictConfig):
'dataset_stats': dataset_stats, 'dataset_stats': dataset_stats,
'current_lr': optimizer.param_groups[0]['lr'], 'current_lr': optimizer.param_groups[0]['lr'],
}, final_model_path) }, final_model_path)
log.info(f"💾 Final model saved: {final_model_path}") log.info(f"💾 最终模型已保存: {final_model_path}")
log.info("Training completed successfully!") log.info("训练成功完成!")
log.info(f"📊 Final Loss: {loss.item():.4f}") log.info(f"📊 最终损失: {loss.item():.4f}")
log.info(f"📊 Best Loss: {best_loss:.4f}") log.info(f"📊 最佳损失: {best_loss:.4f}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,17 +1,19 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from typing import Dict, Optional, Any from collections import deque
from typing import Dict, Optional, Any, Tuple
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D
from roboimi.vla.models.normalization import NormalizationModule
class VLAAgent(nn.Module): class VLAAgent(nn.Module):
def __init__( def __init__(
self, self,
vision_backbone, # 你之前定义的 ResNet vision_backbone, # 视觉编码器(ResNet 等)
state_encoder, state_encoder,
action_encoder, action_encoder,
head, head,
@@ -19,23 +21,35 @@ class VLAAgent(nn.Module):
obs_dim, # 本体感知维度 (例如 关节角度) obs_dim, # 本体感知维度 (例如 关节角度)
pred_horizon=16, # 预测未来多少步动作 pred_horizon=16, # 预测未来多少步动作
obs_horizon=4, # 使用多少步历史观测 obs_horizon=4, # 使用多少步历史观测
diffusion_steps=100, diffusion_steps=100, # DDPM 加噪步数
inference_steps=10, # DDIM 推理步数
num_cams=3, # 视觉输入的摄像头数量 num_cams=3, # 视觉输入的摄像头数量
dataset_stats=None, # 数据集统计信息,用于归一化
normalization_type='gaussian', # 归一化类型: 'gaussian' 或 'min_max'
num_action_steps=1, # 每次推理实际执行多少步动作
): ):
super().__init__() super().__init__()
# Store parameters # 保存参数
self.action_dim = action_dim self.action_dim = action_dim
self.obs_dim = obs_dim self.obs_dim = obs_dim
self.pred_horizon = pred_horizon self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon self.obs_horizon = obs_horizon
self.num_cams = num_cams self.num_cams = num_cams
self.num_action_steps = num_action_steps
self.inference_steps = inference_steps
# 归一化模块 - 统一训练和推理的归一化逻辑
self.normalization = NormalizationModule(
stats=dataset_stats,
normalization_type=normalization_type
)
self.vision_encoder = vision_backbone self.vision_encoder = vision_backbone
single_img_feat_dim = self.vision_encoder.output_dim single_cam_feat_dim = self.vision_encoder.output_dim
total_vision_dim = single_img_feat_dim * num_cams * obs_horizon total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
total_prop_dim = obs_dim * obs_horizon total_prop_dim = obs_dim * obs_horizon
self.global_cond_dim = total_vision_dim + total_prop_dim self.global_cond_dim = total_vision_dim + total_prop_dim
# self.global_cond_dim = total_vision_dim
self.noise_scheduler = DDPMScheduler( self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps, num_train_timesteps=diffusion_steps,
@@ -44,7 +58,7 @@ class VLAAgent(nn.Module):
prediction_type='epsilon' # 预测噪声 prediction_type='epsilon' # 预测噪声
) )
# DDIM scheduler for faster inference # DDIM 调度器用于快速推理
self.infer_scheduler = DDIMScheduler( self.infer_scheduler = DDIMScheduler(
num_train_timesteps=diffusion_steps, num_train_timesteps=diffusion_steps,
beta_schedule='squaredcos_cap_v2', beta_schedule='squaredcos_cap_v2',
@@ -54,45 +68,55 @@ class VLAAgent(nn.Module):
self.noise_pred_net = head( self.noise_pred_net = head(
input_dim=action_dim, input_dim=action_dim,
# input_dim = action_dim + obs_dim, # input_dim = action_dim + obs_dim, # 备选:包含观测维度
global_cond_dim=self.global_cond_dim global_cond_dim=self.global_cond_dim
) )
self.state_encoder = state_encoder self.state_encoder = state_encoder
self.action_encoder = action_encoder self.action_encoder = action_encoder
# 初始化队列(用于在线推理)
self.reset()
# ========================== # ==========================
# 训练阶段 (Training) # 训练阶段 (Training)
# ========================== # ==========================
def compute_loss(self, batch): def compute_loss(self, batch):
""" """
batch: 包含 images, qpos (proprioception), action 计算训练损失
Args:
batch: 包含 images, qpos (本体感知), action 的字典
""" """
actions, states, images = batch['action'], batch['qpos'], batch['images'] actions, states, images = batch['action'], batch['qpos'], batch['images']
B = actions.shape[0] B = actions.shape[0]
# 归一化 states (qpos) 和 actions
states = self.normalization.normalize_qpos(states)
actions = self.normalization.normalize_action(actions)
state_features = self.state_encoder(states) state_features = self.state_encoder(states)
# 1. 提取视觉特征 # 1. 提取视觉特征
visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim) visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim)
action_features = self.action_encoder(actions) action_features = self.action_encoder(actions)
# 3. 采样噪声 # 2. 采样噪声
noise = torch.randn_like(action_features) noise = torch.randn_like(action_features)
# 4. 随机采样时间步 (Timesteps) # 3. 随机采样时间步 (Timesteps)
timesteps = torch.randint( timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps, 0, self.noise_scheduler.config.num_train_timesteps,
(B,), device=action_features.device (B,), device=action_features.device
).long() ).long()
# 5. 给动作加噪 (Forward Diffusion) # 4. 给动作加噪 (Forward Diffusion)
noisy_actions = self.noise_scheduler.add_noise( noisy_actions = self.noise_scheduler.add_noise(
action_features, noise, timesteps action_features, noise, timesteps
) )
# 6. 网络预测噪声 # 5. 网络预测噪声
pred_noise = self.noise_pred_net( pred_noise = self.noise_pred_net(
sample=noisy_actions, sample=noisy_actions,
timestep=timesteps, timestep=timesteps,
@@ -100,30 +124,192 @@ class VLAAgent(nn.Module):
proprioception=state_features proprioception=state_features
) )
# 7. 计算 Loss (MSE) # 6. 计算 Loss (MSE)
loss = nn.functional.mse_loss(pred_noise, noise) loss = nn.functional.mse_loss(pred_noise, noise)
return loss return loss
# ========================== # ==========================
# 推理阶段 (Inference) # 队列管理 (Queue Management)
# ==========================
def reset(self):
"""清空观测和动作队列。应在 env.reset() 时调用"""
self._queues = {
'qpos': deque(maxlen=self.obs_horizon),
'images': deque(maxlen=self.obs_horizon),
'action': deque(maxlen=self.pred_horizon - self.obs_horizon + 1), # 可执行的动作缓存
}
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
"""
将新的观测添加到队列中。
Args:
observation: 包含 'qpos''images' 的字典
"""
# 添加本体感知
if 'qpos' in observation:
self._queues['qpos'].append(observation['qpos'].clone())
# 添加图像
if 'images' in observation:
self._queues['images'].append({k: v.clone() for k, v in observation['images'].items()})
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
"""
从队列中准备用于推理的批量观测。
如果队列未满(首次调用时),用最新观测重复填充。
Returns:
batch: 包含堆叠后的历史观测的字典
"""
# 堆叠历史本体感知
qpos_list = list(self._queues['qpos'])
if len(qpos_list) == 0:
raise ValueError("观测队列为空,请先调用 _populate_queues 添加观测")
# 如果队列未满,用最后一个观测填充
while len(qpos_list) < self.obs_horizon:
qpos_list.append(qpos_list[-1])
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
# 堆叠历史图像
images_list = list(self._queues['images'])
if len(images_list) == 0:
raise ValueError("图像队列为空,请先调用 _populate_queues 添加观测")
# 如果队列未满,用最后一个观测填充
while len(images_list) < self.obs_horizon:
images_list.append(images_list[-1])
batch_images = {}
for cam_name in images_list[0].keys():
batch_images[cam_name] = torch.stack([img[cam_name] for img in images_list], dim=0).unsqueeze(0)
return {'qpos': batch_qpos, 'images': batch_images}
# ==========================
# 在线推理 (Online Inference)
# ==========================
@torch.no_grad()
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
根据当前观测选择单个动作。
这个方法维护一个历史观测和生成动作轨迹的缓存。工作流程:
- 缓存 `obs_horizon` 步的历史观测
- Diffusion 模型生成 `pred_horizon` 步的动作
- 实际执行 `num_action_steps` 步动作
示意图:
--------------------------------------------------------------
(图例: o=obs_horizon, h=pred_horizon, a=num_action_steps)
|时间步 | 0 | 1 | ... | o-1 | o | ... | h-1 |
|观测是否使用 | 是 | 是 | 是 | 是 | 否 | 否 | 否 |
|动作是否生成 | 是 | 是 | 是 | 是 | 是 | 是 | 是 |
|动作是否执行 | 否 | 否 | 否 | 否 | 是 | 是 | 是 |
--------------------------------------------------------------
Args:
observation: 包含 'qpos''images' 的字典
Returns:
action: (action_dim,) 单个动作
"""
# 检测设备并确保所有组件在同一设备上
# 尝试从观测中获取设备
device = None
for v in observation.values():
if isinstance(v, torch.Tensor):
device = v.device
break
if device is not None and self.normalization.enabled:
# 确保归一化参数在同一设备上
norm_device = self.normalization.qpos_mean.device
if device != norm_device:
self.normalization.to(device)
# 同时确保其他模块也在正确设备
self.vision_encoder.to(device)
self.state_encoder.to(device)
self.action_encoder.to(device)
self.noise_pred_net.to(device)
# 将所有 observation 移到正确设备
observation = {k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()}
# 将新观测添加到队列
self._populate_queues(observation)
# 如果动作队列为空,生成新的动作序列
if len(self._queues['action']) == 0:
# 从队列准备批量观测
batch = self._prepare_observation_batch()
# 生成动作块
actions = self.predict_action_chunk(batch) # (1, pred_horizon, action_dim)
# 提取可执行的动作部分
# 从 obs_horizon-1 开始,因为前面的动作对应过去的观测
start = self.obs_horizon - 1
end = start + self.num_action_steps
executable_actions = actions[:, start:end] # (1, num_action_steps, action_dim)
# 将动作添加到队列
for i in range(executable_actions.shape[1]):
self._queues['action'].append(executable_actions[:, i].squeeze(0)) # (action_dim,)
# 从队列中取出一个动作
action = self._queues['action'].popleft() # (action_dim,)
return action
@torch.no_grad()
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
预测一个动作块(用于在线推理)。
Args:
batch: 包含 'qpos''images' 的字典
- qpos: (B, obs_horizon, obs_dim)
- images: Dict[str, (B, obs_horizon, C, H, W)]
Returns:
actions: (B, pred_horizon, action_dim) 预测的动作序列
"""
return self.predict_action(batch['images'], batch['qpos'])
# ==========================
# 批量推理 (Batch Inference - 原有方法)
# ========================== # ==========================
@torch.no_grad() @torch.no_grad()
def predict_action(self, images, proprioception): def predict_action(self, images, proprioception):
"""
批量预测动作序列(用于训练和离线评估)
Args:
images: 图像观测字典
proprioception: 本体感知观测 (qpos)
Returns:
denormalized_actions: 反归一化后的动作序列
"""
B = proprioception.shape[0] B = proprioception.shape[0]
# 1. 提取当前观测特征 (只做一次) # 归一化 proprioception (qpos)
proprioception = self.normalization.normalize_qpos(proprioception)
# 1. 提取当前观测特征(只提取一次)
visual_features = self.vision_encoder(images) visual_features = self.vision_encoder(images)
state_features = self.state_encoder(proprioception) state_features = self.state_encoder(proprioception)
# 2. 初始化纯高斯噪声动作 # 2. 初始化纯高斯噪声动作
# Shape: (B, pred_horizon, action_dim) # 形状: (B, pred_horizon, action_dim)
device = visual_features.device device = visual_features.device
current_actions = torch.randn( current_actions = torch.randn(
(B, self.pred_horizon, self.action_dim), device=device (B, self.pred_horizon, self.action_dim), device=device
) )
# 3. 逐步去噪循环 (Reverse Diffusion) # 3. 逐步去噪循环 (Reverse Diffusion)
self.infer_scheduler.set_timesteps(10) # DDIM 推理步数 self.infer_scheduler.set_timesteps(self.inference_steps) # DDIM 推理步数
for t in self.infer_scheduler.timesteps: for t in self.infer_scheduler.timesteps:
model_input = current_actions model_input = current_actions
@@ -141,5 +327,11 @@ class VLAAgent(nn.Module):
noise_pred, t, current_actions noise_pred, t, current_actions
).prev_sample ).prev_sample
# 4. 输出最终动作序列(归一化空间,由调用方负责反归一化) # 4. 反归一化动作序列
return current_actions denormalized_actions = self.normalization.denormalize_action(current_actions)
return denormalized_actions
def get_normalization_stats(self):
"""获取归一化统计信息(用于保存到 checkpoint"""
return self.normalization.get_stats()

View File

@@ -9,14 +9,26 @@ defaults:
_target_: roboimi.vla.agent.VLAAgent _target_: roboimi.vla.agent.VLAAgent
# Action and Observation Dimensions # ====================
action_dim: 16 # 模型维度配置
obs_dim: 16 # ====================
action_dim: 16 # 动作维度(机器人关节数)
obs_dim: 16 # 本体感知维度(关节位置)
# Prediction and Observation Horizons # ====================
pred_horizon: 16 # 时间步配置
obs_horizon: 2 # ====================
pred_horizon: 16 # 预测未来多少步动作
obs_horizon: 2 # 使用多少步历史观测
num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1
# ====================
# 相机配置
# ====================
num_cams: 3 # 摄像头数量 (r_vis, top, front)
# Camera Configuration # ====================
num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取 # 扩散过程配置
# ====================
diffusion_steps: 100 # 扩散训练步数DDPM
inference_steps: 10 # 推理时的去噪步数DDIM固定为 10

View File

@@ -1,4 +0,0 @@
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
model_name: "microsoft/resnet-18"
freeze: true

View File

@@ -1,8 +1,28 @@
_target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone _target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
vision_backbone: "resnet18"
pretrained_backbone_weights: null # ====================
input_shape: [3, 96, 96] # 骨干网络选择
crop_shape: [84, 84] # ====================
crop_is_random: true vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50
use_group_norm: true pretrained_backbone_weights: null # 预训练权重路径或 nullImageNet 权重)
spatial_softmax_num_keypoints: 32
# ====================
# 输入配置
# ====================
input_shape: [3, 96, 96] # 输入图像形状 (C, H, W)
crop_shape: [84, 84] # 裁剪后的图像形状 (H, W)
crop_is_random: true # 训练时使用随机裁剪,评估时使用中心裁剪
# ====================
# 归一化和特征提取
# ====================
use_group_norm: true # 使用 GroupNorm 替代 BatchNorm更适合小批次训练
spatial_softmax_num_keypoints: 32 # Spatial Softmax 关键点数量
# ====================
# 编码器模式
# ====================
# false: 共享编码器(所有摄像头共享一个 ResNet参数少但容量受限推荐
# true: 独立编码器(每个摄像头有独立的 ResNet参数多但容量大
use_separate_rgb_encoder_per_camera: true
num_cameras: 3 # 摄像头数量

View File

@@ -1,19 +1,41 @@
defaults: defaults:
- agent: resnet_diffusion - agent: resnet_diffusion
- data: resnet_dataset - data: simpe_robot_dataset
- eval: eval - eval: eval
- _self_ - _self_
# ====================
# 训练配置
# ====================
train: train:
batch_size: 8 # Batch size for training # 基础训练参数
lr: 1e-4 # Learning rate batch_size: 8 # 批次大小
max_steps: 20000 # Maximum training steps lr: 1e-4 # 学习率
log_freq: 100 # Log frequency (steps) max_steps: 100000 # 最大训练步数
save_freq: 2000 # Save checkpoint frequency (steps) device: "cuda" # 设备: "cuda" 或 "cpu"
device: "cuda" # Device: "cuda" or "cpu"
num_workers: 8 # DataLoader workers (set to 0 for debugging, 8 for production)
# Learning rate scheduler with warmup # 数据加载
warmup_steps: 500 # Number of warmup steps num_workers: 8 # DataLoader 工作进程数(调试时设为 0生产环境用 8
scheduler_type: "cosine" # Scheduler after warmup: "constant" or "cosine" val_split: 0.1 # 验证集比例
min_lr: 1e-6 # Minimum learning rate (for cosine decay) seed: 42 # 随机种子(用于数据划分)
# 日志和检查点
log_freq: 100 # 日志记录频率(步数)
save_freq: 5000 # 保存检查点频率(步数)
# 学习率调度器(带预热)
warmup_steps: 500 # 预热步数
scheduler_type: "cosine" # 预热后的调度器: "constant" 或 "cosine"
min_lr: 1e-6 # 最小学习率(用于余弦退火)
# 优化器
weight_decay: 1e-5 # 权重衰减L2 正则化)
grad_clip: 1.0 # 梯度裁剪阈值
# ====================
# 实验配置
# ====================
experiment:
name: "vla_diffusion" # 实验名称
notes: "" # 实验备注
tags: [] # 实验标签

View File

@@ -1,19 +0,0 @@
# @package data
_target_: roboimi.vla.data.dataset.RobotDiffusionDataset
# Dataset Directory (CHANGE THIS TO YOUR DATA PATH)
dataset_dir: "roboimi/demos/dataset/sim_transfer" # Path to your dataset directory
# Horizon Parameters — 使用 Hydra 插值,从 agent 配置中引用,保持一致性
pred_horizon: ${agent.pred_horizon}
obs_horizon: ${agent.obs_horizon}
action_horizon: 8 # Action execution horizon (used during evaluation)
# Camera Names (CHANGE THIS TO MATCH YOUR CAMERAS)
camera_names:
- r_vis
- top
- front
# Normalization Type: 'gaussian' (mean/std) or 'min_max' ([-1, 1])
normalization_type: min_max

View File

@@ -0,0 +1,21 @@
# @package data
_target_: roboimi.vla.data.simpe_robot_dataset.SimpleRobotDataset
# ====================
# 数据集路径
# ====================
dataset_dir: "roboimi/demos/dataset/sim_transfer"
# ====================
# 时间步参数(从 agent 配置引用)
# ====================
pred_horizon: ${agent.pred_horizon} # 预测步数
obs_horizon: ${agent.obs_horizon} # 观测步数
# ====================
# 相机配置
# ====================
camera_names:
- r_vis # 机器人视角相机
- top # 顶部相机
- front # 前方相机

View File

@@ -1,19 +1,27 @@
# @package eval # @package eval
# Evaluation Configuration # 评估配置
ckpt_path: "checkpoints/vla_model_best.pt" # Path to model checkpoint ckpt_path: "checkpoints/vla_model_best.pt" # 模型检查点路径
num_episodes: 3 # Number of evaluation episodes num_episodes: 3 # 评估回合数
max_timesteps: 700 # Maximum timesteps per episode max_timesteps: 700 # 每回合最大时间步
device: ${train.device} # 与训练保持一致 device: ${train.device} # 与训练保持一致
task_name: "sim_transfer" # Task name for environment creation task_name: "sim_transfer" # 环境任务名称
# Policy execution — 从 agent 配置中引用,保持一致性 # ====================
num_queries: 4 # 每次预测 pred_horizon 步后重新查询 # 策略执行参数
# ====================
# num_queries 已废弃,现在使用 agent 的 select_action() 自动管理队列
# 以下参数仅用于兼容旧代码,实际使用 agent.num_action_steps
num_queries: ${agent.num_action_steps}
obs_horizon: ${agent.obs_horizon} obs_horizon: ${agent.obs_horizon}
# Camera names — 从 data 配置中引用,保持一致性 # ====================
# 相机配置
# ====================
camera_names: ${data.camera_names} camera_names: ${data.camera_names}
# Action smoothing # ====================
# 动作平滑
# ====================
use_smoothing: false use_smoothing: false
smooth_method: "ema" smooth_method: "ema"
smooth_alpha: 0.3 smooth_alpha: 0.3

View File

@@ -1,5 +1,15 @@
_target_: roboimi.vla.models.heads.conditional_unet1d.ConditionalUnet1D _target_: roboimi.vla.models.heads.conditional_unet1d.ConditionalUnet1D
_partial_: true _partial_: true
kernel_size: 3 # ====================
cond_predict_scale: false # UNet1D 配置
# ====================
kernel_size: 3 # 卷积核大小
cond_predict_scale: false # FiLM 条件化时是否同时预测 scalebias + scale 或仅 bias
# ====================
# 网络架构(默认值,可覆盖)
# ====================
# diffusion_step_embed_dim: 256 # 扩散时间步嵌入维度
# down_dims: [256, 512, 1024] # 下采样各层通道数
# n_groups: 8 # GroupNorm 分组数

View File

@@ -1,152 +0,0 @@
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import h5py
import numpy as np
import os
import glob
import pickle
class RobotDiffusionDataset(Dataset):
def __init__(self,
dataset_dir,
pred_horizon=16,
obs_horizon=2,
action_horizon=8,
camera_names=['r_vis', 'top', 'front'],
normalization_type='gaussian'):
"""
Args:
dataset_dir: 存放 episode_*.hdf5 的文件夹路径
pred_horizon: 预测未来动作的长度 (Tp)
obs_horizon: 历史观测长度 (To)
action_horizon: 执行动作长度 (Ta) - 在Dataset中主要影响Evaluation这里作为参数保留
"""
self.dataset_dir = dataset_dir
self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon
self.action_horizon = action_horizon
self.camera_names = camera_names
self.normalization_type = normalization_type
# 1. 扫描所有HDF5文件并建立索引
# 格式: [(file_path, episode_length), ...]
self.episode_files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
self.indices = []
print(f"Found {len(self.episode_files)} episodes. Building index...")
for file_path in self.episode_files:
with h5py.File(file_path, 'r') as f:
# 获取该 episode 的长度 (例如 700)
l = f['action'].shape[0]
# 保存每个有效 step 的索引信息
# (file_path, episode_length, current_step_index)
for i in range(l):
self.indices.append((file_path, l, i))
# 2. 统计数据
with open(os.path.join(dataset_dir, 'data_stats.pkl'), 'rb') as f:
self.stats = pickle.load(f)
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
file_path, episode_len, start_ts = self.indices[idx]
# -----------------------------
# 1. 打开文件
# -----------------------------
# 注意: 在 __getitem__ 中打开文件对多进程 DataLoader 更友好
# 如果追求极致IO性能可以考虑使用 h5py 的 swmr 模式或内存缓存
with h5py.File(file_path, 'r') as root:
# -----------------------------
# 2. 处理 Action (Prediction Target)
# -----------------------------
# 目标: 获取 [t, t + pred_horizon] 的动作
action_start = start_ts
action_end = min(start_ts + self.pred_horizon, episode_len)
actions = root['action'][action_start:action_end] # shape: (T_subset, 16)
# Padding: 如果剩余动作不足 pred_horizon复制最后一步
if len(actions) < self.pred_horizon:
pad_len = self.pred_horizon - len(actions)
last_action = actions[-1]
# 重复最后一行
pad_content = np.repeat(last_action[np.newaxis, :], pad_len, axis=0)
actions = np.concatenate([actions, pad_content], axis=0)
# 归一化 Action
if self.stats:
actions = self._normalize_data(actions, self.stats['action'])
# -----------------------------
# 3. 处理 Observations (History)
# -----------------------------
# 目标: 获取 [t - obs_horizon + 1, t + 1] 的观测
# 索引逻辑:
# 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding)
# 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5]
start_idx_raw = start_ts - (self.obs_horizon - 1)
start_idx = max(start_idx_raw, 0)
end_idx = start_ts + 1
pad_len = max(0, -start_idx_raw)
# Qpos
qpos_data = root['observations/qpos']
qpos_val = qpos_data[start_idx:end_idx]
if pad_len > 0:
first_frame = qpos_val[0]
padding = np.repeat(first_frame[np.newaxis, :], pad_len, axis=0)
qpos_val = np.concatenate([padding, qpos_val], axis=0)
if self.stats:
qpos_val = self._normalize_data(qpos_val, self.stats['qpos'])
# Images
image_dict = {}
for cam_name in self.camera_names:
img_dset = root['observations']['images'][cam_name]
imgs_np = img_dset[start_idx:end_idx] # (T, H, W, C)
if pad_len > 0:
first_frame = imgs_np[0]
padding = np.repeat(first_frame[np.newaxis, ...], pad_len, axis=0)
imgs_np = np.concatenate([padding, imgs_np], axis=0)
# 转换为 Tensor: (T, H, W, C) -> (T, C, H, W)
imgs_tensor = torch.from_numpy(imgs_np).float() / 255.0
imgs_tensor = torch.einsum('thwc->tchw', imgs_tensor)
image_dict[cam_name] = imgs_tensor
# ==============================
# 3. 组装 Batch
# ==============================
data_batch = {
'action': torch.from_numpy(actions).float(),
'qpos': torch.from_numpy(qpos_val).float(),
}
for cam_name, img_tensor in image_dict.items():
data_batch[f'image_{cam_name}'] = img_tensor
return data_batch
def _normalize_data(self, data, stats):
if self.normalization_type == 'min_max':
# 之前的逻辑: [-1, 1]
min_val = stats['min']
max_val = stats['max']
data = (data - min_val) / (max_val - min_val + 1e-8)
return data * 2 - 1
elif self.normalization_type == 'gaussian':
# 新逻辑: Mean/Std
mean = stats['mean']
std = stats['std']
# (data - mean) / std
# 这里的 data 是 numpy array
return (data - mean) / (std + 1e-8)

View File

@@ -1,53 +1,98 @@
import torch import torch
import h5py
from torch.utils.data import Dataset from torch.utils.data import Dataset
from typing import List, Dict, Optional from typing import List, Dict, Union
from pathlib import Path
class SimpleRobotDataset(Dataset): class SimpleRobotDataset(Dataset):
""" """
LeRobotDataset 简化版 - 图像以字典形式存储 HDF5 懒加载数据集 - LeRobotDataset 格式
与真实 LeRobotDataset 保持一致: 返回格式:
- Dataset 返回字典,每个摄像头单独的 key - observation.state: (obs_horizon, state_dim)
- Policy 负责在 forward 时 stack 图像 - observation.{cam_name}: (obs_horizon, C, H, W)
- action: (pred_horizon, action_dim)
""" """
def __init__( def __init__(
self, self,
frames: List[Dict], dataset_dir: Union[str, Path],
obs_horizon: int = 2, obs_horizon: int = 2,
pred_horizon: int = 8, pred_horizon: int = 8,
image_keys: List[str] = None, camera_names: List[str] = None,
): ):
""" """
Args: Args:
frames: 帧数据列表。每个元素是一个字典,包含: dataset_dir: HDF5 文件目录路径
- "episode_index" (int): [必须] 该帧所属的 Episode ID。Dataset 使用它来确定 Episode 的边界(用于 Padding
- "task" (str): [必须] 任务描述字符串(例如 "pick_up_cube")。
- "observation.state" (torch.Tensor): (state_dim,) [必须] 当前帧的机器人状态向量(例如关节角度)。
- "action" (torch.Tensor): (action_dim,) [必须] 当前帧对应的动作向量。
- "{image_key}" (torch.Tensor): (C, H, W) [可选] 当前帧的图像数据。键名必须与初始化 Dataset 时传入的 image_keys 列表一致。
obs_horizon: 观察过去多少帧 obs_horizon: 观察过去多少帧
pred_horizon: 预测未来多少帧动作 pred_horizon: 预测未来多少帧动作
image_keys: 哪些 key 是图像数据(例如 ["observation.image_0", "observation.image_1"] camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
HDF5 文件格式:
- action: [T, action_dim]
- observations/qpos: [T, obs_dim]
- observations/images/{cam_name}: [T, H, W, C]
""" """
self.frames = frames
self.obs_horizon = obs_horizon self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon self.pred_horizon = pred_horizon
self.image_keys = image_keys or [] self.camera_names = camera_names or []
# 构建 episode 索引 self.dataset_dir = Path(dataset_dir)
if not self.dataset_dir.exists():
raise FileNotFoundError(f"数据集目录不存在: {dataset_dir}")
# 查找 HDF5 文件
self.hdf5_files = sorted(self.dataset_dir.glob("*.hdf5"))
if not self.hdf5_files:
self.hdf5_files = sorted(self.dataset_dir.glob("episode_*.hdf5"))
if not self.hdf5_files:
raise FileNotFoundError(f"{dataset_dir} 中未找到 HDF5 文件")
# 构建 episode 索引(只存储元数据,不加载数据)
self.episodes = {} self.episodes = {}
for idx, frame in enumerate(frames): self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
ep_idx = frame["episode_index"] for ep_idx, hdf5_path in enumerate(self.hdf5_files):
if ep_idx not in self.episodes: with h5py.File(hdf5_path, 'r') as f:
self.episodes[ep_idx] = [] T = f['action'].shape[0]
self.episodes[ep_idx].append(idx) start_idx = len(self.frame_meta)
for t in range(T):
self.frame_meta.append({
"ep_idx": ep_idx,
"frame_idx": t,
"hdf5_path": hdf5_path,
})
self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta)))
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)}")
def __len__(self): def __len__(self):
return len(self.frames) return len(self.frame_meta)
def _load_frame(self, idx: int) -> Dict:
"""从 HDF5 文件懒加载单帧数据"""
meta = self.frame_meta[idx]
with h5py.File(meta["hdf5_path"], 'r') as f:
frame = {
"episode_index": meta["ep_idx"],
"frame_index": meta["frame_idx"],
"task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown",
"observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(),
"action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(),
}
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
for cam_name in self.camera_names:
h5_path = f'observations/images/{cam_name}'
if h5_path in f:
img = f[h5_path][meta["frame_idx"]]
img = torch.from_numpy(img).float()
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
return frame
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
frame = self.frames[idx] frame = self._load_frame(idx)
ep_idx = frame["episode_index"] ep_idx = frame["episode_index"]
# 获取当前 episode 的帧索引范围 # 获取当前 episode 的帧索引范围
@@ -61,9 +106,9 @@ class SimpleRobotDataset(Dataset):
observations = { observations = {
"state": [], # 状态数据 "state": [], # 状态数据
} }
# 为每个摄像头初始化独立列表(字典形式) # 为每个摄像头初始化独立列表
for cam_key in self.image_keys: for cam_name in self.camera_names:
observations[cam_key] = [] observations[f"observation.{cam_name}"] = []
observation_is_pad = [] observation_is_pad = []
@@ -72,22 +117,22 @@ class SimpleRobotDataset(Dataset):
# 边界检查 # 边界检查
if ep_start <= target_idx <= ep_end: if ep_start <= target_idx <= ep_end:
target_frame = self.frames[target_idx] target_frame = self._load_frame(target_idx)
is_pad = False is_pad = False
else: else:
# 超出边界,用边界帧填充 # 超出边界,用边界帧填充
if target_idx < ep_start: if target_idx < ep_start:
target_frame = self.frames[ep_start] target_frame = self._load_frame(ep_start)
else: else:
target_frame = self.frames[ep_end] target_frame = self._load_frame(ep_end)
is_pad = True is_pad = True
# 收集状态 # 收集状态
observations["state"].append(target_frame["observation.state"]) observations["state"].append(target_frame["observation.state"])
# 收集每个摄像头的图像(字典形式,不 stack # 收集每个摄像头的图像
for cam_key in self.image_keys: for cam_name in self.camera_names:
observations[cam_key].append(target_frame[cam_key]) observations[f"observation.{cam_name}"].append(target_frame[f"observation.{cam_name}"])
observation_is_pad.append(is_pad) observation_is_pad.append(is_pad)
@@ -101,14 +146,14 @@ class SimpleRobotDataset(Dataset):
target_idx = idx + delta target_idx = idx + delta
if target_idx <= ep_end: if target_idx <= ep_end:
actions.append(self.frames[target_idx]["action"]) actions.append(self._load_frame(target_idx)["action"])
action_is_pad.append(False) action_is_pad.append(False)
else: else:
actions.append(self.frames[ep_end]["action"]) actions.append(self._load_frame(ep_end)["action"])
action_is_pad.append(True) action_is_pad.append(True)
# ============================================ # ============================================
# 3. 组装返回数据(字典形式) # 3. 组装返回数据(LeRobotDataset 格式)
# ============================================ # ============================================
result = { result = {
# 状态观察: (obs_horizon, state_dim) # 状态观察: (obs_horizon, state_dim)
@@ -123,401 +168,32 @@ class SimpleRobotDataset(Dataset):
"task": frame["task"], "task": frame["task"],
} }
# 图像:每个摄像头独立的 key(字典形式) # 图像:每个摄像头独立的 key
# 形状: (obs_horizon, C, H, W) # 形状: (obs_horizon, C, H, W)
for cam_key in self.image_keys: for cam_name in self.camera_names:
result[cam_key] = torch.stack(observations[cam_key]) result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
return result return result
@property @property
def camera_keys(self) -> list[str]: def camera_keys(self) -> list[str]:
"""获取所有相机键名""" """获取所有相机键名 (LeRobotDataset 格式)"""
return self.image_keys return [f"observation.{cam_name}" for cam_name in self.camera_names]
@property @property
def camera_info(self) -> dict: def camera_info(self) -> dict:
"""获取相机信息""" """获取相机信息"""
if not self.image_keys: if not self.camera_names:
return {} return {}
# 从第一个样本获取形状 # 从第一个样本获取形状
sample = self[0] sample = self[0]
info = {} info = {}
for cam_key in self.image_keys: for cam_name in self.camera_names:
if cam_key in sample: key = f"observation.{cam_name}"
info[cam_key] = { if key in sample:
"shape": sample[cam_key].shape, info[key] = {
"dtype": str(sample[cam_key].dtype), "shape": sample[key].shape,
"dtype": str(sample[key].dtype),
} }
return info return info
class SimpleDiffusionPolicy(torch.nn.Module):
"""简化的 Diffusion Policy - 展示如何在 forward 时 stack 图像"""
def __init__(
self,
state_dim: int,
action_dim: int,
image_features: Dict[str, tuple] = None,
obs_horizon: int = 2,
pred_horizon: int = 8,
):
super().__init__()
self.state_dim = state_dim
self.action_dim = action_dim
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.image_features = image_features or {}
self.state_encoder = torch.nn.Linear(state_dim, 64)
if image_features:
num_cameras = len(image_features)
self.image_encoder = torch.nn.Conv2d(3, 32, kernel_size=7, stride=2)
self.fusion = torch.nn.Linear(64 + 32 * num_cameras, 128)
else:
self.fusion = torch.nn.Linear(64, 128)
self.action_head = torch.nn.Linear(128, action_dim * pred_horizon)
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""前向传播"""
# 处理状态
state_features = self.state_encoder(batch["observation.state"])
state_features = state_features.mean(dim=1)
# 处理图像(字典形式 → stack
if self.image_features:
image_tensors = [batch[key] for key in self.image_features.keys()]
stacked_images = torch.stack(image_tensors, dim=1)
B, num_cam, T, C, H, W = stacked_images.shape
images_flat = stacked_images.reshape(B * num_cam * T, C, H, W)
image_features = self.image_encoder(images_flat)
image_features = image_features.mean(dim=[2, 3])
image_features = image_features.reshape(B, num_cam, T, 32).mean(dim=2)
image_features = image_features.reshape(B, -1)
features = torch.cat([state_features, image_features], dim=-1)
else:
features = state_features
fused = self.fusion(features)
pred_actions = self.action_head(fused)
pred_actions = pred_actions.reshape(B, self.pred_horizon, self.action_dim)
return pred_actions
def create_demo_data_with_images():
"""创建包含图像的模拟数据"""
frames = []
# Episode 0: pick_up_cube task
for t in range(10):
frames.append({
"episode_index": 0,
"frame_index": t,
"task": "pick_up_cube",
"observation.state": torch.randn(6),
"observation.image_high_resize": torch.randn(3, 64, 64),
"observation.image_left_wrist": torch.randn(3, 64, 64),
"action": torch.randn(6),
})
# Episode 1: stack_blocks task
for t in range(10):
frames.append({
"episode_index": 1,
"frame_index": t,
"task": "stack_blocks",
"observation.state": torch.randn(6),
"observation.image_high_resize": torch.randn(3, 64, 64),
"observation.image_left_wrist": torch.randn(3, 64, 64),
"action": torch.randn(6),
})
return frames
def print_section(title: str):
"""打印分节标题"""
print("\n" + "=" * 80)
print(f" {title}")
print("=" * 80)
def test_dataset_basic_info(dataset):
"""测试数据集基本信息"""
print("\n📊 数据集基本信息:")
print(f" 总帧数: {len(dataset)}")
print(f" 总 episode 数: {len(dataset.episodes)}")
print(f" 观察窗口: {dataset.obs_horizon}")
print(f" 预测窗口: {dataset.pred_horizon}")
print(f"\n📷 相机信息:")
cameras = dataset.camera_keys
print(f" 相机数量: {len(cameras)}")
for cam in cameras:
print(f" - {cam}")
print(f"\n相机详细信息:")
cam_info = dataset.camera_info
for cam, info in cam_info.items():
print(f" {cam}:")
print(f" shape: {info['shape']}")
print(f" dtype: {info['dtype']}")
def test_single_sample(dataset):
"""测试单个样本"""
print_section("1. 测试单个样本")
# Episode 中间的样本
sample = dataset[5]
print("\n样本结构 (字典形式):")
for key, value in sample.items():
if isinstance(value, torch.Tensor):
print(f" {key:30s}: {str(value.shape):20s} {value.dtype}")
elif isinstance(value, str):
print(f" {key:30s}: {value}")
# 验证图像是字典形式
print("\n✅ 验证图像存储形式:")
print(" 图像以字典形式存储,每个摄像头独立的 key:")
for cam_key in dataset.camera_keys:
if cam_key in sample:
print(f" - {cam_key}: {sample[cam_key].shape}")
# 验证时间维度
print("\n✅ 验证时间维度:")
print(f" observation.state: {sample['observation.state'].shape}")
print(f" 预期: (obs_horizon={dataset.obs_horizon}, state_dim=6)")
assert sample['observation.state'].shape[0] == dataset.obs_horizon, "观察时间维度错误"
print(f" action: {sample['action'].shape}")
print(f" 预期: (pred_horizon={dataset.pred_horizon}, action_dim=6)")
assert sample['action'].shape[0] == dataset.pred_horizon, "动作时间维度错误"
print(" ✓ 时间维度验证通过")
def test_edge_cases(dataset):
"""测试边界情况"""
print_section("2. 测试边界情况")
test_cases = [
("Episode 开头", 0, {"obs_pad": [True, False], "action_pad": [False] * 8}),
("Episode 中间", 5, {"obs_pad": [False, False], "action_pad": [False] * 5 + [True] * 3}),
("Episode 末尾", 9, {"obs_pad": [False, False], "action_pad": [True] * 8}),
("跨 Episode", 10, {"obs_pad": [True, False], "action_pad": [False] * 8}),
]
for name, idx, expected in test_cases:
print(f"\n📍 {name} (idx={idx}):")
sample = dataset[idx]
obs_pad = sample["observation_is_pad"].tolist()
action_pad_count = sample["action_is_pad"].sum().item()
print(f" observation_is_pad: {obs_pad}")
print(f" action_is_pad: {sample['action_is_pad'].tolist()}")
print(f" action padding 数量: {action_pad_count}")
# 验证观察 padding
if name == "Episode 开头":
assert obs_pad[0] == True, "Episode 开头第一帧应该是 padding"
elif name == "跨 Episode":
assert obs_pad[0] == True, "跨 Episode 第一帧应该是 padding"
def test_dataloader(dataset):
"""测试 DataLoader"""
print_section("3. 测试 DataLoader 集成")
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
num_workers=0, # 测试时用 0
)
batch = next(iter(dataloader))
print("\n📦 Batch 结构:")
for key in ["observation.state", "observation.image_high_resize",
"observation.image_left_wrist", "action", "task"]:
if key in batch:
value = batch[key]
if isinstance(value, torch.Tensor):
print(f" {key:30s}: {str(value.shape):20s} {value.dtype}")
else:
print(f" {key:30s}: {type(value).__name__} (length={len(value)})")
print("\n✅ 验证 Batch 形状:")
B = len(batch["observation.state"])
print(f" Batch size: {B}")
# 验证每个摄像头的形状
for cam_key in dataset.camera_keys:
expected_shape = (B, dataset.obs_horizon, 3, 64, 64)
actual_shape = batch[cam_key].shape
print(f" {cam_key}:")
print(f" 预期: {expected_shape}")
print(f" 实际: {actual_shape}")
assert actual_shape == expected_shape, f"{cam_key} 形状不匹配"
print(" ✓ Batch 形状验证通过")
def test_policy_forward(dataset):
"""测试 Policy 前向传播"""
print_section("4. 测试 Policy 前向传播")
# 创建 Policy
policy = SimpleDiffusionPolicy(
state_dim=6,
action_dim=6,
image_features={
"observation.image_high_resize": (3, 64, 64),
"observation.image_left_wrist": (3, 64, 64),
},
obs_horizon=dataset.obs_horizon,
pred_horizon=dataset.pred_horizon,
)
# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
batch = next(iter(dataloader))
print("\n🔄 Policy.forward() 流程:")
# 1. Stack 之前
print("\n 1⃣ Stack 之前 (字典形式):")
for cam_key in policy.image_features.keys():
print(f" batch['{cam_key}']: {batch[cam_key].shape}")
# 2. 模拟 Stack 操作
print("\n 2⃣ Stack 操作:")
image_tensors = [batch[key] for key in policy.image_features.keys()]
stacked = torch.stack(image_tensors, dim=1)
print(f" stacked_images: {stacked.shape}")
print(f" (B={stacked.shape[0]}, num_cam={stacked.shape[1]}, ")
print(f" obs_hor={stacked.shape[2]}, C={stacked.shape[3]}, H={stacked.shape[4]}, W={stacked.shape[5]})")
# 3. 前向传播
print("\n 3⃣ 前向传播:")
with torch.no_grad():
pred_actions = policy(batch)
print(f" 输入:")
print(f" observation.state: {batch['observation.state'].shape}")
print(f" 图像已 stack")
print(f" 输出:")
print(f" pred_actions: {pred_actions.shape}")
print(f" (B={pred_actions.shape[0]}, pred_horizon={pred_actions.shape[1]}, action_dim={pred_actions.shape[2]})")
print("\n✅ Policy 前向传播验证通过")
def test_data_consistency(dataset):
"""测试数据一致性"""
print_section("5. 测试数据一致性")
print("\n🔍 验证图像 padding 的正确性:")
# Episode 开头的样本
sample = dataset[0]
if sample["observation_is_pad"][0]:
img_0 = sample["observation.image_high_resize"][0]
img_1 = sample["observation.image_high_resize"][1]
print(f" Episode 开头 (idx=0):")
print(f" 第0帧是 padding: {sample['observation_is_pad'][0]}")
print(f" 第0帧图像 = 第1帧图像: {torch.equal(img_0, img_1)}")
assert torch.equal(img_0, img_1), "Padding 应该复制边界帧"
print(" ✓ Padding 正确")
# Episode 中间的样本
sample = dataset[5]
if not sample["observation_is_pad"].any():
img_0 = sample["observation.image_high_resize"][0]
img_1 = sample["observation.image_high_resize"][1]
print(f"\n Episode 中间 (idx=5):")
print(f" 没有 padding: {sample['observation_is_pad']}")
print(f" 第0帧图像 ≠ 第1帧图像: {not torch.equal(img_0, img_1)}")
print(" ✓ 正常帧不重复")
print("\n✅ 数据一致性验证通过")
def test_task_info(dataset):
"""测试任务信息"""
print_section("6. 测试任务信息")
print("\n📋 统计任务分布:")
task_count = {}
for frame in dataset.frames:
task = frame["task"]
task_count[task] = task_count.get(task, 0) + 1
for task, count in task_count.items():
print(f" {task}: {count}")
# 验证 sample 中的 task 信息
sample = dataset[0]
print(f"\n样本 task: {sample['task']}")
print(f" 类型: {type(sample['task'])}")
# 验证 DataLoader 中的 task
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
batch = next(iter(dataloader))
print(f"\nBatch task:")
print(f" 值: {batch['task']}")
print(f" 类型: {type(batch['task'])}")
print(f" 长度: {len(batch['task'])}")
print("\n✅ 任务信息验证通过")
def run_all_tests():
"""运行所有测试"""
print("\n" + "🚀" * 40)
print(" SimpleRobotDataset 完整测试套件")
print("🚀" * 40)
# 创建数据集
print("\n创建测试数据...")
frames = create_demo_data_with_images()
dataset = SimpleRobotDataset(
frames,
obs_horizon=2,
pred_horizon=8,
image_keys=["observation.image_high_resize", "observation.image_left_wrist"],
)
print("✓ 数据集创建完成")
# 运行测试
test_dataset_basic_info(dataset)
test_single_sample(dataset)
test_edge_cases(dataset)
test_dataloader(dataset)
test_policy_forward(dataset)
test_data_consistency(dataset)
test_task_info(dataset)
# 总结
print_section("✅ 测试总结")
print("\n所有测试通过!✨")
print("\n关键验证点:")
print(" ✓ 图像以字典形式存储")
print(" ✓ 每个摄像头独立的 key")
print(" ✓ Policy 在 forward 时 stack 图像")
print(" ✓ 时间维度正确 (obs_horizon, pred_horizon)")
print(" ✓ Padding 处理正确")
print(" ✓ DataLoader 集成正确")
print(" ✓ Task 信息传递正确")
print("\n与 LeRobotDataset 设计完全一致!🎉")
if __name__ == "__main__":
from torch.utils.data import DataLoader
run_all_tests()

View File

@@ -1,4 +1,4 @@
# Backbone models # Backbone models
from .resnet import ResNetBackbone from .resnet_diffusion import ResNetDiffusionBackbone
__all__ = ["ResNetBackbone"] __all__ = ["ResNetBackbone", "ResNetDiffusionBackbone"]

View File

@@ -1,93 +0,0 @@
from roboimi.vla.core.interfaces import VLABackbone
from transformers import ResNetModel
from torchvision import transforms
import torch
import torch.nn as nn
class ResNetBackbone(VLABackbone):
def __init__(
self,
model_name = "microsoft/resnet-18",
freeze: bool = True,
):
super().__init__()
self.model = ResNetModel.from_pretrained(model_name)
self.out_channels = self.model.config.hidden_sizes[-1]
self.transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.spatial_softmax = SpatialSoftmax(num_rows=12, num_cols=12)
if freeze:
self._freeze_parameters()
def _freeze_parameters(self):
print("❄️ Freezing ResNet Backbone parameters")
for param in self.model.parameters():
param.requires_grad = False
self.model.eval()
def train(self, mode=True):
"""
Override train() to keep frozen ResNet in eval mode.
This ensures BatchNorm layers use running statistics consistently.
"""
super().train(mode)
if hasattr(self, 'model'):
self.model.eval() # Always keep ResNet in eval mode
return self
def forward_single_image(self, image):
B, T, C, H, W = image.shape
image = image.view(B * T, C, H, W)
image = self.transform(image)
feature_map = self.model(image).last_hidden_state # (B*T, D, H', W')
features = self.spatial_softmax(feature_map) # (B*T, D*2)
return features
def forward(self, images):
any_tensor = next(iter(images.values()))
B, T = any_tensor.shape[:2]
features_all = []
sorted_cam_names = sorted(images.keys())
for cam_name in sorted_cam_names:
img = images[cam_name]
features = self.forward_single_image(img) # (B*T, D*2)
features_all.append(features)
combined_features = torch.cat(features_all, dim=1) # (B*T, Num_Cams*D*2)
return combined_features.view(B, T, -1)
@property
def output_dim(self):
"""Output dimension after spatial softmax: out_channels * 2"""
return self.out_channels * 2
class SpatialSoftmax(nn.Module):
"""
将特征图 (N, C, H, W) 转换为坐标特征 (N, C*2)
"""
def __init__(self, num_rows, num_cols, temperature=None):
super().__init__()
self.temperature = nn.Parameter(torch.ones(1))
# 创建网格坐标
pos_x, pos_y = torch.meshgrid(
torch.linspace(-1, 1, num_rows),
torch.linspace(-1, 1, num_cols),
indexing='ij'
)
self.register_buffer('pos_x', pos_x.reshape(-1))
self.register_buffer('pos_y', pos_y.reshape(-1))
def forward(self, x):
N, C, H, W = x.shape
x = x.view(N, C, -1) # (N, C, H*W)
# 计算 Softmax 注意力图
softmax_attention = torch.nn.functional.softmax(x / self.temperature, dim=2)
# 计算期望坐标 (x, y)
expected_x = torch.sum(softmax_attention * self.pos_x, dim=2, keepdim=True)
expected_y = torch.sum(softmax_attention * self.pos_y, dim=2, keepdim=True)
# 拼接并展平 -> (N, C*2)
return torch.cat([expected_x, expected_y], dim=2).reshape(N, -1)

View File

@@ -91,20 +91,21 @@ class SpatialSoftmax(nn.Module):
return feature_keypoints return feature_keypoints
class ResNetDiffusionBackbone(VLABackbone): class _SingleRgbEncoder(nn.Module):
"""单个摄像头的 RGB 编码器,支持独立或共享使用"""
def __init__( def __init__(
self, self,
vision_backbone: str = "resnet18", vision_backbone: str,
pretrained_backbone_weights: str | None = None, pretrained_backbone_weights: str | None,
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W) input_shape: Tuple[int, int, int],
crop_shape: Optional[Tuple[int, int]] = None, crop_shape: Optional[Tuple[int, int]],
crop_is_random: bool = True, crop_is_random: bool,
use_group_norm: bool = True, use_group_norm: bool,
spatial_softmax_num_keypoints: int = 32, spatial_softmax_num_keypoints: int,
): ):
super().__init__() super().__init__()
# 设置可选的预处理 # 设置可选的预处理
if crop_shape is not None: if crop_shape is not None:
self.do_crop = True self.do_crop = True
# 评估时始终使用中心裁剪 # 评估时始终使用中心裁剪
@@ -117,7 +118,7 @@ class ResNetDiffusionBackbone(VLABackbone):
self.do_crop = False self.do_crop = False
crop_shape = input_shape[1:] crop_shape = input_shape[1:]
# 设置骨干网络 # 设置骨干网络
backbone_model = getattr(torchvision.models, vision_backbone)( backbone_model = getattr(torchvision.models, vision_backbone)(
weights=pretrained_backbone_weights weights=pretrained_backbone_weights
) )
@@ -132,8 +133,8 @@ class ResNetDiffusionBackbone(VLABackbone):
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
) )
# 设置池化和最终层 # 设置池化和最终层
# 使用试运行来获取特征图形状 # 使用试运行来获取特征图形状
dummy_shape = (1, input_shape[0], *crop_shape) dummy_shape = (1, input_shape[0], *crop_shape)
with torch.no_grad(): with torch.no_grad():
dummy_out = self.backbone(torch.zeros(dummy_shape)) dummy_out = self.backbone(torch.zeros(dummy_shape))
@@ -150,13 +151,83 @@ class ResNetDiffusionBackbone(VLABackbone):
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))) x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
return x return x
class ResNetDiffusionBackbone(VLABackbone):
def __init__(
self,
vision_backbone: str = "resnet18",
pretrained_backbone_weights: str | None = None,
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W)
crop_shape: Optional[Tuple[int, int]] = None,
crop_is_random: bool = True,
use_group_norm: bool = True,
spatial_softmax_num_keypoints: int = 32,
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
):
super().__init__()
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
self.num_cameras = num_cameras
if use_separate_rgb_encoder_per_camera:
# 独立编码器模式:为每个摄像头创建独立的编码器
encoders = [
_SingleRgbEncoder(
vision_backbone=vision_backbone,
pretrained_backbone_weights=pretrained_backbone_weights,
input_shape=input_shape,
crop_shape=crop_shape,
crop_is_random=crop_is_random,
use_group_norm=use_group_norm,
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
)
for _ in range(num_cameras)
]
self.rgb_encoder = nn.ModuleList(encoders)
# 重要output_dim 始终表示单个编码器的特征维度(与 lerobot 保持一致)
self.feature_dim = encoders[0].feature_dim
else:
# 共享编码器模式:所有摄像头共享同一个编码器
self.rgb_encoder = _SingleRgbEncoder(
vision_backbone=vision_backbone,
pretrained_backbone_weights=pretrained_backbone_weights,
input_shape=input_shape,
crop_shape=crop_shape,
crop_is_random=crop_is_random,
use_group_norm=use_group_norm,
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
)
self.feature_dim = self.rgb_encoder.feature_dim
def forward(self, images): def forward(self, images):
"""
Args:
images: Dict[str, Tensor], 每个摄像头的图像
形状: {cam_name: (B, T, C, H, W)}
Returns:
Tensor: (B, T, total_feature_dim)
"""
any_tensor = next(iter(images.values())) any_tensor = next(iter(images.values()))
B, T = any_tensor.shape[:2] B, T = any_tensor.shape[:2]
cam_names = sorted(images.keys())
if self.use_separate_rgb_encoder_per_camera:
# 独立编码器模式:每个摄像头使用对应的编码器
features_all = [] features_all = []
for cam_name in sorted(images.keys()): for cam_idx, cam_name in enumerate(cam_names):
img = images[cam_name] img = images[cam_name]
features = self.forward_single_image(img.view(B * T, *img.shape[2:])) encoder = self.rgb_encoder[cam_idx]
features = encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
features_all.append(features)
return torch.cat(features_all, dim=1).view(B, T, -1)
else:
# 共享编码器模式:所有摄像头共享同一个编码器
features_all = []
for cam_name in cam_names:
img = images[cam_name]
features = self.rgb_encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
features_all.append(features) features_all.append(features)
return torch.cat(features_all, dim=1).view(B, T, -1) return torch.cat(features_all, dim=1).view(B, T, -1)
@@ -165,7 +236,9 @@ class ResNetDiffusionBackbone(VLABackbone):
return self.feature_dim return self.feature_dim
if __name__ == "__main__": if __name__ == "__main__":
print("🚀 Testing ResNetDiffusionBackbone...") print("=" * 60)
print("🚀 Testing ResNetDiffusionBackbone")
print("=" * 60)
# Configuration # Configuration
B, T = 2, 5 B, T = 2, 5
@@ -174,34 +247,109 @@ if __name__ == "__main__":
num_keypoints = 32 num_keypoints = 32
feature_dim_per_cam = num_keypoints * 2 feature_dim_per_cam = num_keypoints * 2
# Instantiate model # Create dummy input (2 cameras)
backbone = ResNetDiffusionBackbone( images = {
"cam_high": torch.randn(B, T, C, H, W),
"cam_wrist": torch.randn(B, T, C, H, W)
}
num_cameras = len(images)
# ============================================================================
# Test 1: Shared Encoder (默认模式)
# ============================================================================
print("\n[Test 1] Shared Encoder Mode")
print("-" * 60)
backbone_shared = ResNetDiffusionBackbone(
vision_backbone="resnet18", vision_backbone="resnet18",
pretrained_backbone_weights=None, # Speed up test pretrained_backbone_weights=None, # Speed up test
input_shape=(C, H, W), input_shape=(C, H, W),
crop_shape=(crop_h, crop_w), crop_shape=(crop_h, crop_w),
crop_is_random=True, crop_is_random=True,
use_group_norm=True, use_group_norm=True,
spatial_softmax_num_keypoints=num_keypoints spatial_softmax_num_keypoints=num_keypoints,
use_separate_rgb_encoder_per_camera=False, # 共享编码器
) )
print(f"Model instantiated. Output dim per camera: {backbone.output_dim}") print(f"Shared encoder model instantiated")
print(f" Output dim per camera: {feature_dim_per_cam}")
print(f" Number of cameras: {num_cameras}")
print(f" Expected total dim: {num_cameras * feature_dim_per_cam}")
# Create dummy input output = backbone_shared(images)
images = { print(f"\n🔄 Forward pass completed")
"cam_high": torch.randn(B, T, C, H, W), print(f" Input shapes: {[v.shape for v in images.values()]}")
"cam_wrist": torch.randn(B, T, C, H, W) print(f" Output shape: {output.shape}")
}
# Forward pass expected_dim = num_cameras * feature_dim_per_cam
print("🔄 Running forward pass...")
output = backbone(images)
print(f"Input shapes: {[v.shape for v in images.values()]}")
print(f"Output shape: {output.shape}")
# Verification
expected_dim = len(images) * feature_dim_per_cam
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}" assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
print(f"✨ Test passed!")
print("✨ Test passed!") # ============================================================================
# Test 2: Separate Encoders (独立编码器模式)
# ============================================================================
print("\n[Test 2] Separate Encoders Mode")
print("-" * 60)
backbone_separate = ResNetDiffusionBackbone(
vision_backbone="resnet18",
pretrained_backbone_weights=None, # Speed up test
input_shape=(C, H, W),
crop_shape=(crop_h, crop_w),
crop_is_random=True,
use_group_norm=True,
spatial_softmax_num_keypoints=num_keypoints,
use_separate_rgb_encoder_per_camera=True, # 独立编码器
num_cameras=num_cameras,
)
print(f"✅ Separate encoders model instantiated")
print(f" Output dim per camera: {feature_dim_per_cam}")
print(f" Number of cameras: {num_cameras}")
print(f" Number of encoders: {len(backbone_separate.rgb_encoder)}")
output = backbone_separate(images)
print(f"\n🔄 Forward pass completed")
print(f" Input shapes: {[v.shape for v in images.values()]}")
print(f" Output shape: {output.shape}")
expected_dim = num_cameras * feature_dim_per_cam
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
print(f"✨ Test passed!")
# ============================================================================
# Test 3: Verify parameters count
# ============================================================================
print("\n[Test 3] Parameter Count Comparison")
print("-" * 60)
shared_params = sum(p.numel() for p in backbone_shared.parameters())
separate_params = sum(p.numel() for p in backbone_separate.parameters())
print(f" Shared encoder parameters: {shared_params:,}")
print(f" Separate encoders parameters: {separate_params:,}")
print(f" Ratio: {separate_params / shared_params:.2f}x")
assert separate_params > shared_params, "Separate encoders should have more parameters"
print(f"✨ Verification passed!")
# ============================================================================
# Test 4: Verify independent parameters
# ============================================================================
print("\n[Test 4] Verify Independent Parameters")
print("-" * 60)
# Check that encoders have independent parameters
encoder_0_first_param = list(backbone_separate.rgb_encoder[0].parameters())[0]
encoder_1_first_param = list(backbone_separate.rgb_encoder[1].parameters())[0]
# Modify first encoder's parameter
with torch.no_grad():
encoder_0_first_param += 1.0
# Verify they are not the same tensor
assert not torch.allclose(encoder_0_first_param, encoder_1_first_param), \
"Encoders should have independent parameters"
print(f"✅ Encoders have independent parameters")
print(f"✨ All tests passed!")
print("\n" + "=" * 60)
print("🎉 All tests completed successfully!")
print("=" * 60)

View File

@@ -0,0 +1,128 @@
"""
归一化模块 - 统一训练和推理的归一化逻辑
支持两种归一化方式:
1. Gaussian (z-score): (x - mean) / std
2. MinMax: 2 * (x - min) / (max - min) - 1 -> [-1, 1]
"""
import torch
import torch.nn as nn
from typing import Optional, Dict, Literal
class NormalizationModule(nn.Module):
"""
统一的归一化模块
用于在 Agent 内部对 qpos 和 action 进行归一化/反归一化
"""
def __init__(
self,
stats: Optional[Dict] = None,
normalization_type: Literal['gaussian', 'min_max'] = 'gaussian'
):
"""
Args:
stats: 数据集统计信息字典,格式:
{
'normalization_type': 'gaussian' | 'min_max',
'qpos_mean': [...],
'qpos_std': [...],
'qpos_min': [...], # 仅 min_max 需要
'qpos_max': [...], # 仅 min_max 需要
'action_mean': [...],
'action_std': [...],
'action_min': [...], # 仅 min_max 需要
'action_max': [...], # 仅 min_max 需要
}
normalization_type: 归一化类型 ('gaussian''min_max')
"""
super().__init__()
self.normalization_type = normalization_type
self.enabled = stats is not None
if self.enabled:
# 从 stats 中读取归一化类型(如果提供)
self.normalization_type = stats.get('normalization_type', normalization_type)
# 注册为 buffer (不会被优化,但会随模型保存)
self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32))
self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32))
self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32))
self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32))
# MinMax 归一化需要 min/max
if self.normalization_type == 'min_max':
qpos_min = stats.get('qpos_min', [0.0] * len(stats['qpos_mean']))
qpos_max = stats.get('qpos_max', [1.0] * len(stats['qpos_mean']))
action_min = stats.get('action_min', [0.0] * len(stats['action_mean']))
action_max = stats.get('action_max', [1.0] * len(stats['action_mean']))
self.register_buffer('qpos_min', torch.tensor(qpos_min, dtype=torch.float32))
self.register_buffer('qpos_max', torch.tensor(qpos_max, dtype=torch.float32))
self.register_buffer('action_min', torch.tensor(action_min, dtype=torch.float32))
self.register_buffer('action_max', torch.tensor(action_max, dtype=torch.float32))
def normalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
"""归一化 qpos"""
if not self.enabled:
return qpos
if self.normalization_type == 'gaussian':
return (qpos - self.qpos_mean) / self.qpos_std
else: # min_max
return 2 * (qpos - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
def denormalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
"""反归一化 qpos"""
if not self.enabled:
return qpos
if self.normalization_type == 'gaussian':
return qpos * self.qpos_std + self.qpos_mean
else: # min_max
return (qpos + 1) / 2 * (self.qpos_max - self.qpos_min) + self.qpos_min
def normalize_action(self, action: torch.Tensor) -> torch.Tensor:
"""归一化 action"""
if not self.enabled:
return action
if self.normalization_type == 'gaussian':
return (action - self.action_mean) / self.action_std
else: # min_max
return 2 * (action - self.action_min) / (self.action_max - self.action_min) - 1
def denormalize_action(self, action: torch.Tensor) -> torch.Tensor:
"""反归一化 action"""
if not self.enabled:
return action
if self.normalization_type == 'gaussian':
return action * self.action_std + self.action_mean
else: # min_max
return (action + 1) / 2 * (self.action_max - self.action_min) + self.action_min
def get_stats(self) -> Optional[Dict]:
"""导出统计信息(用于保存到 checkpoint"""
if not self.enabled:
return None
stats = {
'normalization_type': self.normalization_type,
'qpos_mean': self.qpos_mean.cpu().tolist(),
'qpos_std': self.qpos_std.cpu().tolist(),
'action_mean': self.action_mean.cpu().tolist(),
'action_std': self.action_std.cpu().tolist(),
}
if self.normalization_type == 'min_max':
stats['qpos_min'] = self.qpos_min.cpu().tolist()
stats['qpos_max'] = self.qpos_max.cpu().tolist()
stats['action_min'] = self.action_min.cpu().tolist()
stats['action_max'] = self.action_max.cpu().tolist()
return stats