refactor:大重构
This commit is contained in:
238
diffusion/configuration_diffusion.py
Normal file
238
diffusion/configuration_diffusion.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
from lerobot.optim.optimizers import AdamConfig
|
||||||
|
from lerobot.optim.schedulers import DiffuserSchedulerConfig
|
||||||
|
|
||||||
|
|
||||||
|
@PreTrainedConfig.register_subclass("diffusion")
|
||||||
|
@dataclass
|
||||||
|
class DiffusionConfig(PreTrainedConfig):
|
||||||
|
"""Configuration class for DiffusionPolicy.
|
||||||
|
|
||||||
|
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||||
|
|
||||||
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
|
Those are: `input_shapes` and `output_shapes`.
|
||||||
|
|
||||||
|
Notes on the inputs and outputs:
|
||||||
|
- "observation.state" is required as an input key.
|
||||||
|
- Either:
|
||||||
|
- At least one key starting with "observation.image is required as an input.
|
||||||
|
AND/OR
|
||||||
|
- The key "observation.environment_state" is required as input.
|
||||||
|
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||||
|
views. Right now we only support all images having the same shape.
|
||||||
|
- "action" is required as an output key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
|
current step and additional steps going back).
|
||||||
|
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||||
|
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||||
|
See `DiffusionPolicy.select_action` for more details.
|
||||||
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
|
include batch dimension or temporal dimension.
|
||||||
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
[-1, 1] range.
|
||||||
|
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||||
|
original scale. Note that this is also used for normalizing the training targets.
|
||||||
|
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||||
|
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||||
|
within the image size. If None, no cropping is done.
|
||||||
|
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||||
|
mode).
|
||||||
|
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||||
|
`None` means no pretrained weights.
|
||||||
|
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||||
|
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||||
|
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||||
|
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||||
|
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||||
|
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||||
|
downsampling.
|
||||||
|
kernel_size: The convolutional kernel size of the diffusion modeling Unet.
|
||||||
|
n_groups: Number of groups used in the group norm of the Unet's convolutional blocks.
|
||||||
|
diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear
|
||||||
|
network. This is the output dimension of that network, i.e., the embedding dimension.
|
||||||
|
use_film_scale_modulation: FiLM (https://huggingface.co/papers/1709.07871) is used for the Unet conditioning.
|
||||||
|
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
||||||
|
modulation.
|
||||||
|
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
|
||||||
|
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
||||||
|
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
||||||
|
beta_start: Beta value for the first forward-diffusion step.
|
||||||
|
beta_end: Beta value for the last forward-diffusion step.
|
||||||
|
prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon"
|
||||||
|
or "sample". These have equivalent outcomes from a latent variable modeling perspective, but
|
||||||
|
"epsilon" has been shown to work better in many deep neural network settings.
|
||||||
|
clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
|
||||||
|
denoising step at inference time. WARNING: you will need to make sure your action-space is
|
||||||
|
normalized to fit within this range.
|
||||||
|
clip_sample_range: The magnitude of the clipping range as described above.
|
||||||
|
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
||||||
|
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
||||||
|
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
|
||||||
|
`LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
|
||||||
|
to False as the original Diffusion Policy implementation does the same.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Inputs / output structure.
|
||||||
|
n_obs_steps: int = 2
|
||||||
|
horizon: int = 16
|
||||||
|
n_action_steps: int = 8
|
||||||
|
|
||||||
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"VISUAL": NormalizationMode.MEAN_STD,
|
||||||
|
"STATE": NormalizationMode.MIN_MAX,
|
||||||
|
"ACTION": NormalizationMode.MIN_MAX,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# The original implementation doesn't sample frames for the last 7 steps,
|
||||||
|
# which avoids excessive padding and leads to improved training results.
|
||||||
|
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
||||||
|
|
||||||
|
# Architecture / modeling.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: str = "resnet18"
|
||||||
|
crop_shape: tuple[int, int] | None = (84, 84)
|
||||||
|
crop_is_random: bool = True
|
||||||
|
pretrained_backbone_weights: str | None = None
|
||||||
|
use_group_norm: bool = True
|
||||||
|
spatial_softmax_num_keypoints: int = 32
|
||||||
|
use_separate_rgb_encoder_per_camera: bool = False
|
||||||
|
# Unet.
|
||||||
|
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||||
|
kernel_size: int = 5
|
||||||
|
n_groups: int = 8
|
||||||
|
diffusion_step_embed_dim: int = 128
|
||||||
|
use_film_scale_modulation: bool = True
|
||||||
|
# Noise scheduler.
|
||||||
|
noise_scheduler_type: str = "DDPM"
|
||||||
|
num_train_timesteps: int = 100
|
||||||
|
beta_schedule: str = "squaredcos_cap_v2"
|
||||||
|
beta_start: float = 0.0001
|
||||||
|
beta_end: float = 0.02
|
||||||
|
prediction_type: str = "epsilon"
|
||||||
|
clip_sample: bool = True
|
||||||
|
clip_sample_range: float = 1.0
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
num_inference_steps: int | None = None
|
||||||
|
|
||||||
|
# Loss computation
|
||||||
|
do_mask_loss_for_padding: bool = False
|
||||||
|
|
||||||
|
# Training presets
|
||||||
|
optimizer_lr: float = 1e-4
|
||||||
|
optimizer_betas: tuple = (0.95, 0.999)
|
||||||
|
optimizer_eps: float = 1e-8
|
||||||
|
optimizer_weight_decay: float = 1e-6
|
||||||
|
scheduler_name: str = "cosine"
|
||||||
|
scheduler_warmup_steps: int = 500
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
"""Input validation (not exhaustive)."""
|
||||||
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
|
raise ValueError(
|
||||||
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
|
)
|
||||||
|
|
||||||
|
supported_prediction_types = ["epsilon", "sample"]
|
||||||
|
if self.prediction_type not in supported_prediction_types:
|
||||||
|
raise ValueError(
|
||||||
|
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
||||||
|
)
|
||||||
|
supported_noise_schedulers = ["DDPM", "DDIM"]
|
||||||
|
if self.noise_scheduler_type not in supported_noise_schedulers:
|
||||||
|
raise ValueError(
|
||||||
|
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
|
||||||
|
f"Got {self.noise_scheduler_type}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that the horizon size and U-Net downsampling is compatible.
|
||||||
|
# U-Net downsamples by 2 with each stage.
|
||||||
|
downsampling_factor = 2 ** len(self.down_dims)
|
||||||
|
if self.horizon % downsampling_factor != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The horizon should be an integer multiple of the downsampling factor (which is determined "
|
||||||
|
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> AdamConfig:
|
||||||
|
return AdamConfig(
|
||||||
|
lr=self.optimizer_lr,
|
||||||
|
betas=self.optimizer_betas,
|
||||||
|
eps=self.optimizer_eps,
|
||||||
|
weight_decay=self.optimizer_weight_decay,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||||
|
return DiffuserSchedulerConfig(
|
||||||
|
name=self.scheduler_name,
|
||||||
|
num_warmup_steps=self.scheduler_warmup_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||||
|
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||||
|
|
||||||
|
if self.crop_shape is not None:
|
||||||
|
for key, image_ft in self.image_features.items():
|
||||||
|
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||||
|
raise ValueError(
|
||||||
|
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||||
|
f"for `crop_shape` and {image_ft.shape} for "
|
||||||
|
f"`{key}`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that all input images have the same shape.
|
||||||
|
if len(self.image_features) > 0:
|
||||||
|
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||||
|
for key, image_ft in self.image_features.items():
|
||||||
|
if image_ft.shape != first_image_ft.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> list:
|
||||||
|
return list(range(1 - self.n_obs_steps, 1))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list:
|
||||||
|
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> None:
|
||||||
|
return None
|
||||||
92
diffusion/processor_diffusion.py
Normal file
92
diffusion/processor_diffusion.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
from lerobot.processor import (
|
||||||
|
AddBatchDimensionProcessorStep,
|
||||||
|
DeviceProcessorStep,
|
||||||
|
NormalizerProcessorStep,
|
||||||
|
PolicyAction,
|
||||||
|
PolicyProcessorPipeline,
|
||||||
|
RenameObservationsProcessorStep,
|
||||||
|
UnnormalizerProcessorStep,
|
||||||
|
)
|
||||||
|
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||||
|
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
|
|
||||||
|
|
||||||
|
def make_diffusion_pre_post_processors(
|
||||||
|
config: DiffusionConfig,
|
||||||
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Constructs pre-processor and post-processor pipelines for a diffusion policy.
|
||||||
|
|
||||||
|
The pre-processing pipeline prepares the input data for the model by:
|
||||||
|
1. Renaming features.
|
||||||
|
2. Normalizing the input and output features based on dataset statistics.
|
||||||
|
3. Adding a batch dimension.
|
||||||
|
4. Moving the data to the specified device.
|
||||||
|
|
||||||
|
The post-processing pipeline handles the model's output by:
|
||||||
|
1. Moving the data to the CPU.
|
||||||
|
2. Unnormalizing the output features to their original scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The configuration object for the diffusion policy,
|
||||||
|
containing feature definitions, normalization mappings, and device information.
|
||||||
|
dataset_stats: A dictionary of statistics used for normalization.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_steps = [
|
||||||
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
|
AddBatchDimensionProcessorStep(),
|
||||||
|
DeviceProcessorStep(device=config.device),
|
||||||
|
NormalizerProcessorStep(
|
||||||
|
features={**config.input_features, **config.output_features},
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
output_steps = [
|
||||||
|
UnnormalizerProcessorStep(
|
||||||
|
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device="cpu"),
|
||||||
|
]
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
|
steps=input_steps,
|
||||||
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||||
|
steps=output_steps,
|
||||||
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
to_transition=policy_action_to_transition,
|
||||||
|
to_output=transition_to_policy_action,
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
VLA Policy Evaluation Script (Hydra-based)
|
VLA 策略评估脚本(简化版)
|
||||||
|
|
||||||
This script evaluates a trained Vision-Language-Action (VLA) policy
|
该脚本使用 agent 内置的队列管理来评估训练好的 VLA 策略。
|
||||||
in the MuJoCo simulation environment.
|
无需单独的评估器类 - agent 处理一切!
|
||||||
|
|
||||||
Usage:
|
使用方法:
|
||||||
python roboimi/demos/eval_vla.py
|
python roboimi/demos/eval_vla_simple.py
|
||||||
python roboimi/demos/eval_vla.py ckpt_path=checkpoints/vla_model_step_8000.pt num_episodes=5
|
python roboimi/demos/eval_vla_simple.py eval.ckpt_path=checkpoints/vla_model_final.pt
|
||||||
python roboimi/demos/eval_vla.py use_smoothing=true smooth_alpha=0.5
|
python roboimi/demos/eval_vla_simple.py eval.ckpt_path=checkpoints/vla_model_best.pt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
@@ -19,314 +19,152 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import hydra
|
import hydra
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
# Ensure correct import path
|
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
|
|
||||||
if not OmegaConf.has_resolver("len"):
|
if not OmegaConf.has_resolver("len"):
|
||||||
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||||
|
|
||||||
|
|
||||||
class VLAEvaluator:
|
|
||||||
"""
|
|
||||||
VLA Policy Evaluator for MuJoCo Simulation
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
agent: torch.nn.Module,
|
|
||||||
device: str = 'cuda',
|
|
||||||
camera_names: List[str] = ['r_vis', 'top', 'front'],
|
|
||||||
num_queries: int = 1,
|
|
||||||
obs_horizon: int = 2,
|
|
||||||
pred_horizon: int = 16,
|
|
||||||
use_smoothing: bool = False,
|
|
||||||
smooth_method: str = 'ema',
|
|
||||||
smooth_alpha: float = 0.3,
|
|
||||||
dataset_stats: dict = None
|
|
||||||
):
|
|
||||||
self.agent = agent.to(device)
|
|
||||||
self.device = device
|
|
||||||
self.camera_names = camera_names
|
|
||||||
self.num_queries = num_queries
|
|
||||||
self.obs_horizon = obs_horizon
|
|
||||||
self.pred_horizon = pred_horizon
|
|
||||||
|
|
||||||
# Dataset statistics for normalization/denormalization
|
|
||||||
self.stats = dataset_stats
|
|
||||||
if self.stats is not None:
|
|
||||||
self.normalization_type = self.stats.get('normalization_type', 'gaussian')
|
|
||||||
self.qpos_mean = torch.tensor(self.stats['qpos_mean'], dtype=torch.float32)
|
|
||||||
self.qpos_std = torch.tensor(self.stats['qpos_std'], dtype=torch.float32)
|
|
||||||
self.qpos_min = torch.tensor(self.stats.get('qpos_min', []), dtype=torch.float32)
|
|
||||||
self.qpos_max = torch.tensor(self.stats.get('qpos_max', []), dtype=torch.float32)
|
|
||||||
self.action_mean = torch.tensor(self.stats['action_mean'], dtype=torch.float32)
|
|
||||||
self.action_std = torch.tensor(self.stats['action_std'], dtype=torch.float32)
|
|
||||||
self.action_min = torch.tensor(self.stats.get('action_min', []), dtype=torch.float32)
|
|
||||||
self.action_max = torch.tensor(self.stats.get('action_max', []), dtype=torch.float32)
|
|
||||||
else:
|
|
||||||
self.normalization_type = None
|
|
||||||
|
|
||||||
# Action smoothing
|
|
||||||
self.use_smoothing = use_smoothing
|
|
||||||
self.smooth_method = smooth_method
|
|
||||||
self.smooth_alpha = smooth_alpha
|
|
||||||
self.smoother = ActionSmoother(
|
|
||||||
action_dim=16,
|
|
||||||
method=smooth_method,
|
|
||||||
alpha=smooth_alpha
|
|
||||||
) if use_smoothing else None
|
|
||||||
|
|
||||||
# Observation buffer for obs_horizon
|
|
||||||
self.obs_buffer = {
|
|
||||||
'images': {cam: [] for cam in camera_names},
|
|
||||||
'qpos': []
|
|
||||||
}
|
|
||||||
self.cached_actions = None
|
|
||||||
self.query_step = 0
|
|
||||||
|
|
||||||
# Timing statistics
|
|
||||||
self.inference_times = [] # Model inference time only
|
|
||||||
self.total_times = [] # Total prediction time (including preprocessing)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""Reset evaluator state"""
|
|
||||||
self.obs_buffer = {
|
|
||||||
'images': {cam: [] for cam in self.camera_names},
|
|
||||||
'qpos': []
|
|
||||||
}
|
|
||||||
self.cached_actions = None
|
|
||||||
self.query_step = 0
|
|
||||||
if self.smoother is not None:
|
|
||||||
self.smoother.reset()
|
|
||||||
|
|
||||||
# Reset timing stats for each episode
|
|
||||||
self.inference_times = []
|
|
||||||
self.total_times = []
|
|
||||||
|
|
||||||
def _get_image_dict(self, obs: Dict) -> Dict[str, torch.Tensor]:
|
|
||||||
images = {}
|
|
||||||
for cam_name in self.camera_names:
|
|
||||||
img = obs['images'][cam_name]
|
|
||||||
img = rearrange(img, 'h w c -> c h w')
|
|
||||||
img = torch.from_numpy(img / 255.0).float()
|
|
||||||
images[cam_name] = img
|
|
||||||
|
|
||||||
image_dict = {}
|
|
||||||
for cam_name in self.camera_names:
|
|
||||||
cam_images = self.obs_buffer['images'][cam_name]
|
|
||||||
cam_images.append(images[cam_name])
|
|
||||||
|
|
||||||
while len(cam_images) < self.obs_horizon:
|
|
||||||
cam_images.insert(0, cam_images[0])
|
|
||||||
|
|
||||||
if len(cam_images) > self.obs_horizon:
|
|
||||||
cam_images = cam_images[-self.obs_horizon:]
|
|
||||||
|
|
||||||
img_tensor = torch.stack(cam_images, dim=0).unsqueeze(0)
|
|
||||||
image_dict[cam_name] = img_tensor
|
|
||||||
|
|
||||||
self.obs_buffer['images'][cam_name] = cam_images[-self.obs_horizon:]
|
|
||||||
|
|
||||||
return image_dict
|
|
||||||
|
|
||||||
def _get_qpos_dict(self, obs: Dict) -> torch.Tensor:
|
|
||||||
qpos = obs['qpos']
|
|
||||||
qpos = torch.from_numpy(qpos).float()
|
|
||||||
|
|
||||||
self.obs_buffer['qpos'].append(qpos)
|
|
||||||
|
|
||||||
while len(self.obs_buffer['qpos']) < self.obs_horizon:
|
|
||||||
self.obs_buffer['qpos'].insert(0, self.obs_buffer['qpos'][0])
|
|
||||||
|
|
||||||
if len(self.obs_buffer['qpos']) > self.obs_horizon:
|
|
||||||
self.obs_buffer['qpos'] = self.obs_buffer['qpos'][-self.obs_horizon:]
|
|
||||||
|
|
||||||
qpos_tensor = torch.stack(self.obs_buffer['qpos'], dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
|
|
||||||
|
|
||||||
# Normalize qpos
|
|
||||||
if self.stats is not None:
|
|
||||||
if self.normalization_type == 'gaussian':
|
|
||||||
qpos_tensor = (qpos_tensor - self.qpos_mean) / self.qpos_std
|
|
||||||
else: # min_max: normalize to [-1, 1]
|
|
||||||
qpos_tensor = 2 * (qpos_tensor - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
|
|
||||||
|
|
||||||
return qpos_tensor
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def predict_action(self, obs: Dict) -> np.ndarray:
|
|
||||||
start_total = time.time()
|
|
||||||
|
|
||||||
images = self._get_image_dict(obs)
|
|
||||||
qpos = self._get_qpos_dict(obs)
|
|
||||||
|
|
||||||
if self.cached_actions is None or self.query_step % self.num_queries == 0:
|
|
||||||
images = {k: v.to(self.device) for k, v in images.items()}
|
|
||||||
qpos = qpos.to(self.device)
|
|
||||||
|
|
||||||
# Measure pure model inference time
|
|
||||||
start_inference = time.time()
|
|
||||||
predicted_actions = self.agent.predict_action(
|
|
||||||
images=images,
|
|
||||||
proprioception=qpos
|
|
||||||
)
|
|
||||||
|
|
||||||
# Synchronize CUDA if using GPU to get accurate timing
|
|
||||||
if self.device == 'cuda':
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end_inference = time.time()
|
|
||||||
|
|
||||||
inference_time = end_inference - start_inference
|
|
||||||
self.inference_times.append(inference_time)
|
|
||||||
|
|
||||||
# Denormalize actions
|
|
||||||
if self.stats is not None:
|
|
||||||
if self.normalization_type == 'gaussian':
|
|
||||||
predicted_actions = predicted_actions * self.action_std.to(self.device) + self.action_mean.to(self.device)
|
|
||||||
else: # min_max
|
|
||||||
predicted_actions = (predicted_actions + 1) / 2 * (self.action_max.to(self.device) - self.action_min.to(self.device)) + self.action_min.to(self.device)
|
|
||||||
|
|
||||||
self.cached_actions = predicted_actions.squeeze(0).cpu().numpy()
|
|
||||||
self.query_step = 0
|
|
||||||
|
|
||||||
raw_action = self.cached_actions[self.query_step]
|
|
||||||
self.query_step += 1
|
|
||||||
|
|
||||||
if self.smoother is not None:
|
|
||||||
raw_action = self.smoother.smooth(raw_action)
|
|
||||||
|
|
||||||
end_total = time.time()
|
|
||||||
total_time = end_total - start_total
|
|
||||||
self.total_times.append(total_time)
|
|
||||||
|
|
||||||
return raw_action
|
|
||||||
|
|
||||||
def get_timing_stats(self) -> Dict:
|
|
||||||
"""Get timing statistics"""
|
|
||||||
if len(self.inference_times) == 0:
|
|
||||||
return {
|
|
||||||
'inference_fps': 0.0,
|
|
||||||
'control_fps': 0.0,
|
|
||||||
'avg_inference_time_ms': 0.0,
|
|
||||||
'avg_total_time_ms': 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
avg_inference_time = np.mean(self.inference_times)
|
|
||||||
avg_total_time = np.mean(self.total_times)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'inference_fps': 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0,
|
|
||||||
'control_fps': 1.0 / avg_total_time if avg_total_time > 0 else 0.0,
|
|
||||||
'avg_inference_time_ms': avg_inference_time * 1000,
|
|
||||||
'avg_total_time_ms': avg_total_time * 1000,
|
|
||||||
'num_inferences': len(self.inference_times),
|
|
||||||
'num_steps': len(self.total_times)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ActionSmoother:
|
|
||||||
"""Action smoothing for smoother execution"""
|
|
||||||
|
|
||||||
def __init__(self, action_dim: int, method: str = 'ema', alpha: float = 0.3):
|
|
||||||
self.action_dim = action_dim
|
|
||||||
self.method = method
|
|
||||||
self.alpha = alpha
|
|
||||||
self.prev_action = None
|
|
||||||
|
|
||||||
def smooth(self, action: np.ndarray) -> np.ndarray:
|
|
||||||
if self.method == 'ema':
|
|
||||||
if self.prev_action is None:
|
|
||||||
smoothed = action
|
|
||||||
else:
|
|
||||||
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
|
|
||||||
self.prev_action = smoothed
|
|
||||||
return smoothed
|
|
||||||
else:
|
|
||||||
return action
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.prev_action = None
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
ckpt_path: str,
|
ckpt_path: str,
|
||||||
agent_cfg: DictConfig,
|
agent_cfg: DictConfig,
|
||||||
device: str = 'cuda'
|
device: str = 'cuda'
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
"""
|
"""
|
||||||
Load trained VLA model from checkpoint using Hydra agent config.
|
从检查点加载训练好的 VLA 模型,使用 Hydra agent 配置。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ckpt_path: Path to checkpoint file (.pt)
|
ckpt_path: 检查点文件路径 (.pt)
|
||||||
agent_cfg: Hydra agent config for instantiation
|
agent_cfg: Hydra agent 配置,用于实例化
|
||||||
device: Device to load model on
|
device: 加载模型的设备
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loaded VLAAgent model
|
加载后的 VLAAgent 模型
|
||||||
"""
|
"""
|
||||||
from pathlib import Path as PathLib
|
from pathlib import Path as PathLib
|
||||||
|
|
||||||
ckpt_path = PathLib(ckpt_path).absolute()
|
ckpt_path = PathLib(ckpt_path).absolute()
|
||||||
if not ckpt_path.exists():
|
if not ckpt_path.exists():
|
||||||
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
raise FileNotFoundError(f"检查点未找到: {ckpt_path}")
|
||||||
|
|
||||||
log.info(f"Loading checkpoint from {ckpt_path}")
|
log.info(f"从 {ckpt_path} 加载检查点")
|
||||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
||||||
log.info(f"Checkpoint keys: {checkpoint.keys()}")
|
log.info(f"检查点键值: {checkpoint.keys()}")
|
||||||
|
|
||||||
# Instantiate agent from Hydra config
|
# 加载数据集统计信息用于归一化
|
||||||
log.info("Instantiating agent from config...")
|
|
||||||
agent = instantiate(agent_cfg)
|
|
||||||
|
|
||||||
# Load model state
|
|
||||||
agent.load_state_dict(checkpoint['model_state_dict'])
|
|
||||||
log.info(f"✅ Model state loaded (step: {checkpoint.get('step', 'unknown')})")
|
|
||||||
|
|
||||||
# Load dataset statistics for denormalization
|
|
||||||
stats = checkpoint.get('dataset_stats', None)
|
stats = checkpoint.get('dataset_stats', None)
|
||||||
|
|
||||||
|
# 使用数据集统计信息从 Hydra 配置实例化 agent
|
||||||
|
log.info("从配置实例化 agent...")
|
||||||
|
agent = instantiate(agent_cfg, dataset_stats=stats)
|
||||||
|
|
||||||
|
# 加载模型状态
|
||||||
|
agent.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
log.info(f"✅ 模型状态已加载 (步数: {checkpoint.get('step', 'unknown')})")
|
||||||
|
|
||||||
if stats is not None:
|
if stats is not None:
|
||||||
log.info(f"✅ Dataset statistics loaded (normalization: {stats.get('normalization_type', 'gaussian')})")
|
log.info(f"✅ 数据集统计信息已加载 (归一化: {stats.get('normalization_type', 'gaussian')})")
|
||||||
else:
|
else:
|
||||||
# Fallback: try external JSON file (兼容旧 checkpoint)
|
# 后备方案:尝试从外部 JSON 文件加载(兼容旧检查点)
|
||||||
stats_path = ckpt_path.parent / 'dataset_stats.json'
|
stats_path = ckpt_path.parent / 'dataset_stats.json'
|
||||||
if stats_path.exists():
|
if stats_path.exists():
|
||||||
with open(stats_path, 'r') as f:
|
with open(stats_path, 'r') as f:
|
||||||
stats = json.load(f)
|
stats = json.load(f)
|
||||||
log.info("✅ Dataset statistics loaded from external JSON (legacy)")
|
log.info("✅ 数据集统计信息已从外部 JSON 加载(旧版本兼容)")
|
||||||
else:
|
else:
|
||||||
log.warning("⚠️ No dataset statistics found. Actions will not be denormalized!")
|
log.warning("⚠️ 未找到数据集统计信息。动作将无法反归一化!")
|
||||||
|
|
||||||
agent.eval()
|
agent.eval()
|
||||||
agent.to(device)
|
agent.to(device)
|
||||||
|
|
||||||
log.info(f"✅ Model loaded successfully on {device}")
|
log.info(f"✅ 模型已成功加载到 {device}")
|
||||||
return agent, stats
|
return agent, stats
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_observation(obs: Dict, camera_names: list) -> Dict:
|
||||||
|
"""
|
||||||
|
将环境观测转换为 agent 格式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs: 环境观测字典,包含图像和 qpos
|
||||||
|
camera_names: 摄像头名称列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
agent 格式的观测字典
|
||||||
|
"""
|
||||||
|
# 转换图像: numpy -> tensor, HWC -> CHW
|
||||||
|
images = {}
|
||||||
|
for cam_name in camera_names:
|
||||||
|
img = obs['images'][cam_name]
|
||||||
|
img = rearrange(img, 'h w c -> c h w')
|
||||||
|
img = torch.from_numpy(img / 255.0).float()
|
||||||
|
images[cam_name] = img
|
||||||
|
|
||||||
|
# 转换 qpos: numpy -> tensor
|
||||||
|
qpos = torch.from_numpy(obs['qpos']).float()
|
||||||
|
|
||||||
|
return {'qpos': qpos, 'images': images}
|
||||||
|
|
||||||
|
|
||||||
|
class ActionSmoother:
|
||||||
|
"""
|
||||||
|
动作平滑器(指数移动平均)
|
||||||
|
用于平滑执行动作以获得更稳定的控制
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, alpha: float = 0.3):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
alpha: 平滑系数 (0-1),值越大越重视当前动作
|
||||||
|
"""
|
||||||
|
self.alpha = alpha
|
||||||
|
self.prev_action = None
|
||||||
|
|
||||||
|
def smooth(self, action: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
平滑动作
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: 当前动作
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
平滑后的动作
|
||||||
|
"""
|
||||||
|
if self.prev_action is None:
|
||||||
|
smoothed = action
|
||||||
|
else:
|
||||||
|
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
|
||||||
|
self.prev_action = smoothed
|
||||||
|
return smoothed
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""重置平滑器状态"""
|
||||||
|
self.prev_action = None
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||||||
def main(cfg: DictConfig):
|
def main(cfg: DictConfig):
|
||||||
"""
|
"""
|
||||||
VLA Evaluation Script with Hydra Configuration.
|
使用 agent 内置队列管理的简化版 VLA 评估
|
||||||
|
|
||||||
All eval parameters come from vla/conf/eval.yaml, merged into cfg.
|
所有评估参数来自 vla/conf/eval.yaml,合并到 cfg 中。
|
||||||
Override on command line: python eval_vla.py eval.ckpt_path=... eval.num_episodes=5
|
命令行覆盖: python eval_vla_simple.py eval.ckpt_path=... eval.num_episodes=5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Print configuration
|
# 打印配置
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print("VLA Evaluation Configuration:")
|
print("VLA 评估配置:")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print(OmegaConf.to_yaml(cfg))
|
print(OmegaConf.to_yaml(cfg))
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -335,67 +173,114 @@ def main(cfg: DictConfig):
|
|||||||
device = eval_cfg.device
|
device = eval_cfg.device
|
||||||
camera_names = list(eval_cfg.camera_names)
|
camera_names = list(eval_cfg.camera_names)
|
||||||
|
|
||||||
# Load model
|
# =========================================================================
|
||||||
log.info(f"🚀 Loading model from {eval_cfg.ckpt_path}...")
|
# 加载模型
|
||||||
|
# =========================================================================
|
||||||
|
log.info(f"🚀 从 {eval_cfg.ckpt_path} 加载模型...")
|
||||||
agent, dataset_stats = load_checkpoint(
|
agent, dataset_stats = load_checkpoint(
|
||||||
ckpt_path=eval_cfg.ckpt_path,
|
ckpt_path=eval_cfg.ckpt_path,
|
||||||
agent_cfg=cfg.agent,
|
agent_cfg=cfg.agent,
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create evaluator
|
# 重置 agent 的队列
|
||||||
evaluator = VLAEvaluator(
|
agent.reset()
|
||||||
agent=agent,
|
|
||||||
device=device,
|
|
||||||
camera_names=camera_names,
|
|
||||||
num_queries=eval_cfg.num_queries,
|
|
||||||
obs_horizon=eval_cfg.obs_horizon,
|
|
||||||
use_smoothing=eval_cfg.use_smoothing,
|
|
||||||
smooth_method=eval_cfg.smooth_method,
|
|
||||||
smooth_alpha=eval_cfg.smooth_alpha,
|
|
||||||
dataset_stats=dataset_stats
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create environment
|
# 可选:动作平滑器
|
||||||
|
smoother = ActionSmoother(alpha=eval_cfg.smooth_alpha) if eval_cfg.use_smoothing else None
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 创建环境
|
||||||
|
# =========================================================================
|
||||||
env = make_sim_env(eval_cfg.task_name)
|
env = make_sim_env(eval_cfg.task_name)
|
||||||
|
|
||||||
# Run episodes
|
# =========================================================================
|
||||||
|
# 运行评估回合
|
||||||
|
# =========================================================================
|
||||||
all_stats = []
|
all_stats = []
|
||||||
|
|
||||||
for episode_idx in range(eval_cfg.num_episodes):
|
for episode_idx in range(eval_cfg.num_episodes):
|
||||||
print(f"\n{'='*60}")
|
print(f"\n{'='*60}")
|
||||||
print(f"Episode {episode_idx + 1}/{eval_cfg.num_episodes}")
|
print(f"回合 {episode_idx + 1}/{eval_cfg.num_episodes}")
|
||||||
print(f"{'='*60}\n")
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
box_pos = sample_transfer_pose()
|
box_pos = sample_transfer_pose()
|
||||||
env.reset(box_pos)
|
env.reset(box_pos)
|
||||||
evaluator.reset()
|
|
||||||
|
# 为新回合重置 agent 队列
|
||||||
|
agent.reset()
|
||||||
|
if smoother:
|
||||||
|
smoother.reset()
|
||||||
|
|
||||||
|
# 计时统计
|
||||||
|
inference_times = []
|
||||||
|
total_times = []
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"Episode {episode_idx + 1}"):
|
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"回合 {episode_idx + 1}"):
|
||||||
|
start_total = time.time()
|
||||||
|
|
||||||
|
# 从环境获取观测
|
||||||
obs = env._get_image_obs()
|
obs = env._get_image_obs()
|
||||||
qpos_obs = env._get_qpos_obs()
|
qpos_obs = env._get_qpos_obs()
|
||||||
obs['qpos'] = qpos_obs['qpos']
|
obs['qpos'] = qpos_obs['qpos']
|
||||||
|
|
||||||
action = evaluator.predict_action(obs)
|
# 准备给 agent 的观测
|
||||||
env.step_jnt(action)
|
observation = prepare_observation(obs, camera_names)
|
||||||
|
|
||||||
|
# 选择动作(agent 内部处理队列管理)
|
||||||
|
start_inference = time.time()
|
||||||
|
action = agent.select_action(observation)
|
||||||
|
|
||||||
|
if device == 'cuda':
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end_inference = time.time()
|
||||||
|
|
||||||
|
# 转换为 numpy
|
||||||
|
action = action.cpu().numpy()
|
||||||
|
|
||||||
|
# 可选:平滑动作
|
||||||
|
if smoother:
|
||||||
|
action = smoother.smooth(action)
|
||||||
|
|
||||||
|
# 执行动作
|
||||||
|
env.step_jnt(action)
|
||||||
env.render()
|
env.render()
|
||||||
|
|
||||||
# Get timing statistics for this episode
|
end_total = time.time()
|
||||||
stats = evaluator.get_timing_stats()
|
|
||||||
|
# 记录计时
|
||||||
|
inference_times.append(end_inference - start_inference)
|
||||||
|
total_times.append(end_total - start_total)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 打印回合统计
|
||||||
|
# =========================================================================
|
||||||
|
avg_inference_time = np.mean(inference_times)
|
||||||
|
avg_total_time = np.mean(total_times)
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
'inference_fps': 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0,
|
||||||
|
'control_fps': 1.0 / avg_total_time if avg_total_time > 0 else 0.0,
|
||||||
|
'avg_inference_time_ms': avg_inference_time * 1000,
|
||||||
|
'avg_total_time_ms': avg_total_time * 1000,
|
||||||
|
'num_inferences': len([t for t in inference_times if t > 0.001]), # 统计实际推理次数
|
||||||
|
'num_steps': len(total_times)
|
||||||
|
}
|
||||||
all_stats.append(stats)
|
all_stats.append(stats)
|
||||||
|
|
||||||
print(f"\nEpisode {episode_idx + 1} completed ({eval_cfg.max_timesteps} timesteps)")
|
print(f"\n回合 {episode_idx + 1} 完成 ({eval_cfg.max_timesteps} 时间步)")
|
||||||
print(f" Model Inference FPS: {stats['inference_fps']:.2f} Hz")
|
print(f" 模型推理 FPS: {stats['inference_fps']:.2f} Hz")
|
||||||
print(f" Control Loop FPS: {stats['control_fps']:.2f} Hz")
|
print(f" 控制循环 FPS: {stats['control_fps']:.2f} Hz")
|
||||||
print(f" Avg Inference Time: {stats['avg_inference_time_ms']:.2f} ms")
|
print(f" 平均推理时间: {stats['avg_inference_time_ms']:.2f} ms")
|
||||||
print(f" Avg Total Time: {stats['avg_total_time_ms']:.2f} ms")
|
print(f" 平均总时间: {stats['avg_total_time_ms']:.2f} ms")
|
||||||
print(f" Total Inferences: {stats['num_inferences']}")
|
print(f" 总推理次数: {stats['num_inferences']}")
|
||||||
|
|
||||||
# Print overall statistics
|
# =========================================================================
|
||||||
|
# 总体统计
|
||||||
|
# =========================================================================
|
||||||
print(f"\n{'='*60}")
|
print(f"\n{'='*60}")
|
||||||
print("Evaluation complete!")
|
print("评估完成!")
|
||||||
print(f"{'='*60}")
|
print(f"{'='*60}")
|
||||||
|
|
||||||
if all_stats:
|
if all_stats:
|
||||||
@@ -404,11 +289,11 @@ def main(cfg: DictConfig):
|
|||||||
avg_inference_time = np.mean([s['avg_inference_time_ms'] for s in all_stats])
|
avg_inference_time = np.mean([s['avg_inference_time_ms'] for s in all_stats])
|
||||||
avg_total_time = np.mean([s['avg_total_time_ms'] for s in all_stats])
|
avg_total_time = np.mean([s['avg_total_time_ms'] for s in all_stats])
|
||||||
|
|
||||||
print(f"\nOverall Statistics ({eval_cfg.num_episodes} episodes):")
|
print(f"\n总体统计 ({eval_cfg.num_episodes} 个回合):")
|
||||||
print(f" Average Model Inference FPS: {avg_inference_fps:.2f} Hz")
|
print(f" 平均模型推理 FPS: {avg_inference_fps:.2f} Hz")
|
||||||
print(f" Average Control Loop FPS: {avg_control_fps:.2f} Hz")
|
print(f" 平均控制循环 FPS: {avg_control_fps:.2f} Hz")
|
||||||
print(f" Average Inference Time: {avg_inference_time:.2f} ms")
|
print(f" 平均推理时间: {avg_inference_time:.2f} ms")
|
||||||
print(f" Average Total Time: {avg_total_time:.2f} ms")
|
print(f" 平均总时间: {avg_total_time:.2f} ms")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,28 +12,28 @@ from torch.optim import AdamW
|
|||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Ensure correct import path
|
# 确保正确的导入路径
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
|
# 注册列表长度解析器(用于配置中如 ${len:${data.camera_names}})
|
||||||
if not OmegaConf.has_resolver("len"):
|
if not OmegaConf.has_resolver("len"):
|
||||||
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||||
|
|
||||||
|
|
||||||
def recursive_to_device(data, device):
|
def recursive_to_device(data, device):
|
||||||
"""
|
"""
|
||||||
Recursively move nested dictionaries/lists of tensors to specified device.
|
递归地将嵌套字典/列表中的张量移动到指定设备。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: Dictionary, list, or tensor
|
data: 字典、列表或张量
|
||||||
device: Target device (e.g., 'cuda', 'cpu')
|
device: 目标设备 (例如 'cuda', 'cpu')
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Data structure with all tensors moved to device
|
所有张量已移动到指定设备的数据结构
|
||||||
"""
|
"""
|
||||||
if isinstance(data, torch.Tensor):
|
if isinstance(data, torch.Tensor):
|
||||||
return data.to(device)
|
return data.to(device)
|
||||||
@@ -46,36 +46,36 @@ def recursive_to_device(data, device):
|
|||||||
|
|
||||||
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
|
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
|
||||||
"""
|
"""
|
||||||
Create a learning rate scheduler with warmup.
|
创建带预热的学习率调度器。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer: PyTorch optimizer
|
optimizer: PyTorch 优化器
|
||||||
warmup_steps: Number of warmup steps
|
warmup_steps: 预热步数
|
||||||
max_steps: Total training steps
|
max_steps: 总训练步数
|
||||||
scheduler_type: Type of scheduler after warmup ('cosine' or 'constant')
|
scheduler_type: 预热后的调度器类型 ('cosine' 或 'constant')
|
||||||
min_lr: Minimum learning rate (for cosine decay)
|
min_lr: 最小学习率(用于余弦衰减)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LambdaLR scheduler
|
LambdaLR 调度器
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
# Capture initial lr before LambdaLR modifies it
|
# 在 LambdaLR 修改前捕获初始学习率
|
||||||
base_lr = optimizer.param_groups[0]['lr']
|
base_lr = optimizer.param_groups[0]['lr']
|
||||||
min_lr_ratio = min_lr / base_lr if base_lr > 0 else 0.0
|
min_lr_ratio = min_lr / base_lr if base_lr > 0 else 0.0
|
||||||
|
|
||||||
def lr_lambda(step):
|
def lr_lambda(step):
|
||||||
# Warmup phase: linear increase from 0 to 1
|
# 预热阶段:从 0 线性增加到 1
|
||||||
if step < warmup_steps:
|
if step < warmup_steps:
|
||||||
return float(step) / float(max(1, warmup_steps))
|
return float(step) / float(max(1, warmup_steps))
|
||||||
|
|
||||||
# Post-warmup phase
|
# 预热后阶段
|
||||||
if scheduler_type == 'cosine':
|
if scheduler_type == 'cosine':
|
||||||
# Cosine annealing from 1 to min_lr_ratio
|
# 从 1 到 min_lr_ratio 的余弦退火
|
||||||
progress = float(step - warmup_steps) / float(max(1, max_steps - warmup_steps))
|
progress = float(step - warmup_steps) / float(max(1, max_steps - warmup_steps))
|
||||||
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
|
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
|
||||||
return max(min_lr_ratio, cosine_decay)
|
return max(min_lr_ratio, cosine_decay)
|
||||||
else:
|
else:
|
||||||
# Constant learning rate
|
# 恒定学习率
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
return LambdaLR(optimizer, lr_lambda)
|
return LambdaLR(optimizer, lr_lambda)
|
||||||
@@ -84,40 +84,40 @@ def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_ty
|
|||||||
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||||||
def main(cfg: DictConfig):
|
def main(cfg: DictConfig):
|
||||||
"""
|
"""
|
||||||
VLA Training Script with ResNet Backbone and Diffusion Policy.
|
VLA 训练脚本(ResNet 骨干网络 + Diffusion 策略)
|
||||||
|
|
||||||
This script:
|
该脚本功能:
|
||||||
1. Loads dataset from HDF5 files
|
1. 从 HDF5 文件加载数据集
|
||||||
2. Instantiates VLAAgent with ResNet vision encoder
|
2. 实例化带 ResNet 视觉编码器的 VLAAgent
|
||||||
3. Trains diffusion-based action prediction
|
3. 训练基于扩散的动作预测模型
|
||||||
4. Saves checkpoints periodically
|
4. 定期保存检查点
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Print configuration
|
# 打印配置
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print("VLA Training Configuration:")
|
print("VLA 训练配置:")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print(OmegaConf.to_yaml(cfg))
|
print(OmegaConf.to_yaml(cfg))
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
log.info(f"🚀 Starting VLA Training (Device: {cfg.train.device})")
|
log.info(f"🚀 开始 VLA 训练 (设备: {cfg.train.device})")
|
||||||
|
|
||||||
# Create checkpoint directory
|
# 创建检查点目录
|
||||||
checkpoint_dir = Path("checkpoints")
|
checkpoint_dir = Path("checkpoints")
|
||||||
checkpoint_dir.mkdir(exist_ok=True)
|
checkpoint_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 1. Instantiate Dataset & DataLoader
|
# 1. 实例化数据集与 DataLoader
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
log.info("📦 Loading dataset...")
|
log.info("📦 加载数据集...")
|
||||||
try:
|
try:
|
||||||
dataset = instantiate(cfg.data)
|
dataset = instantiate(cfg.data)
|
||||||
log.info(f"✅ Dataset loaded successfully. Total samples: {len(dataset)}")
|
log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"❌ Failed to load dataset: {e}")
|
log.error(f"❌ 数据集加载失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Train/Val split
|
# 训练/验证集划分
|
||||||
val_split = float(cfg.train.get('val_split', 0.1))
|
val_split = float(cfg.train.get('val_split', 0.1))
|
||||||
seed = int(cfg.train.get('seed', 42))
|
seed = int(cfg.train.get('seed', 42))
|
||||||
val_size = int(len(dataset) * val_split)
|
val_size = int(len(dataset) * val_split)
|
||||||
@@ -128,10 +128,10 @@ def main(cfg: DictConfig):
|
|||||||
[train_size, val_size],
|
[train_size, val_size],
|
||||||
generator=torch.Generator().manual_seed(seed)
|
generator=torch.Generator().manual_seed(seed)
|
||||||
)
|
)
|
||||||
log.info(f"✅ Dataset split: train={train_size}, val={val_size} (val_split={val_split})")
|
log.info(f"✅ 数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})")
|
||||||
else:
|
else:
|
||||||
train_dataset, val_dataset = dataset, None
|
train_dataset, val_dataset = dataset, None
|
||||||
log.info("✅ Dataset split: train=all, val=0 (val_split=0)")
|
log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)")
|
||||||
|
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
@@ -139,7 +139,7 @@ def main(cfg: DictConfig):
|
|||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=cfg.train.num_workers,
|
num_workers=cfg.train.num_workers,
|
||||||
pin_memory=(cfg.train.device != "cpu"),
|
pin_memory=(cfg.train.device != "cpu"),
|
||||||
drop_last=True # Drop incomplete batches for stable training
|
drop_last=True # 丢弃不完整批次以稳定训练
|
||||||
)
|
)
|
||||||
|
|
||||||
val_loader = None
|
val_loader = None
|
||||||
@@ -153,34 +153,14 @@ def main(cfg: DictConfig):
|
|||||||
drop_last=False
|
drop_last=False
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(f"✅ Train loader batches per epoch: {len(train_loader)}")
|
log.info(f"✅ 训练加载器每轮批次数: {len(train_loader)}")
|
||||||
if val_loader is not None:
|
if val_loader is not None:
|
||||||
log.info(f"✅ Val loader batches per epoch: {len(val_loader)}")
|
log.info(f"✅ 验证加载器每轮批次数: {len(val_loader)}")
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 2. Instantiate VLA Agent
|
# 2. 加载数据集统计信息(将传递给 agent)
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
log.info("🤖 Initializing VLA Agent...")
|
log.info("💾 加载数据集统计信息...")
|
||||||
try:
|
|
||||||
agent = instantiate(cfg.agent)
|
|
||||||
agent.to(cfg.train.device)
|
|
||||||
agent.train()
|
|
||||||
log.info(f"✅ Agent initialized and moved to {cfg.train.device}")
|
|
||||||
|
|
||||||
# Count parameters
|
|
||||||
total_params = sum(p.numel() for p in agent.parameters())
|
|
||||||
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
|
|
||||||
log.info(f"📊 Total parameters: {total_params:,}")
|
|
||||||
log.info(f"📊 Trainable parameters: {trainable_params:,}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"❌ Failed to initialize agent: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# 2.5. Load Dataset Statistics (will be saved into checkpoints)
|
|
||||||
# =========================================================================
|
|
||||||
log.info("💾 Loading dataset statistics...")
|
|
||||||
dataset_stats = None
|
dataset_stats = None
|
||||||
try:
|
try:
|
||||||
dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer')
|
dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer')
|
||||||
@@ -201,22 +181,43 @@ def main(cfg: DictConfig):
|
|||||||
'qpos_min': stats['qpos']['min'].tolist(),
|
'qpos_min': stats['qpos']['min'].tolist(),
|
||||||
'qpos_max': stats['qpos']['max'].tolist(),
|
'qpos_max': stats['qpos']['max'].tolist(),
|
||||||
}
|
}
|
||||||
log.info(f"✅ Dataset statistics loaded (normalization: {dataset_stats['normalization_type']})")
|
log.info(f"✅ 数据集统计信息加载完成 (归一化: {dataset_stats['normalization_type']})")
|
||||||
else:
|
else:
|
||||||
log.warning(f"⚠️ Statistics file not found: {stats_path}")
|
log.warning(f"⚠️ 统计文件未找到: {stats_path}")
|
||||||
log.warning("⚠️ Actions will not be denormalized during inference!")
|
log.warning("⚠️ 推理时动作将无法反归一化!")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"⚠️ Failed to load statistics: {e}")
|
log.warning(f"⚠️ 统计信息加载失败: {e}")
|
||||||
log.warning("⚠️ Training will continue, but inference may not work correctly")
|
log.warning("⚠️ 训练将继续,但推理可能无法正常工作")
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 3. Setup Optimizer & LR Scheduler
|
# 3. 实例化 VLA Agent
|
||||||
|
# =========================================================================
|
||||||
|
log.info("🤖 初始化 VLA Agent...")
|
||||||
|
try:
|
||||||
|
# 将 dataset_stats 和 normalization_type 传递给 agent
|
||||||
|
agent = instantiate(cfg.agent, dataset_stats=dataset_stats)
|
||||||
|
agent.to(cfg.train.device)
|
||||||
|
agent.train()
|
||||||
|
log.info(f"✅ Agent 初始化完成并已移至 {cfg.train.device}")
|
||||||
|
|
||||||
|
# 统计参数量
|
||||||
|
total_params = sum(p.numel() for p in agent.parameters())
|
||||||
|
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
|
||||||
|
log.info(f"📊 总参数量: {total_params:,}")
|
||||||
|
log.info(f"📊 可训练参数量: {trainable_params:,}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"❌ Agent 初始化失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 4. 设置优化器与学习率调度器
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
|
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
|
||||||
log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})")
|
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr})")
|
||||||
|
|
||||||
# Setup learning rate scheduler with warmup
|
# 设置带预热的学習率调度器
|
||||||
warmup_steps = int(cfg.train.get('warmup_steps', 500))
|
warmup_steps = int(cfg.train.get('warmup_steps', 500))
|
||||||
scheduler_type = cfg.train.get('scheduler_type', 'cosine')
|
scheduler_type = cfg.train.get('scheduler_type', 'cosine')
|
||||||
min_lr = float(cfg.train.get('min_lr', 1e-6))
|
min_lr = float(cfg.train.get('min_lr', 1e-6))
|
||||||
@@ -228,33 +229,36 @@ def main(cfg: DictConfig):
|
|||||||
scheduler_type=scheduler_type,
|
scheduler_type=scheduler_type,
|
||||||
min_lr=min_lr
|
min_lr=min_lr
|
||||||
)
|
)
|
||||||
log.info(f"📈 LR Scheduler: {scheduler_type} with {warmup_steps} warmup steps (min_lr={min_lr})")
|
log.info(f"📈 学习率调度器: {scheduler_type},{warmup_steps} 步预热 (最小学习率={min_lr})")
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 4. Training Loop
|
# 5. 训练循环
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
log.info("🏋️ Starting training loop...")
|
log.info("🏋️ 开始训练循环...")
|
||||||
|
|
||||||
def build_agent_input(batch_data):
|
def build_agent_input(batch_data):
|
||||||
|
"""构建 agent 输入格式"""
|
||||||
images = {}
|
images = {}
|
||||||
|
# SimpleRobotDataset 返回 observation.{cam_name} 格式
|
||||||
for cam_name in cfg.data.camera_names:
|
for cam_name in cfg.data.camera_names:
|
||||||
key = f"image_{cam_name}"
|
key = f"observation.{cam_name}"
|
||||||
if key in batch_data:
|
if key in batch_data:
|
||||||
images[cam_name] = batch_data[key]
|
images[cam_name] = batch_data[key]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'images': images,
|
'images': images,
|
||||||
'qpos': batch_data['qpos'],
|
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
|
||||||
'action': batch_data['action']
|
'action': batch_data['action']
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_validation():
|
def run_validation():
|
||||||
|
"""运行验证"""
|
||||||
if val_loader is None:
|
if val_loader is None:
|
||||||
return None
|
return None
|
||||||
agent.eval()
|
agent.eval()
|
||||||
|
|
||||||
# 🔧 FIX: Set deterministic seed for validation to get reproducible loss
|
# 设置确定性种子以获得可重现的损失
|
||||||
# This ensures validation loss is comparable across different steps
|
# 这确保验证损失在不同步骤之间可比较
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed(42)
|
torch.cuda.manual_seed(42)
|
||||||
@@ -272,7 +276,7 @@ def main(cfg: DictConfig):
|
|||||||
return total_loss / max(num_batches, 1)
|
return total_loss / max(num_batches, 1)
|
||||||
|
|
||||||
data_iter = iter(train_loader)
|
data_iter = iter(train_loader)
|
||||||
pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100)
|
pbar = tqdm(range(cfg.train.max_steps), desc="训练中", ncols=100)
|
||||||
|
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
|
||||||
@@ -280,47 +284,47 @@ def main(cfg: DictConfig):
|
|||||||
try:
|
try:
|
||||||
batch = next(data_iter)
|
batch = next(data_iter)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
# Restart iterator when epoch ends
|
# 轮次结束时重启迭代器
|
||||||
data_iter = iter(train_loader)
|
data_iter = iter(train_loader)
|
||||||
batch = next(data_iter)
|
batch = next(data_iter)
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# Move batch to device
|
# 将批次移至设备
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
batch = recursive_to_device(batch, cfg.train.device)
|
batch = recursive_to_device(batch, cfg.train.device)
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# Prepare agent input
|
# 准备 agent 输入
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# Dataset returns: {action, qpos, image_<cam_name>, ...}
|
# 数据集返回: {action, qpos, image_<cam_name>, ...}
|
||||||
# Agent expects: {images: dict, qpos: tensor, action: tensor}
|
# Agent 期望: {images: dict, qpos: tensor, action: tensor}
|
||||||
|
|
||||||
# Prepare agent input
|
# 准备 agent 输入
|
||||||
agent_input = build_agent_input(batch)
|
agent_input = build_agent_input(batch)
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# Forward pass & compute loss
|
# 前向传播与损失计算
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
try:
|
try:
|
||||||
loss = agent.compute_loss(agent_input)
|
loss = agent.compute_loss(agent_input)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"❌ Forward pass failed at step {step}: {e}")
|
log.error(f"❌ 步骤 {step} 前向传播失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# Backward pass & optimization
|
# 反向传播与优化
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# Gradient clipping for stable training
|
# 梯度裁剪以稳定训练
|
||||||
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
|
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# Logging
|
# 日志记录
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
if step % cfg.train.log_freq == 0:
|
if step % cfg.train.log_freq == 0:
|
||||||
current_lr = optimizer.param_groups[0]['lr']
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
@@ -329,16 +333,16 @@ def main(cfg: DictConfig):
|
|||||||
"lr": f"{current_lr:.2e}",
|
"lr": f"{current_lr:.2e}",
|
||||||
"best_loss": f"{best_loss:.4f}"
|
"best_loss": f"{best_loss:.4f}"
|
||||||
})
|
})
|
||||||
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f} | LR: {current_lr:.2e}")
|
log.info(f"步骤 {step}/{cfg.train.max_steps} | 损失: {loss.item():.4f} | 学习率: {current_lr:.2e}")
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# Checkpoint saving & Validation
|
# 检查点保存与验证
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
if step > 0 and step % cfg.train.save_freq == 0:
|
if step > 0 and step % cfg.train.save_freq == 0:
|
||||||
# Run validation
|
# 运行验证
|
||||||
val_loss = run_validation()
|
val_loss = run_validation()
|
||||||
if val_loss is not None:
|
if val_loss is not None:
|
||||||
log.info(f"Step {step}/{cfg.train.max_steps} | Val Loss: {val_loss:.4f}")
|
log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}")
|
||||||
|
|
||||||
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
||||||
torch.save({
|
torch.save({
|
||||||
@@ -351,9 +355,9 @@ def main(cfg: DictConfig):
|
|||||||
'dataset_stats': dataset_stats,
|
'dataset_stats': dataset_stats,
|
||||||
'current_lr': optimizer.param_groups[0]['lr'],
|
'current_lr': optimizer.param_groups[0]['lr'],
|
||||||
}, checkpoint_path)
|
}, checkpoint_path)
|
||||||
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
|
log.info(f"💾 检查点已保存: {checkpoint_path}")
|
||||||
|
|
||||||
# Save best model based on validation loss
|
# 根据验证损失保存最佳模型
|
||||||
eval_loss = val_loss if val_loss is not None else loss.item()
|
eval_loss = val_loss if val_loss is not None else loss.item()
|
||||||
if eval_loss < best_loss:
|
if eval_loss < best_loss:
|
||||||
best_loss = eval_loss
|
best_loss = eval_loss
|
||||||
@@ -368,10 +372,10 @@ def main(cfg: DictConfig):
|
|||||||
'dataset_stats': dataset_stats,
|
'dataset_stats': dataset_stats,
|
||||||
'current_lr': optimizer.param_groups[0]['lr'],
|
'current_lr': optimizer.param_groups[0]['lr'],
|
||||||
}, best_model_path)
|
}, best_model_path)
|
||||||
log.info(f"🌟 Best model updated: {best_model_path} (val_loss: {best_loss:.4f})")
|
log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})")
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 5. Save Final Model
|
# 6. 保存最终模型
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
final_model_path = checkpoint_dir / "vla_model_final.pt"
|
final_model_path = checkpoint_dir / "vla_model_final.pt"
|
||||||
torch.save({
|
torch.save({
|
||||||
@@ -383,11 +387,11 @@ def main(cfg: DictConfig):
|
|||||||
'dataset_stats': dataset_stats,
|
'dataset_stats': dataset_stats,
|
||||||
'current_lr': optimizer.param_groups[0]['lr'],
|
'current_lr': optimizer.param_groups[0]['lr'],
|
||||||
}, final_model_path)
|
}, final_model_path)
|
||||||
log.info(f"💾 Final model saved: {final_model_path}")
|
log.info(f"💾 最终模型已保存: {final_model_path}")
|
||||||
|
|
||||||
log.info("✅ Training completed successfully!")
|
log.info("✅ 训练成功完成!")
|
||||||
log.info(f"📊 Final Loss: {loss.item():.4f}")
|
log.info(f"📊 最终损失: {loss.item():.4f}")
|
||||||
log.info(f"📊 Best Loss: {best_loss:.4f}")
|
log.info(f"📊 最佳损失: {best_loss:.4f}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,17 +1,19 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Dict, Optional, Any
|
from collections import deque
|
||||||
|
from typing import Dict, Optional, Any, Tuple
|
||||||
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D
|
from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D
|
||||||
|
from roboimi.vla.models.normalization import NormalizationModule
|
||||||
|
|
||||||
class VLAAgent(nn.Module):
|
class VLAAgent(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vision_backbone, # 你之前定义的 ResNet 类
|
vision_backbone, # 视觉编码器(ResNet 等)
|
||||||
state_encoder,
|
state_encoder,
|
||||||
action_encoder,
|
action_encoder,
|
||||||
head,
|
head,
|
||||||
@@ -19,23 +21,35 @@ class VLAAgent(nn.Module):
|
|||||||
obs_dim, # 本体感知维度 (例如 关节角度)
|
obs_dim, # 本体感知维度 (例如 关节角度)
|
||||||
pred_horizon=16, # 预测未来多少步动作
|
pred_horizon=16, # 预测未来多少步动作
|
||||||
obs_horizon=4, # 使用多少步历史观测
|
obs_horizon=4, # 使用多少步历史观测
|
||||||
diffusion_steps=100,
|
diffusion_steps=100, # DDPM 加噪步数
|
||||||
|
inference_steps=10, # DDIM 推理步数
|
||||||
num_cams=3, # 视觉输入的摄像头数量
|
num_cams=3, # 视觉输入的摄像头数量
|
||||||
|
dataset_stats=None, # 数据集统计信息,用于归一化
|
||||||
|
normalization_type='gaussian', # 归一化类型: 'gaussian' 或 'min_max'
|
||||||
|
num_action_steps=1, # 每次推理实际执行多少步动作
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Store parameters
|
# 保存参数
|
||||||
self.action_dim = action_dim
|
self.action_dim = action_dim
|
||||||
self.obs_dim = obs_dim
|
self.obs_dim = obs_dim
|
||||||
self.pred_horizon = pred_horizon
|
self.pred_horizon = pred_horizon
|
||||||
self.obs_horizon = obs_horizon
|
self.obs_horizon = obs_horizon
|
||||||
self.num_cams = num_cams
|
self.num_cams = num_cams
|
||||||
|
self.num_action_steps = num_action_steps
|
||||||
|
self.inference_steps = inference_steps
|
||||||
|
|
||||||
|
|
||||||
|
# 归一化模块 - 统一训练和推理的归一化逻辑
|
||||||
|
self.normalization = NormalizationModule(
|
||||||
|
stats=dataset_stats,
|
||||||
|
normalization_type=normalization_type
|
||||||
|
)
|
||||||
|
|
||||||
self.vision_encoder = vision_backbone
|
self.vision_encoder = vision_backbone
|
||||||
single_img_feat_dim = self.vision_encoder.output_dim
|
single_cam_feat_dim = self.vision_encoder.output_dim
|
||||||
total_vision_dim = single_img_feat_dim * num_cams * obs_horizon
|
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
|
||||||
total_prop_dim = obs_dim * obs_horizon
|
total_prop_dim = obs_dim * obs_horizon
|
||||||
self.global_cond_dim = total_vision_dim + total_prop_dim
|
self.global_cond_dim = total_vision_dim + total_prop_dim
|
||||||
# self.global_cond_dim = total_vision_dim
|
|
||||||
|
|
||||||
self.noise_scheduler = DDPMScheduler(
|
self.noise_scheduler = DDPMScheduler(
|
||||||
num_train_timesteps=diffusion_steps,
|
num_train_timesteps=diffusion_steps,
|
||||||
@@ -44,7 +58,7 @@ class VLAAgent(nn.Module):
|
|||||||
prediction_type='epsilon' # 预测噪声
|
prediction_type='epsilon' # 预测噪声
|
||||||
)
|
)
|
||||||
|
|
||||||
# DDIM scheduler for faster inference
|
# DDIM 调度器用于快速推理
|
||||||
self.infer_scheduler = DDIMScheduler(
|
self.infer_scheduler = DDIMScheduler(
|
||||||
num_train_timesteps=diffusion_steps,
|
num_train_timesteps=diffusion_steps,
|
||||||
beta_schedule='squaredcos_cap_v2',
|
beta_schedule='squaredcos_cap_v2',
|
||||||
@@ -54,84 +68,256 @@ class VLAAgent(nn.Module):
|
|||||||
|
|
||||||
self.noise_pred_net = head(
|
self.noise_pred_net = head(
|
||||||
input_dim=action_dim,
|
input_dim=action_dim,
|
||||||
# input_dim = action_dim + obs_dim,
|
# input_dim = action_dim + obs_dim, # 备选:包含观测维度
|
||||||
global_cond_dim=self.global_cond_dim
|
global_cond_dim=self.global_cond_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
self.state_encoder = state_encoder
|
self.state_encoder = state_encoder
|
||||||
self.action_encoder = action_encoder
|
self.action_encoder = action_encoder
|
||||||
|
|
||||||
|
# 初始化队列(用于在线推理)
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
|
||||||
# ==========================
|
# ==========================
|
||||||
# 训练阶段 (Training)
|
# 训练阶段 (Training)
|
||||||
# ==========================
|
# ==========================
|
||||||
def compute_loss(self, batch):
|
def compute_loss(self, batch):
|
||||||
"""
|
"""
|
||||||
batch: 包含 images, qpos (proprioception), action
|
计算训练损失
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: 包含 images, qpos (本体感知), action 的字典
|
||||||
"""
|
"""
|
||||||
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
||||||
B = actions.shape[0]
|
B = actions.shape[0]
|
||||||
|
|
||||||
|
# 归一化 states (qpos) 和 actions
|
||||||
|
states = self.normalization.normalize_qpos(states)
|
||||||
|
actions = self.normalization.normalize_action(actions)
|
||||||
|
|
||||||
state_features = self.state_encoder(states)
|
state_features = self.state_encoder(states)
|
||||||
|
|
||||||
# 1. 提取视觉特征
|
# 1. 提取视觉特征
|
||||||
visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim)
|
visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim)
|
||||||
action_features = self.action_encoder(actions)
|
action_features = self.action_encoder(actions)
|
||||||
|
|
||||||
# 3. 采样噪声
|
# 2. 采样噪声
|
||||||
noise = torch.randn_like(action_features)
|
noise = torch.randn_like(action_features)
|
||||||
|
|
||||||
# 4. 随机采样时间步 (Timesteps)
|
# 3. 随机采样时间步 (Timesteps)
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
0, self.noise_scheduler.config.num_train_timesteps,
|
0, self.noise_scheduler.config.num_train_timesteps,
|
||||||
(B,), device=action_features.device
|
(B,), device=action_features.device
|
||||||
).long()
|
).long()
|
||||||
|
|
||||||
# 5. 给动作加噪 (Forward Diffusion)
|
# 4. 给动作加噪 (Forward Diffusion)
|
||||||
noisy_actions = self.noise_scheduler.add_noise(
|
noisy_actions = self.noise_scheduler.add_noise(
|
||||||
action_features, noise, timesteps
|
action_features, noise, timesteps
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. 网络预测噪声
|
# 5. 网络预测噪声
|
||||||
pred_noise = self.noise_pred_net(
|
pred_noise = self.noise_pred_net(
|
||||||
sample=noisy_actions,
|
sample=noisy_actions,
|
||||||
timestep=timesteps,
|
timestep=timesteps,
|
||||||
visual_features=visual_features,
|
visual_features=visual_features,
|
||||||
proprioception=state_features
|
proprioception=state_features
|
||||||
)
|
)
|
||||||
|
|
||||||
# 7. 计算 Loss (MSE)
|
# 6. 计算 Loss (MSE)
|
||||||
loss = nn.functional.mse_loss(pred_noise, noise)
|
loss = nn.functional.mse_loss(pred_noise, noise)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
# ==========================
|
# ==========================
|
||||||
# 推理阶段 (Inference)
|
# 队列管理 (Queue Management)
|
||||||
|
# ==========================
|
||||||
|
def reset(self):
|
||||||
|
"""清空观测和动作队列。应在 env.reset() 时调用"""
|
||||||
|
self._queues = {
|
||||||
|
'qpos': deque(maxlen=self.obs_horizon),
|
||||||
|
'images': deque(maxlen=self.obs_horizon),
|
||||||
|
'action': deque(maxlen=self.pred_horizon - self.obs_horizon + 1), # 可执行的动作缓存
|
||||||
|
}
|
||||||
|
|
||||||
|
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
|
||||||
|
"""
|
||||||
|
将新的观测添加到队列中。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation: 包含 'qpos' 和 'images' 的字典
|
||||||
|
"""
|
||||||
|
# 添加本体感知
|
||||||
|
if 'qpos' in observation:
|
||||||
|
self._queues['qpos'].append(observation['qpos'].clone())
|
||||||
|
|
||||||
|
# 添加图像
|
||||||
|
if 'images' in observation:
|
||||||
|
self._queues['images'].append({k: v.clone() for k, v in observation['images'].items()})
|
||||||
|
|
||||||
|
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
从队列中准备用于推理的批量观测。
|
||||||
|
如果队列未满(首次调用时),用最新观测重复填充。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
batch: 包含堆叠后的历史观测的字典
|
||||||
|
"""
|
||||||
|
# 堆叠历史本体感知
|
||||||
|
qpos_list = list(self._queues['qpos'])
|
||||||
|
if len(qpos_list) == 0:
|
||||||
|
raise ValueError("观测队列为空,请先调用 _populate_queues 添加观测")
|
||||||
|
# 如果队列未满,用最后一个观测填充
|
||||||
|
while len(qpos_list) < self.obs_horizon:
|
||||||
|
qpos_list.append(qpos_list[-1])
|
||||||
|
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
|
||||||
|
|
||||||
|
# 堆叠历史图像
|
||||||
|
images_list = list(self._queues['images'])
|
||||||
|
if len(images_list) == 0:
|
||||||
|
raise ValueError("图像队列为空,请先调用 _populate_queues 添加观测")
|
||||||
|
# 如果队列未满,用最后一个观测填充
|
||||||
|
while len(images_list) < self.obs_horizon:
|
||||||
|
images_list.append(images_list[-1])
|
||||||
|
|
||||||
|
batch_images = {}
|
||||||
|
for cam_name in images_list[0].keys():
|
||||||
|
batch_images[cam_name] = torch.stack([img[cam_name] for img in images_list], dim=0).unsqueeze(0)
|
||||||
|
|
||||||
|
return {'qpos': batch_qpos, 'images': batch_images}
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 在线推理 (Online Inference)
|
||||||
|
# ==========================
|
||||||
|
@torch.no_grad()
|
||||||
|
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
根据当前观测选择单个动作。
|
||||||
|
|
||||||
|
这个方法维护一个历史观测和生成动作轨迹的缓存。工作流程:
|
||||||
|
- 缓存 `obs_horizon` 步的历史观测
|
||||||
|
- Diffusion 模型生成 `pred_horizon` 步的动作
|
||||||
|
- 实际执行 `num_action_steps` 步动作
|
||||||
|
|
||||||
|
示意图:
|
||||||
|
--------------------------------------------------------------
|
||||||
|
(图例: o=obs_horizon, h=pred_horizon, a=num_action_steps)
|
||||||
|
|时间步 | 0 | 1 | ... | o-1 | o | ... | h-1 |
|
||||||
|
|观测是否使用 | 是 | 是 | 是 | 是 | 否 | 否 | 否 |
|
||||||
|
|动作是否生成 | 是 | 是 | 是 | 是 | 是 | 是 | 是 |
|
||||||
|
|动作是否执行 | 否 | 否 | 否 | 否 | 是 | 是 | 是 |
|
||||||
|
--------------------------------------------------------------
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation: 包含 'qpos' 和 'images' 的字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
action: (action_dim,) 单个动作
|
||||||
|
"""
|
||||||
|
# 检测设备并确保所有组件在同一设备上
|
||||||
|
# 尝试从观测中获取设备
|
||||||
|
device = None
|
||||||
|
for v in observation.values():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
device = v.device
|
||||||
|
break
|
||||||
|
|
||||||
|
if device is not None and self.normalization.enabled:
|
||||||
|
# 确保归一化参数在同一设备上
|
||||||
|
norm_device = self.normalization.qpos_mean.device
|
||||||
|
if device != norm_device:
|
||||||
|
self.normalization.to(device)
|
||||||
|
# 同时确保其他模块也在正确设备
|
||||||
|
self.vision_encoder.to(device)
|
||||||
|
self.state_encoder.to(device)
|
||||||
|
self.action_encoder.to(device)
|
||||||
|
self.noise_pred_net.to(device)
|
||||||
|
|
||||||
|
# 将所有 observation 移到正确设备
|
||||||
|
observation = {k: v.to(device) if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in observation.items()}
|
||||||
|
|
||||||
|
# 将新观测添加到队列
|
||||||
|
self._populate_queues(observation)
|
||||||
|
|
||||||
|
# 如果动作队列为空,生成新的动作序列
|
||||||
|
if len(self._queues['action']) == 0:
|
||||||
|
# 从队列准备批量观测
|
||||||
|
batch = self._prepare_observation_batch()
|
||||||
|
|
||||||
|
# 生成动作块
|
||||||
|
actions = self.predict_action_chunk(batch) # (1, pred_horizon, action_dim)
|
||||||
|
|
||||||
|
# 提取可执行的动作部分
|
||||||
|
# 从 obs_horizon-1 开始,因为前面的动作对应过去的观测
|
||||||
|
start = self.obs_horizon - 1
|
||||||
|
end = start + self.num_action_steps
|
||||||
|
executable_actions = actions[:, start:end] # (1, num_action_steps, action_dim)
|
||||||
|
|
||||||
|
# 将动作添加到队列
|
||||||
|
for i in range(executable_actions.shape[1]):
|
||||||
|
self._queues['action'].append(executable_actions[:, i].squeeze(0)) # (action_dim,)
|
||||||
|
|
||||||
|
# 从队列中取出一个动作
|
||||||
|
action = self._queues['action'].popleft() # (action_dim,)
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
预测一个动作块(用于在线推理)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: 包含 'qpos' 和 'images' 的字典
|
||||||
|
- qpos: (B, obs_horizon, obs_dim)
|
||||||
|
- images: Dict[str, (B, obs_horizon, C, H, W)]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
actions: (B, pred_horizon, action_dim) 预测的动作序列
|
||||||
|
"""
|
||||||
|
return self.predict_action(batch['images'], batch['qpos'])
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 批量推理 (Batch Inference - 原有方法)
|
||||||
# ==========================
|
# ==========================
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action(self, images, proprioception):
|
def predict_action(self, images, proprioception):
|
||||||
|
"""
|
||||||
|
批量预测动作序列(用于训练和离线评估)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: 图像观测字典
|
||||||
|
proprioception: 本体感知观测 (qpos)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
denormalized_actions: 反归一化后的动作序列
|
||||||
|
"""
|
||||||
B = proprioception.shape[0]
|
B = proprioception.shape[0]
|
||||||
|
|
||||||
# 1. 提取当前观测特征 (只做一次)
|
# 归一化 proprioception (qpos)
|
||||||
|
proprioception = self.normalization.normalize_qpos(proprioception)
|
||||||
|
|
||||||
|
# 1. 提取当前观测特征(只提取一次)
|
||||||
visual_features = self.vision_encoder(images)
|
visual_features = self.vision_encoder(images)
|
||||||
state_features = self.state_encoder(proprioception)
|
state_features = self.state_encoder(proprioception)
|
||||||
|
|
||||||
# 2. 初始化纯高斯噪声动作
|
# 2. 初始化纯高斯噪声动作
|
||||||
# Shape: (B, pred_horizon, action_dim)
|
# 形状: (B, pred_horizon, action_dim)
|
||||||
device = visual_features.device
|
device = visual_features.device
|
||||||
current_actions = torch.randn(
|
current_actions = torch.randn(
|
||||||
(B, self.pred_horizon, self.action_dim), device=device
|
(B, self.pred_horizon, self.action_dim), device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 逐步去噪循环 (Reverse Diffusion)
|
# 3. 逐步去噪循环 (Reverse Diffusion)
|
||||||
self.infer_scheduler.set_timesteps(10) # DDIM 推理步数
|
self.infer_scheduler.set_timesteps(self.inference_steps) # DDIM 推理步数
|
||||||
|
|
||||||
for t in self.infer_scheduler.timesteps:
|
for t in self.infer_scheduler.timesteps:
|
||||||
model_input = current_actions
|
model_input = current_actions
|
||||||
|
|
||||||
# 预测噪声
|
# 预测噪声
|
||||||
noise_pred = self.noise_pred_net(
|
noise_pred = self.noise_pred_net(
|
||||||
sample=model_input,
|
sample=model_input,
|
||||||
timestep=t,
|
timestep=t,
|
||||||
visual_features=visual_features,
|
visual_features=visual_features,
|
||||||
proprioception=state_features
|
proprioception=state_features
|
||||||
)
|
)
|
||||||
@@ -141,5 +327,11 @@ class VLAAgent(nn.Module):
|
|||||||
noise_pred, t, current_actions
|
noise_pred, t, current_actions
|
||||||
).prev_sample
|
).prev_sample
|
||||||
|
|
||||||
# 4. 输出最终动作序列(归一化空间,由调用方负责反归一化)
|
# 4. 反归一化动作序列
|
||||||
return current_actions
|
denormalized_actions = self.normalization.denormalize_action(current_actions)
|
||||||
|
|
||||||
|
return denormalized_actions
|
||||||
|
|
||||||
|
def get_normalization_stats(self):
|
||||||
|
"""获取归一化统计信息(用于保存到 checkpoint)"""
|
||||||
|
return self.normalization.get_stats()
|
||||||
|
|||||||
@@ -9,14 +9,26 @@ defaults:
|
|||||||
|
|
||||||
_target_: roboimi.vla.agent.VLAAgent
|
_target_: roboimi.vla.agent.VLAAgent
|
||||||
|
|
||||||
# Action and Observation Dimensions
|
# ====================
|
||||||
action_dim: 16
|
# 模型维度配置
|
||||||
obs_dim: 16
|
# ====================
|
||||||
|
action_dim: 16 # 动作维度(机器人关节数)
|
||||||
|
obs_dim: 16 # 本体感知维度(关节位置)
|
||||||
|
|
||||||
# Prediction and Observation Horizons
|
# ====================
|
||||||
pred_horizon: 16
|
# 时间步配置
|
||||||
obs_horizon: 2
|
# ====================
|
||||||
|
pred_horizon: 16 # 预测未来多少步动作
|
||||||
|
obs_horizon: 2 # 使用多少步历史观测
|
||||||
|
num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1)
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 相机配置
|
||||||
|
# ====================
|
||||||
|
num_cams: 3 # 摄像头数量 (r_vis, top, front)
|
||||||
|
|
||||||
# Camera Configuration
|
# ====================
|
||||||
num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取
|
# 扩散过程配置
|
||||||
|
# ====================
|
||||||
|
diffusion_steps: 100 # 扩散训练步数(DDPM)
|
||||||
|
inference_steps: 10 # 推理时的去噪步数(DDIM,固定为 10)
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
|
|
||||||
|
|
||||||
model_name: "microsoft/resnet-18"
|
|
||||||
freeze: true
|
|
||||||
@@ -1,8 +1,28 @@
|
|||||||
_target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
|
_target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
|
||||||
vision_backbone: "resnet18"
|
|
||||||
pretrained_backbone_weights: null
|
# ====================
|
||||||
input_shape: [3, 96, 96]
|
# 骨干网络选择
|
||||||
crop_shape: [84, 84]
|
# ====================
|
||||||
crop_is_random: true
|
vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50
|
||||||
use_group_norm: true
|
pretrained_backbone_weights: null # 预训练权重路径或 null(ImageNet 权重)
|
||||||
spatial_softmax_num_keypoints: 32
|
|
||||||
|
# ====================
|
||||||
|
# 输入配置
|
||||||
|
# ====================
|
||||||
|
input_shape: [3, 96, 96] # 输入图像形状 (C, H, W)
|
||||||
|
crop_shape: [84, 84] # 裁剪后的图像形状 (H, W)
|
||||||
|
crop_is_random: true # 训练时使用随机裁剪,评估时使用中心裁剪
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 归一化和特征提取
|
||||||
|
# ====================
|
||||||
|
use_group_norm: true # 使用 GroupNorm 替代 BatchNorm(更适合小批次训练)
|
||||||
|
spatial_softmax_num_keypoints: 32 # Spatial Softmax 关键点数量
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 编码器模式
|
||||||
|
# ====================
|
||||||
|
# false: 共享编码器(所有摄像头共享一个 ResNet,参数少但容量受限)推荐!
|
||||||
|
# true: 独立编码器(每个摄像头有独立的 ResNet,参数多但容量大)
|
||||||
|
use_separate_rgb_encoder_per_camera: true
|
||||||
|
num_cameras: 3 # 摄像头数量
|
||||||
@@ -1,19 +1,41 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- agent: resnet_diffusion
|
- agent: resnet_diffusion
|
||||||
- data: resnet_dataset
|
- data: simpe_robot_dataset
|
||||||
- eval: eval
|
- eval: eval
|
||||||
- _self_
|
- _self_
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 训练配置
|
||||||
|
# ====================
|
||||||
train:
|
train:
|
||||||
batch_size: 8 # Batch size for training
|
# 基础训练参数
|
||||||
lr: 1e-4 # Learning rate
|
batch_size: 8 # 批次大小
|
||||||
max_steps: 20000 # Maximum training steps
|
lr: 1e-4 # 学习率
|
||||||
log_freq: 100 # Log frequency (steps)
|
max_steps: 100000 # 最大训练步数
|
||||||
save_freq: 2000 # Save checkpoint frequency (steps)
|
device: "cuda" # 设备: "cuda" 或 "cpu"
|
||||||
device: "cuda" # Device: "cuda" or "cpu"
|
|
||||||
num_workers: 8 # DataLoader workers (set to 0 for debugging, 8 for production)
|
|
||||||
|
|
||||||
# Learning rate scheduler with warmup
|
# 数据加载
|
||||||
warmup_steps: 500 # Number of warmup steps
|
num_workers: 8 # DataLoader 工作进程数(调试时设为 0,生产环境用 8)
|
||||||
scheduler_type: "cosine" # Scheduler after warmup: "constant" or "cosine"
|
val_split: 0.1 # 验证集比例
|
||||||
min_lr: 1e-6 # Minimum learning rate (for cosine decay)
|
seed: 42 # 随机种子(用于数据划分)
|
||||||
|
|
||||||
|
# 日志和检查点
|
||||||
|
log_freq: 100 # 日志记录频率(步数)
|
||||||
|
save_freq: 5000 # 保存检查点频率(步数)
|
||||||
|
|
||||||
|
# 学习率调度器(带预热)
|
||||||
|
warmup_steps: 500 # 预热步数
|
||||||
|
scheduler_type: "cosine" # 预热后的调度器: "constant" 或 "cosine"
|
||||||
|
min_lr: 1e-6 # 最小学习率(用于余弦退火)
|
||||||
|
|
||||||
|
# 优化器
|
||||||
|
weight_decay: 1e-5 # 权重衰减(L2 正则化)
|
||||||
|
grad_clip: 1.0 # 梯度裁剪阈值
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 实验配置
|
||||||
|
# ====================
|
||||||
|
experiment:
|
||||||
|
name: "vla_diffusion" # 实验名称
|
||||||
|
notes: "" # 实验备注
|
||||||
|
tags: [] # 实验标签
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
# @package data
|
|
||||||
_target_: roboimi.vla.data.dataset.RobotDiffusionDataset
|
|
||||||
|
|
||||||
# Dataset Directory (CHANGE THIS TO YOUR DATA PATH)
|
|
||||||
dataset_dir: "roboimi/demos/dataset/sim_transfer" # Path to your dataset directory
|
|
||||||
|
|
||||||
# Horizon Parameters — 使用 Hydra 插值,从 agent 配置中引用,保持一致性
|
|
||||||
pred_horizon: ${agent.pred_horizon}
|
|
||||||
obs_horizon: ${agent.obs_horizon}
|
|
||||||
action_horizon: 8 # Action execution horizon (used during evaluation)
|
|
||||||
|
|
||||||
# Camera Names (CHANGE THIS TO MATCH YOUR CAMERAS)
|
|
||||||
camera_names:
|
|
||||||
- r_vis
|
|
||||||
- top
|
|
||||||
- front
|
|
||||||
|
|
||||||
# Normalization Type: 'gaussian' (mean/std) or 'min_max' ([-1, 1])
|
|
||||||
normalization_type: min_max
|
|
||||||
21
roboimi/vla/conf/data/simpe_robot_dataset.yaml
Normal file
21
roboimi/vla/conf/data/simpe_robot_dataset.yaml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# @package data
|
||||||
|
_target_: roboimi.vla.data.simpe_robot_dataset.SimpleRobotDataset
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 数据集路径
|
||||||
|
# ====================
|
||||||
|
dataset_dir: "roboimi/demos/dataset/sim_transfer"
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 时间步参数(从 agent 配置引用)
|
||||||
|
# ====================
|
||||||
|
pred_horizon: ${agent.pred_horizon} # 预测步数
|
||||||
|
obs_horizon: ${agent.obs_horizon} # 观测步数
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 相机配置
|
||||||
|
# ====================
|
||||||
|
camera_names:
|
||||||
|
- r_vis # 机器人视角相机
|
||||||
|
- top # 顶部相机
|
||||||
|
- front # 前方相机
|
||||||
@@ -1,19 +1,27 @@
|
|||||||
# @package eval
|
# @package eval
|
||||||
# Evaluation Configuration
|
# 评估配置
|
||||||
ckpt_path: "checkpoints/vla_model_best.pt" # Path to model checkpoint
|
ckpt_path: "checkpoints/vla_model_best.pt" # 模型检查点路径
|
||||||
num_episodes: 3 # Number of evaluation episodes
|
num_episodes: 3 # 评估回合数
|
||||||
max_timesteps: 700 # Maximum timesteps per episode
|
max_timesteps: 700 # 每回合最大时间步
|
||||||
device: ${train.device} # 与训练保持一致
|
device: ${train.device} # 与训练保持一致
|
||||||
task_name: "sim_transfer" # Task name for environment creation
|
task_name: "sim_transfer" # 环境任务名称
|
||||||
|
|
||||||
# Policy execution — 从 agent 配置中引用,保持一致性
|
# ====================
|
||||||
num_queries: 4 # 每次预测 pred_horizon 步后重新查询
|
# 策略执行参数
|
||||||
|
# ====================
|
||||||
|
# num_queries 已废弃,现在使用 agent 的 select_action() 自动管理队列
|
||||||
|
# 以下参数仅用于兼容旧代码,实际使用 agent.num_action_steps
|
||||||
|
num_queries: ${agent.num_action_steps}
|
||||||
obs_horizon: ${agent.obs_horizon}
|
obs_horizon: ${agent.obs_horizon}
|
||||||
|
|
||||||
# Camera names — 从 data 配置中引用,保持一致性
|
# ====================
|
||||||
|
# 相机配置
|
||||||
|
# ====================
|
||||||
camera_names: ${data.camera_names}
|
camera_names: ${data.camera_names}
|
||||||
|
|
||||||
# Action smoothing
|
# ====================
|
||||||
|
# 动作平滑
|
||||||
|
# ====================
|
||||||
use_smoothing: false
|
use_smoothing: false
|
||||||
smooth_method: "ema"
|
smooth_method: "ema"
|
||||||
smooth_alpha: 0.3
|
smooth_alpha: 0.3
|
||||||
|
|||||||
@@ -1,5 +1,15 @@
|
|||||||
_target_: roboimi.vla.models.heads.conditional_unet1d.ConditionalUnet1D
|
_target_: roboimi.vla.models.heads.conditional_unet1d.ConditionalUnet1D
|
||||||
_partial_: true
|
_partial_: true
|
||||||
|
|
||||||
kernel_size: 3
|
# ====================
|
||||||
cond_predict_scale: false
|
# UNet1D 配置
|
||||||
|
# ====================
|
||||||
|
kernel_size: 3 # 卷积核大小
|
||||||
|
cond_predict_scale: false # FiLM 条件化时是否同时预测 scale(bias + scale 或仅 bias)
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 网络架构(默认值,可覆盖)
|
||||||
|
# ====================
|
||||||
|
# diffusion_step_embed_dim: 256 # 扩散时间步嵌入维度
|
||||||
|
# down_dims: [256, 512, 1024] # 下采样各层通道数
|
||||||
|
# n_groups: 8 # GroupNorm 分组数
|
||||||
|
|||||||
@@ -1,152 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
import h5py
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
class RobotDiffusionDataset(Dataset):
|
|
||||||
def __init__(self,
|
|
||||||
dataset_dir,
|
|
||||||
pred_horizon=16,
|
|
||||||
obs_horizon=2,
|
|
||||||
action_horizon=8,
|
|
||||||
camera_names=['r_vis', 'top', 'front'],
|
|
||||||
normalization_type='gaussian'):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
dataset_dir: 存放 episode_*.hdf5 的文件夹路径
|
|
||||||
pred_horizon: 预测未来动作的长度 (Tp)
|
|
||||||
obs_horizon: 历史观测长度 (To)
|
|
||||||
action_horizon: 执行动作长度 (Ta) - 在Dataset中主要影响Evaluation,这里作为参数保留
|
|
||||||
"""
|
|
||||||
self.dataset_dir = dataset_dir
|
|
||||||
self.pred_horizon = pred_horizon
|
|
||||||
self.obs_horizon = obs_horizon
|
|
||||||
self.action_horizon = action_horizon
|
|
||||||
self.camera_names = camera_names
|
|
||||||
self.normalization_type = normalization_type
|
|
||||||
# 1. 扫描所有HDF5文件并建立索引
|
|
||||||
# 格式: [(file_path, episode_length), ...]
|
|
||||||
self.episode_files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
|
|
||||||
self.indices = []
|
|
||||||
|
|
||||||
print(f"Found {len(self.episode_files)} episodes. Building index...")
|
|
||||||
|
|
||||||
for file_path in self.episode_files:
|
|
||||||
with h5py.File(file_path, 'r') as f:
|
|
||||||
# 获取该 episode 的长度 (例如 700)
|
|
||||||
l = f['action'].shape[0]
|
|
||||||
# 保存每个有效 step 的索引信息
|
|
||||||
# (file_path, episode_length, current_step_index)
|
|
||||||
for i in range(l):
|
|
||||||
self.indices.append((file_path, l, i))
|
|
||||||
|
|
||||||
# 2. 统计数据
|
|
||||||
with open(os.path.join(dataset_dir, 'data_stats.pkl'), 'rb') as f:
|
|
||||||
self.stats = pickle.load(f)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.indices)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
file_path, episode_len, start_ts = self.indices[idx]
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 1. 打开文件
|
|
||||||
# -----------------------------
|
|
||||||
# 注意: 在 __getitem__ 中打开文件对多进程 DataLoader 更友好
|
|
||||||
# 如果追求极致IO性能,可以考虑使用 h5py 的 swmr 模式或内存缓存
|
|
||||||
with h5py.File(file_path, 'r') as root:
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 2. 处理 Action (Prediction Target)
|
|
||||||
# -----------------------------
|
|
||||||
# 目标: 获取 [t, t + pred_horizon] 的动作
|
|
||||||
action_start = start_ts
|
|
||||||
action_end = min(start_ts + self.pred_horizon, episode_len)
|
|
||||||
|
|
||||||
actions = root['action'][action_start:action_end] # shape: (T_subset, 16)
|
|
||||||
|
|
||||||
# Padding: 如果剩余动作不足 pred_horizon,复制最后一步
|
|
||||||
if len(actions) < self.pred_horizon:
|
|
||||||
pad_len = self.pred_horizon - len(actions)
|
|
||||||
last_action = actions[-1]
|
|
||||||
# 重复最后一行
|
|
||||||
pad_content = np.repeat(last_action[np.newaxis, :], pad_len, axis=0)
|
|
||||||
actions = np.concatenate([actions, pad_content], axis=0)
|
|
||||||
|
|
||||||
# 归一化 Action
|
|
||||||
if self.stats:
|
|
||||||
actions = self._normalize_data(actions, self.stats['action'])
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# 3. 处理 Observations (History)
|
|
||||||
# -----------------------------
|
|
||||||
# 目标: 获取 [t - obs_horizon + 1, t + 1] 的观测
|
|
||||||
# 索引逻辑:
|
|
||||||
# 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding)
|
|
||||||
# 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5]
|
|
||||||
|
|
||||||
start_idx_raw = start_ts - (self.obs_horizon - 1)
|
|
||||||
start_idx = max(start_idx_raw, 0)
|
|
||||||
end_idx = start_ts + 1
|
|
||||||
pad_len = max(0, -start_idx_raw)
|
|
||||||
|
|
||||||
# Qpos
|
|
||||||
qpos_data = root['observations/qpos']
|
|
||||||
qpos_val = qpos_data[start_idx:end_idx]
|
|
||||||
|
|
||||||
if pad_len > 0:
|
|
||||||
first_frame = qpos_val[0]
|
|
||||||
padding = np.repeat(first_frame[np.newaxis, :], pad_len, axis=0)
|
|
||||||
qpos_val = np.concatenate([padding, qpos_val], axis=0)
|
|
||||||
|
|
||||||
if self.stats:
|
|
||||||
qpos_val = self._normalize_data(qpos_val, self.stats['qpos'])
|
|
||||||
|
|
||||||
# Images
|
|
||||||
image_dict = {}
|
|
||||||
for cam_name in self.camera_names:
|
|
||||||
img_dset = root['observations']['images'][cam_name]
|
|
||||||
imgs_np = img_dset[start_idx:end_idx] # (T, H, W, C)
|
|
||||||
|
|
||||||
if pad_len > 0:
|
|
||||||
first_frame = imgs_np[0]
|
|
||||||
padding = np.repeat(first_frame[np.newaxis, ...], pad_len, axis=0)
|
|
||||||
imgs_np = np.concatenate([padding, imgs_np], axis=0)
|
|
||||||
|
|
||||||
# 转换为 Tensor: (T, H, W, C) -> (T, C, H, W)
|
|
||||||
imgs_tensor = torch.from_numpy(imgs_np).float() / 255.0
|
|
||||||
imgs_tensor = torch.einsum('thwc->tchw', imgs_tensor)
|
|
||||||
image_dict[cam_name] = imgs_tensor
|
|
||||||
|
|
||||||
# ==============================
|
|
||||||
# 3. 组装 Batch
|
|
||||||
# ==============================
|
|
||||||
data_batch = {
|
|
||||||
'action': torch.from_numpy(actions).float(),
|
|
||||||
'qpos': torch.from_numpy(qpos_val).float(),
|
|
||||||
}
|
|
||||||
for cam_name, img_tensor in image_dict.items():
|
|
||||||
data_batch[f'image_{cam_name}'] = img_tensor
|
|
||||||
|
|
||||||
return data_batch
|
|
||||||
|
|
||||||
def _normalize_data(self, data, stats):
|
|
||||||
if self.normalization_type == 'min_max':
|
|
||||||
# 之前的逻辑: [-1, 1]
|
|
||||||
min_val = stats['min']
|
|
||||||
max_val = stats['max']
|
|
||||||
data = (data - min_val) / (max_val - min_val + 1e-8)
|
|
||||||
return data * 2 - 1
|
|
||||||
|
|
||||||
elif self.normalization_type == 'gaussian':
|
|
||||||
# 新逻辑: Mean/Std
|
|
||||||
mean = stats['mean']
|
|
||||||
std = stats['std']
|
|
||||||
# (data - mean) / std
|
|
||||||
# 这里的 data 是 numpy array
|
|
||||||
return (data - mean) / (std + 1e-8)
|
|
||||||
@@ -1,523 +1,199 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import h5py
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Union
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
class SimpleRobotDataset(Dataset):
|
class SimpleRobotDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
LeRobotDataset 简化版 - 图像以字典形式存储
|
HDF5 懒加载数据集 - LeRobotDataset 格式
|
||||||
|
|
||||||
与真实 LeRobotDataset 保持一致:
|
返回格式:
|
||||||
- Dataset 返回字典,每个摄像头单独的 key
|
- observation.state: (obs_horizon, state_dim)
|
||||||
- Policy 负责在 forward 时 stack 图像
|
- observation.{cam_name}: (obs_horizon, C, H, W)
|
||||||
|
- action: (pred_horizon, action_dim)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
frames: List[Dict],
|
dataset_dir: Union[str, Path],
|
||||||
obs_horizon: int = 2,
|
obs_horizon: int = 2,
|
||||||
pred_horizon: int = 8,
|
pred_horizon: int = 8,
|
||||||
image_keys: List[str] = None,
|
camera_names: List[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
frames: 帧数据列表。每个元素是一个字典,包含:
|
dataset_dir: HDF5 文件目录路径
|
||||||
- "episode_index" (int): [必须] 该帧所属的 Episode ID。Dataset 使用它来确定 Episode 的边界(用于 Padding)。
|
|
||||||
- "task" (str): [必须] 任务描述字符串(例如 "pick_up_cube")。
|
|
||||||
- "observation.state" (torch.Tensor): (state_dim,) [必须] 当前帧的机器人状态向量(例如关节角度)。
|
|
||||||
- "action" (torch.Tensor): (action_dim,) [必须] 当前帧对应的动作向量。
|
|
||||||
- "{image_key}" (torch.Tensor): (C, H, W) [可选] 当前帧的图像数据。键名必须与初始化 Dataset 时传入的 image_keys 列表一致。
|
|
||||||
obs_horizon: 观察过去多少帧
|
obs_horizon: 观察过去多少帧
|
||||||
pred_horizon: 预测未来多少帧动作
|
pred_horizon: 预测未来多少帧动作
|
||||||
image_keys: 哪些 key 是图像数据(例如 ["observation.image_0", "observation.image_1"])
|
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
|
||||||
|
|
||||||
|
HDF5 文件格式:
|
||||||
|
- action: [T, action_dim]
|
||||||
|
- observations/qpos: [T, obs_dim]
|
||||||
|
- observations/images/{cam_name}: [T, H, W, C]
|
||||||
"""
|
"""
|
||||||
self.frames = frames
|
|
||||||
self.obs_horizon = obs_horizon
|
self.obs_horizon = obs_horizon
|
||||||
self.pred_horizon = pred_horizon
|
self.pred_horizon = pred_horizon
|
||||||
self.image_keys = image_keys or []
|
self.camera_names = camera_names or []
|
||||||
|
|
||||||
# 构建 episode 索引
|
self.dataset_dir = Path(dataset_dir)
|
||||||
|
if not self.dataset_dir.exists():
|
||||||
|
raise FileNotFoundError(f"数据集目录不存在: {dataset_dir}")
|
||||||
|
|
||||||
|
# 查找 HDF5 文件
|
||||||
|
self.hdf5_files = sorted(self.dataset_dir.glob("*.hdf5"))
|
||||||
|
if not self.hdf5_files:
|
||||||
|
self.hdf5_files = sorted(self.dataset_dir.glob("episode_*.hdf5"))
|
||||||
|
if not self.hdf5_files:
|
||||||
|
raise FileNotFoundError(f"在 {dataset_dir} 中未找到 HDF5 文件")
|
||||||
|
|
||||||
|
# 构建 episode 索引(只存储元数据,不加载数据)
|
||||||
self.episodes = {}
|
self.episodes = {}
|
||||||
for idx, frame in enumerate(frames):
|
self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
|
||||||
ep_idx = frame["episode_index"]
|
for ep_idx, hdf5_path in enumerate(self.hdf5_files):
|
||||||
if ep_idx not in self.episodes:
|
with h5py.File(hdf5_path, 'r') as f:
|
||||||
self.episodes[ep_idx] = []
|
T = f['action'].shape[0]
|
||||||
self.episodes[ep_idx].append(idx)
|
start_idx = len(self.frame_meta)
|
||||||
|
for t in range(T):
|
||||||
|
self.frame_meta.append({
|
||||||
|
"ep_idx": ep_idx,
|
||||||
|
"frame_idx": t,
|
||||||
|
"hdf5_path": hdf5_path,
|
||||||
|
})
|
||||||
|
self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta)))
|
||||||
|
|
||||||
|
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)} 帧")
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.frames)
|
return len(self.frame_meta)
|
||||||
|
|
||||||
|
def _load_frame(self, idx: int) -> Dict:
|
||||||
|
"""从 HDF5 文件懒加载单帧数据"""
|
||||||
|
meta = self.frame_meta[idx]
|
||||||
|
with h5py.File(meta["hdf5_path"], 'r') as f:
|
||||||
|
frame = {
|
||||||
|
"episode_index": meta["ep_idx"],
|
||||||
|
"frame_index": meta["frame_idx"],
|
||||||
|
"task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown",
|
||||||
|
"observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(),
|
||||||
|
"action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
h5_path = f'observations/images/{cam_name}'
|
||||||
|
if h5_path in f:
|
||||||
|
img = f[h5_path][meta["frame_idx"]]
|
||||||
|
img = torch.from_numpy(img).float()
|
||||||
|
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||||
frame = self.frames[idx]
|
frame = self._load_frame(idx)
|
||||||
ep_idx = frame["episode_index"]
|
ep_idx = frame["episode_index"]
|
||||||
|
|
||||||
# 获取当前 episode 的帧索引范围
|
# 获取当前 episode 的帧索引范围
|
||||||
ep_indices = self.episodes[ep_idx]
|
ep_indices = self.episodes[ep_idx]
|
||||||
ep_start = ep_indices[0]
|
ep_start = ep_indices[0]
|
||||||
ep_end = ep_indices[-1]
|
ep_end = ep_indices[-1]
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# 1. 加载观察(过去 obs_horizon 帧)
|
# 1. 加载观察(过去 obs_horizon 帧)
|
||||||
# ============================================
|
# ============================================
|
||||||
observations = {
|
observations = {
|
||||||
"state": [], # 状态数据
|
"state": [], # 状态数据
|
||||||
}
|
}
|
||||||
# 为每个摄像头初始化独立列表(字典形式)
|
# 为每个摄像头初始化独立列表
|
||||||
for cam_key in self.image_keys:
|
for cam_name in self.camera_names:
|
||||||
observations[cam_key] = []
|
observations[f"observation.{cam_name}"] = []
|
||||||
|
|
||||||
observation_is_pad = []
|
observation_is_pad = []
|
||||||
|
|
||||||
for delta in range(-self.obs_horizon + 1, 1): # [-1, 0] for obs_horizon=2
|
for delta in range(-self.obs_horizon + 1, 1): # [-1, 0] for obs_horizon=2
|
||||||
target_idx = idx + delta
|
target_idx = idx + delta
|
||||||
|
|
||||||
# 边界检查
|
# 边界检查
|
||||||
if ep_start <= target_idx <= ep_end:
|
if ep_start <= target_idx <= ep_end:
|
||||||
target_frame = self.frames[target_idx]
|
target_frame = self._load_frame(target_idx)
|
||||||
is_pad = False
|
is_pad = False
|
||||||
else:
|
else:
|
||||||
# 超出边界,用边界帧填充
|
# 超出边界,用边界帧填充
|
||||||
if target_idx < ep_start:
|
if target_idx < ep_start:
|
||||||
target_frame = self.frames[ep_start]
|
target_frame = self._load_frame(ep_start)
|
||||||
else:
|
else:
|
||||||
target_frame = self.frames[ep_end]
|
target_frame = self._load_frame(ep_end)
|
||||||
is_pad = True
|
is_pad = True
|
||||||
|
|
||||||
# 收集状态
|
# 收集状态
|
||||||
observations["state"].append(target_frame["observation.state"])
|
observations["state"].append(target_frame["observation.state"])
|
||||||
|
|
||||||
# 收集每个摄像头的图像(字典形式,不 stack)
|
# 收集每个摄像头的图像
|
||||||
for cam_key in self.image_keys:
|
for cam_name in self.camera_names:
|
||||||
observations[cam_key].append(target_frame[cam_key])
|
observations[f"observation.{cam_name}"].append(target_frame[f"observation.{cam_name}"])
|
||||||
|
|
||||||
observation_is_pad.append(is_pad)
|
observation_is_pad.append(is_pad)
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# 2. 加载动作(未来 pred_horizon 帧)
|
# 2. 加载动作(未来 pred_horizon 帧)
|
||||||
# ============================================
|
# ============================================
|
||||||
actions = []
|
actions = []
|
||||||
action_is_pad = []
|
action_is_pad = []
|
||||||
|
|
||||||
for delta in range(self.pred_horizon):
|
for delta in range(self.pred_horizon):
|
||||||
target_idx = idx + delta
|
target_idx = idx + delta
|
||||||
|
|
||||||
if target_idx <= ep_end:
|
if target_idx <= ep_end:
|
||||||
actions.append(self.frames[target_idx]["action"])
|
actions.append(self._load_frame(target_idx)["action"])
|
||||||
action_is_pad.append(False)
|
action_is_pad.append(False)
|
||||||
else:
|
else:
|
||||||
actions.append(self.frames[ep_end]["action"])
|
actions.append(self._load_frame(ep_end)["action"])
|
||||||
action_is_pad.append(True)
|
action_is_pad.append(True)
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# 3. 组装返回数据(字典形式)
|
# 3. 组装返回数据(LeRobotDataset 格式)
|
||||||
# ============================================
|
# ============================================
|
||||||
result = {
|
result = {
|
||||||
# 状态观察: (obs_horizon, state_dim)
|
# 状态观察: (obs_horizon, state_dim)
|
||||||
"observation.state": torch.stack(observations["state"]),
|
"observation.state": torch.stack(observations["state"]),
|
||||||
"observation_is_pad": torch.tensor(observation_is_pad, dtype=torch.bool),
|
"observation_is_pad": torch.tensor(observation_is_pad, dtype=torch.bool),
|
||||||
|
|
||||||
# 动作: (pred_horizon, action_dim)
|
# 动作: (pred_horizon, action_dim)
|
||||||
"action": torch.stack(actions),
|
"action": torch.stack(actions),
|
||||||
"action_is_pad": torch.tensor(action_is_pad, dtype=torch.bool),
|
"action_is_pad": torch.tensor(action_is_pad, dtype=torch.bool),
|
||||||
|
|
||||||
# 任务
|
# 任务
|
||||||
"task": frame["task"],
|
"task": frame["task"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# 图像:每个摄像头独立的 key(字典形式)
|
# 图像:每个摄像头独立的 key
|
||||||
# 形状: (obs_horizon, C, H, W)
|
# 形状: (obs_horizon, C, H, W)
|
||||||
for cam_key in self.image_keys:
|
for cam_name in self.camera_names:
|
||||||
result[cam_key] = torch.stack(observations[cam_key])
|
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def camera_keys(self) -> list[str]:
|
def camera_keys(self) -> list[str]:
|
||||||
"""获取所有相机键名"""
|
"""获取所有相机键名 (LeRobotDataset 格式)"""
|
||||||
return self.image_keys
|
return [f"observation.{cam_name}" for cam_name in self.camera_names]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def camera_info(self) -> dict:
|
def camera_info(self) -> dict:
|
||||||
"""获取相机信息"""
|
"""获取相机信息"""
|
||||||
if not self.image_keys:
|
if not self.camera_names:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# 从第一个样本获取形状
|
# 从第一个样本获取形状
|
||||||
sample = self[0]
|
sample = self[0]
|
||||||
info = {}
|
info = {}
|
||||||
for cam_key in self.image_keys:
|
for cam_name in self.camera_names:
|
||||||
if cam_key in sample:
|
key = f"observation.{cam_name}"
|
||||||
info[cam_key] = {
|
if key in sample:
|
||||||
"shape": sample[cam_key].shape,
|
info[key] = {
|
||||||
"dtype": str(sample[cam_key].dtype),
|
"shape": sample[key].shape,
|
||||||
|
"dtype": str(sample[key].dtype),
|
||||||
}
|
}
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
class SimpleDiffusionPolicy(torch.nn.Module):
|
|
||||||
"""简化的 Diffusion Policy - 展示如何在 forward 时 stack 图像"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
state_dim: int,
|
|
||||||
action_dim: int,
|
|
||||||
image_features: Dict[str, tuple] = None,
|
|
||||||
obs_horizon: int = 2,
|
|
||||||
pred_horizon: int = 8,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.state_dim = state_dim
|
|
||||||
self.action_dim = action_dim
|
|
||||||
self.obs_horizon = obs_horizon
|
|
||||||
self.pred_horizon = pred_horizon
|
|
||||||
self.image_features = image_features or {}
|
|
||||||
|
|
||||||
self.state_encoder = torch.nn.Linear(state_dim, 64)
|
|
||||||
if image_features:
|
|
||||||
num_cameras = len(image_features)
|
|
||||||
self.image_encoder = torch.nn.Conv2d(3, 32, kernel_size=7, stride=2)
|
|
||||||
self.fusion = torch.nn.Linear(64 + 32 * num_cameras, 128)
|
|
||||||
else:
|
|
||||||
self.fusion = torch.nn.Linear(64, 128)
|
|
||||||
|
|
||||||
self.action_head = torch.nn.Linear(128, action_dim * pred_horizon)
|
|
||||||
|
|
||||||
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
||||||
"""前向传播"""
|
|
||||||
# 处理状态
|
|
||||||
state_features = self.state_encoder(batch["observation.state"])
|
|
||||||
state_features = state_features.mean(dim=1)
|
|
||||||
|
|
||||||
# 处理图像(字典形式 → stack)
|
|
||||||
if self.image_features:
|
|
||||||
image_tensors = [batch[key] for key in self.image_features.keys()]
|
|
||||||
stacked_images = torch.stack(image_tensors, dim=1)
|
|
||||||
|
|
||||||
B, num_cam, T, C, H, W = stacked_images.shape
|
|
||||||
images_flat = stacked_images.reshape(B * num_cam * T, C, H, W)
|
|
||||||
image_features = self.image_encoder(images_flat)
|
|
||||||
image_features = image_features.mean(dim=[2, 3])
|
|
||||||
image_features = image_features.reshape(B, num_cam, T, 32).mean(dim=2)
|
|
||||||
image_features = image_features.reshape(B, -1)
|
|
||||||
|
|
||||||
features = torch.cat([state_features, image_features], dim=-1)
|
|
||||||
else:
|
|
||||||
features = state_features
|
|
||||||
|
|
||||||
fused = self.fusion(features)
|
|
||||||
pred_actions = self.action_head(fused)
|
|
||||||
pred_actions = pred_actions.reshape(B, self.pred_horizon, self.action_dim)
|
|
||||||
|
|
||||||
return pred_actions
|
|
||||||
|
|
||||||
|
|
||||||
def create_demo_data_with_images():
|
|
||||||
"""创建包含图像的模拟数据"""
|
|
||||||
frames = []
|
|
||||||
|
|
||||||
# Episode 0: pick_up_cube task
|
|
||||||
for t in range(10):
|
|
||||||
frames.append({
|
|
||||||
"episode_index": 0,
|
|
||||||
"frame_index": t,
|
|
||||||
"task": "pick_up_cube",
|
|
||||||
"observation.state": torch.randn(6),
|
|
||||||
"observation.image_high_resize": torch.randn(3, 64, 64),
|
|
||||||
"observation.image_left_wrist": torch.randn(3, 64, 64),
|
|
||||||
"action": torch.randn(6),
|
|
||||||
})
|
|
||||||
|
|
||||||
# Episode 1: stack_blocks task
|
|
||||||
for t in range(10):
|
|
||||||
frames.append({
|
|
||||||
"episode_index": 1,
|
|
||||||
"frame_index": t,
|
|
||||||
"task": "stack_blocks",
|
|
||||||
"observation.state": torch.randn(6),
|
|
||||||
"observation.image_high_resize": torch.randn(3, 64, 64),
|
|
||||||
"observation.image_left_wrist": torch.randn(3, 64, 64),
|
|
||||||
"action": torch.randn(6),
|
|
||||||
})
|
|
||||||
|
|
||||||
return frames
|
|
||||||
|
|
||||||
|
|
||||||
def print_section(title: str):
|
|
||||||
"""打印分节标题"""
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print(f" {title}")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_basic_info(dataset):
|
|
||||||
"""测试数据集基本信息"""
|
|
||||||
print("\n📊 数据集基本信息:")
|
|
||||||
print(f" 总帧数: {len(dataset)}")
|
|
||||||
print(f" 总 episode 数: {len(dataset.episodes)}")
|
|
||||||
print(f" 观察窗口: {dataset.obs_horizon}")
|
|
||||||
print(f" 预测窗口: {dataset.pred_horizon}")
|
|
||||||
|
|
||||||
print(f"\n📷 相机信息:")
|
|
||||||
cameras = dataset.camera_keys
|
|
||||||
print(f" 相机数量: {len(cameras)}")
|
|
||||||
for cam in cameras:
|
|
||||||
print(f" - {cam}")
|
|
||||||
|
|
||||||
print(f"\n相机详细信息:")
|
|
||||||
cam_info = dataset.camera_info
|
|
||||||
for cam, info in cam_info.items():
|
|
||||||
print(f" {cam}:")
|
|
||||||
print(f" shape: {info['shape']}")
|
|
||||||
print(f" dtype: {info['dtype']}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_sample(dataset):
|
|
||||||
"""测试单个样本"""
|
|
||||||
print_section("1. 测试单个样本")
|
|
||||||
|
|
||||||
# Episode 中间的样本
|
|
||||||
sample = dataset[5]
|
|
||||||
|
|
||||||
print("\n样本结构 (字典形式):")
|
|
||||||
for key, value in sample.items():
|
|
||||||
if isinstance(value, torch.Tensor):
|
|
||||||
print(f" {key:30s}: {str(value.shape):20s} {value.dtype}")
|
|
||||||
elif isinstance(value, str):
|
|
||||||
print(f" {key:30s}: {value}")
|
|
||||||
|
|
||||||
# 验证图像是字典形式
|
|
||||||
print("\n✅ 验证图像存储形式:")
|
|
||||||
print(" 图像以字典形式存储,每个摄像头独立的 key:")
|
|
||||||
for cam_key in dataset.camera_keys:
|
|
||||||
if cam_key in sample:
|
|
||||||
print(f" - {cam_key}: {sample[cam_key].shape}")
|
|
||||||
|
|
||||||
# 验证时间维度
|
|
||||||
print("\n✅ 验证时间维度:")
|
|
||||||
print(f" observation.state: {sample['observation.state'].shape}")
|
|
||||||
print(f" 预期: (obs_horizon={dataset.obs_horizon}, state_dim=6)")
|
|
||||||
assert sample['observation.state'].shape[0] == dataset.obs_horizon, "观察时间维度错误"
|
|
||||||
print(f" action: {sample['action'].shape}")
|
|
||||||
print(f" 预期: (pred_horizon={dataset.pred_horizon}, action_dim=6)")
|
|
||||||
assert sample['action'].shape[0] == dataset.pred_horizon, "动作时间维度错误"
|
|
||||||
print(" ✓ 时间维度验证通过")
|
|
||||||
|
|
||||||
|
|
||||||
def test_edge_cases(dataset):
|
|
||||||
"""测试边界情况"""
|
|
||||||
print_section("2. 测试边界情况")
|
|
||||||
|
|
||||||
test_cases = [
|
|
||||||
("Episode 开头", 0, {"obs_pad": [True, False], "action_pad": [False] * 8}),
|
|
||||||
("Episode 中间", 5, {"obs_pad": [False, False], "action_pad": [False] * 5 + [True] * 3}),
|
|
||||||
("Episode 末尾", 9, {"obs_pad": [False, False], "action_pad": [True] * 8}),
|
|
||||||
("跨 Episode", 10, {"obs_pad": [True, False], "action_pad": [False] * 8}),
|
|
||||||
]
|
|
||||||
|
|
||||||
for name, idx, expected in test_cases:
|
|
||||||
print(f"\n📍 {name} (idx={idx}):")
|
|
||||||
sample = dataset[idx]
|
|
||||||
|
|
||||||
obs_pad = sample["observation_is_pad"].tolist()
|
|
||||||
action_pad_count = sample["action_is_pad"].sum().item()
|
|
||||||
|
|
||||||
print(f" observation_is_pad: {obs_pad}")
|
|
||||||
print(f" action_is_pad: {sample['action_is_pad'].tolist()}")
|
|
||||||
print(f" action padding 数量: {action_pad_count}")
|
|
||||||
|
|
||||||
# 验证观察 padding
|
|
||||||
if name == "Episode 开头":
|
|
||||||
assert obs_pad[0] == True, "Episode 开头第一帧应该是 padding"
|
|
||||||
elif name == "跨 Episode":
|
|
||||||
assert obs_pad[0] == True, "跨 Episode 第一帧应该是 padding"
|
|
||||||
|
|
||||||
|
|
||||||
def test_dataloader(dataset):
|
|
||||||
"""测试 DataLoader"""
|
|
||||||
print_section("3. 测试 DataLoader 集成")
|
|
||||||
|
|
||||||
dataloader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=4,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=0, # 测试时用 0
|
|
||||||
)
|
|
||||||
|
|
||||||
batch = next(iter(dataloader))
|
|
||||||
|
|
||||||
print("\n📦 Batch 结构:")
|
|
||||||
for key in ["observation.state", "observation.image_high_resize",
|
|
||||||
"observation.image_left_wrist", "action", "task"]:
|
|
||||||
if key in batch:
|
|
||||||
value = batch[key]
|
|
||||||
if isinstance(value, torch.Tensor):
|
|
||||||
print(f" {key:30s}: {str(value.shape):20s} {value.dtype}")
|
|
||||||
else:
|
|
||||||
print(f" {key:30s}: {type(value).__name__} (length={len(value)})")
|
|
||||||
|
|
||||||
print("\n✅ 验证 Batch 形状:")
|
|
||||||
B = len(batch["observation.state"])
|
|
||||||
print(f" Batch size: {B}")
|
|
||||||
|
|
||||||
# 验证每个摄像头的形状
|
|
||||||
for cam_key in dataset.camera_keys:
|
|
||||||
expected_shape = (B, dataset.obs_horizon, 3, 64, 64)
|
|
||||||
actual_shape = batch[cam_key].shape
|
|
||||||
print(f" {cam_key}:")
|
|
||||||
print(f" 预期: {expected_shape}")
|
|
||||||
print(f" 实际: {actual_shape}")
|
|
||||||
assert actual_shape == expected_shape, f"{cam_key} 形状不匹配"
|
|
||||||
print(" ✓ Batch 形状验证通过")
|
|
||||||
|
|
||||||
|
|
||||||
def test_policy_forward(dataset):
|
|
||||||
"""测试 Policy 前向传播"""
|
|
||||||
print_section("4. 测试 Policy 前向传播")
|
|
||||||
|
|
||||||
# 创建 Policy
|
|
||||||
policy = SimpleDiffusionPolicy(
|
|
||||||
state_dim=6,
|
|
||||||
action_dim=6,
|
|
||||||
image_features={
|
|
||||||
"observation.image_high_resize": (3, 64, 64),
|
|
||||||
"observation.image_left_wrist": (3, 64, 64),
|
|
||||||
},
|
|
||||||
obs_horizon=dataset.obs_horizon,
|
|
||||||
pred_horizon=dataset.pred_horizon,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 DataLoader
|
|
||||||
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
|
|
||||||
batch = next(iter(dataloader))
|
|
||||||
|
|
||||||
print("\n🔄 Policy.forward() 流程:")
|
|
||||||
|
|
||||||
# 1. Stack 之前
|
|
||||||
print("\n 1️⃣ Stack 之前 (字典形式):")
|
|
||||||
for cam_key in policy.image_features.keys():
|
|
||||||
print(f" batch['{cam_key}']: {batch[cam_key].shape}")
|
|
||||||
|
|
||||||
# 2. 模拟 Stack 操作
|
|
||||||
print("\n 2️⃣ Stack 操作:")
|
|
||||||
image_tensors = [batch[key] for key in policy.image_features.keys()]
|
|
||||||
stacked = torch.stack(image_tensors, dim=1)
|
|
||||||
print(f" stacked_images: {stacked.shape}")
|
|
||||||
print(f" (B={stacked.shape[0]}, num_cam={stacked.shape[1]}, ")
|
|
||||||
print(f" obs_hor={stacked.shape[2]}, C={stacked.shape[3]}, H={stacked.shape[4]}, W={stacked.shape[5]})")
|
|
||||||
|
|
||||||
# 3. 前向传播
|
|
||||||
print("\n 3️⃣ 前向传播:")
|
|
||||||
with torch.no_grad():
|
|
||||||
pred_actions = policy(batch)
|
|
||||||
|
|
||||||
print(f" 输入:")
|
|
||||||
print(f" observation.state: {batch['observation.state'].shape}")
|
|
||||||
print(f" 图像已 stack")
|
|
||||||
print(f" 输出:")
|
|
||||||
print(f" pred_actions: {pred_actions.shape}")
|
|
||||||
print(f" (B={pred_actions.shape[0]}, pred_horizon={pred_actions.shape[1]}, action_dim={pred_actions.shape[2]})")
|
|
||||||
|
|
||||||
print("\n✅ Policy 前向传播验证通过")
|
|
||||||
|
|
||||||
|
|
||||||
def test_data_consistency(dataset):
|
|
||||||
"""测试数据一致性"""
|
|
||||||
print_section("5. 测试数据一致性")
|
|
||||||
|
|
||||||
print("\n🔍 验证图像 padding 的正确性:")
|
|
||||||
|
|
||||||
# Episode 开头的样本
|
|
||||||
sample = dataset[0]
|
|
||||||
if sample["observation_is_pad"][0]:
|
|
||||||
img_0 = sample["observation.image_high_resize"][0]
|
|
||||||
img_1 = sample["observation.image_high_resize"][1]
|
|
||||||
print(f" Episode 开头 (idx=0):")
|
|
||||||
print(f" 第0帧是 padding: {sample['observation_is_pad'][0]}")
|
|
||||||
print(f" 第0帧图像 = 第1帧图像: {torch.equal(img_0, img_1)}")
|
|
||||||
assert torch.equal(img_0, img_1), "Padding 应该复制边界帧"
|
|
||||||
print(" ✓ Padding 正确")
|
|
||||||
|
|
||||||
# Episode 中间的样本
|
|
||||||
sample = dataset[5]
|
|
||||||
if not sample["observation_is_pad"].any():
|
|
||||||
img_0 = sample["observation.image_high_resize"][0]
|
|
||||||
img_1 = sample["observation.image_high_resize"][1]
|
|
||||||
print(f"\n Episode 中间 (idx=5):")
|
|
||||||
print(f" 没有 padding: {sample['observation_is_pad']}")
|
|
||||||
print(f" 第0帧图像 ≠ 第1帧图像: {not torch.equal(img_0, img_1)}")
|
|
||||||
print(" ✓ 正常帧不重复")
|
|
||||||
|
|
||||||
print("\n✅ 数据一致性验证通过")
|
|
||||||
|
|
||||||
|
|
||||||
def test_task_info(dataset):
|
|
||||||
"""测试任务信息"""
|
|
||||||
print_section("6. 测试任务信息")
|
|
||||||
|
|
||||||
print("\n📋 统计任务分布:")
|
|
||||||
task_count = {}
|
|
||||||
for frame in dataset.frames:
|
|
||||||
task = frame["task"]
|
|
||||||
task_count[task] = task_count.get(task, 0) + 1
|
|
||||||
|
|
||||||
for task, count in task_count.items():
|
|
||||||
print(f" {task}: {count} 帧")
|
|
||||||
|
|
||||||
# 验证 sample 中的 task 信息
|
|
||||||
sample = dataset[0]
|
|
||||||
print(f"\n样本 task: {sample['task']}")
|
|
||||||
print(f" 类型: {type(sample['task'])}")
|
|
||||||
|
|
||||||
# 验证 DataLoader 中的 task
|
|
||||||
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
|
|
||||||
batch = next(iter(dataloader))
|
|
||||||
print(f"\nBatch task:")
|
|
||||||
print(f" 值: {batch['task']}")
|
|
||||||
print(f" 类型: {type(batch['task'])}")
|
|
||||||
print(f" 长度: {len(batch['task'])}")
|
|
||||||
|
|
||||||
print("\n✅ 任务信息验证通过")
|
|
||||||
|
|
||||||
|
|
||||||
def run_all_tests():
|
|
||||||
"""运行所有测试"""
|
|
||||||
print("\n" + "🚀" * 40)
|
|
||||||
print(" SimpleRobotDataset 完整测试套件")
|
|
||||||
print("🚀" * 40)
|
|
||||||
|
|
||||||
# 创建数据集
|
|
||||||
print("\n创建测试数据...")
|
|
||||||
frames = create_demo_data_with_images()
|
|
||||||
dataset = SimpleRobotDataset(
|
|
||||||
frames,
|
|
||||||
obs_horizon=2,
|
|
||||||
pred_horizon=8,
|
|
||||||
image_keys=["observation.image_high_resize", "observation.image_left_wrist"],
|
|
||||||
)
|
|
||||||
print("✓ 数据集创建完成")
|
|
||||||
|
|
||||||
# 运行测试
|
|
||||||
test_dataset_basic_info(dataset)
|
|
||||||
test_single_sample(dataset)
|
|
||||||
test_edge_cases(dataset)
|
|
||||||
test_dataloader(dataset)
|
|
||||||
test_policy_forward(dataset)
|
|
||||||
test_data_consistency(dataset)
|
|
||||||
test_task_info(dataset)
|
|
||||||
|
|
||||||
# 总结
|
|
||||||
print_section("✅ 测试总结")
|
|
||||||
print("\n所有测试通过!✨")
|
|
||||||
print("\n关键验证点:")
|
|
||||||
print(" ✓ 图像以字典形式存储")
|
|
||||||
print(" ✓ 每个摄像头独立的 key")
|
|
||||||
print(" ✓ Policy 在 forward 时 stack 图像")
|
|
||||||
print(" ✓ 时间维度正确 (obs_horizon, pred_horizon)")
|
|
||||||
print(" ✓ Padding 处理正确")
|
|
||||||
print(" ✓ DataLoader 集成正确")
|
|
||||||
print(" ✓ Task 信息传递正确")
|
|
||||||
print("\n与 LeRobotDataset 设计完全一致!🎉")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
run_all_tests()
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
# Backbone models
|
# Backbone models
|
||||||
from .resnet import ResNetBackbone
|
from .resnet_diffusion import ResNetDiffusionBackbone
|
||||||
|
|
||||||
__all__ = ["ResNetBackbone"]
|
__all__ = ["ResNetBackbone", "ResNetDiffusionBackbone"]
|
||||||
|
|||||||
@@ -1,93 +0,0 @@
|
|||||||
from roboimi.vla.core.interfaces import VLABackbone
|
|
||||||
from transformers import ResNetModel
|
|
||||||
from torchvision import transforms
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class ResNetBackbone(VLABackbone):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name = "microsoft/resnet-18",
|
|
||||||
freeze: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.model = ResNetModel.from_pretrained(model_name)
|
|
||||||
self.out_channels = self.model.config.hidden_sizes[-1]
|
|
||||||
self.transform = transforms.Compose([
|
|
||||||
transforms.Resize((384, 384)),
|
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
||||||
])
|
|
||||||
self.spatial_softmax = SpatialSoftmax(num_rows=12, num_cols=12)
|
|
||||||
if freeze:
|
|
||||||
self._freeze_parameters()
|
|
||||||
|
|
||||||
def _freeze_parameters(self):
|
|
||||||
print("❄️ Freezing ResNet Backbone parameters")
|
|
||||||
for param in self.model.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
def train(self, mode=True):
|
|
||||||
"""
|
|
||||||
Override train() to keep frozen ResNet in eval mode.
|
|
||||||
This ensures BatchNorm layers use running statistics consistently.
|
|
||||||
"""
|
|
||||||
super().train(mode)
|
|
||||||
if hasattr(self, 'model'):
|
|
||||||
self.model.eval() # Always keep ResNet in eval mode
|
|
||||||
return self
|
|
||||||
|
|
||||||
def forward_single_image(self, image):
|
|
||||||
B, T, C, H, W = image.shape
|
|
||||||
image = image.view(B * T, C, H, W)
|
|
||||||
image = self.transform(image)
|
|
||||||
feature_map = self.model(image).last_hidden_state # (B*T, D, H', W')
|
|
||||||
features = self.spatial_softmax(feature_map) # (B*T, D*2)
|
|
||||||
return features
|
|
||||||
|
|
||||||
def forward(self, images):
|
|
||||||
any_tensor = next(iter(images.values()))
|
|
||||||
B, T = any_tensor.shape[:2]
|
|
||||||
features_all = []
|
|
||||||
sorted_cam_names = sorted(images.keys())
|
|
||||||
for cam_name in sorted_cam_names:
|
|
||||||
img = images[cam_name]
|
|
||||||
features = self.forward_single_image(img) # (B*T, D*2)
|
|
||||||
features_all.append(features)
|
|
||||||
combined_features = torch.cat(features_all, dim=1) # (B*T, Num_Cams*D*2)
|
|
||||||
return combined_features.view(B, T, -1)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_dim(self):
|
|
||||||
"""Output dimension after spatial softmax: out_channels * 2"""
|
|
||||||
return self.out_channels * 2
|
|
||||||
|
|
||||||
class SpatialSoftmax(nn.Module):
|
|
||||||
"""
|
|
||||||
将特征图 (N, C, H, W) 转换为坐标特征 (N, C*2)
|
|
||||||
"""
|
|
||||||
def __init__(self, num_rows, num_cols, temperature=None):
|
|
||||||
super().__init__()
|
|
||||||
self.temperature = nn.Parameter(torch.ones(1))
|
|
||||||
# 创建网格坐标
|
|
||||||
pos_x, pos_y = torch.meshgrid(
|
|
||||||
torch.linspace(-1, 1, num_rows),
|
|
||||||
torch.linspace(-1, 1, num_cols),
|
|
||||||
indexing='ij'
|
|
||||||
)
|
|
||||||
self.register_buffer('pos_x', pos_x.reshape(-1))
|
|
||||||
self.register_buffer('pos_y', pos_y.reshape(-1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
N, C, H, W = x.shape
|
|
||||||
x = x.view(N, C, -1) # (N, C, H*W)
|
|
||||||
|
|
||||||
# 计算 Softmax 注意力图
|
|
||||||
softmax_attention = torch.nn.functional.softmax(x / self.temperature, dim=2)
|
|
||||||
|
|
||||||
# 计算期望坐标 (x, y)
|
|
||||||
expected_x = torch.sum(softmax_attention * self.pos_x, dim=2, keepdim=True)
|
|
||||||
expected_y = torch.sum(softmax_attention * self.pos_y, dim=2, keepdim=True)
|
|
||||||
|
|
||||||
# 拼接并展平 -> (N, C*2)
|
|
||||||
return torch.cat([expected_x, expected_y], dim=2).reshape(N, -1)
|
|
||||||
@@ -91,20 +91,21 @@ class SpatialSoftmax(nn.Module):
|
|||||||
|
|
||||||
return feature_keypoints
|
return feature_keypoints
|
||||||
|
|
||||||
class ResNetDiffusionBackbone(VLABackbone):
|
class _SingleRgbEncoder(nn.Module):
|
||||||
|
"""单个摄像头的 RGB 编码器,支持独立或共享使用"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vision_backbone: str = "resnet18",
|
vision_backbone: str,
|
||||||
pretrained_backbone_weights: str | None = None,
|
pretrained_backbone_weights: str | None,
|
||||||
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W)
|
input_shape: Tuple[int, int, int],
|
||||||
crop_shape: Optional[Tuple[int, int]] = None,
|
crop_shape: Optional[Tuple[int, int]],
|
||||||
crop_is_random: bool = True,
|
crop_is_random: bool,
|
||||||
use_group_norm: bool = True,
|
use_group_norm: bool,
|
||||||
spatial_softmax_num_keypoints: int = 32,
|
spatial_softmax_num_keypoints: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# 设置可选的预处理。
|
# 设置可选的预处理
|
||||||
if crop_shape is not None:
|
if crop_shape is not None:
|
||||||
self.do_crop = True
|
self.do_crop = True
|
||||||
# 评估时始终使用中心裁剪
|
# 评估时始终使用中心裁剪
|
||||||
@@ -117,14 +118,14 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
self.do_crop = False
|
self.do_crop = False
|
||||||
crop_shape = input_shape[1:]
|
crop_shape = input_shape[1:]
|
||||||
|
|
||||||
# 设置骨干网络。
|
# 设置骨干网络
|
||||||
backbone_model = getattr(torchvision.models, vision_backbone)(
|
backbone_model = getattr(torchvision.models, vision_backbone)(
|
||||||
weights=pretrained_backbone_weights
|
weights=pretrained_backbone_weights
|
||||||
)
|
)
|
||||||
|
|
||||||
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
|
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
|
||||||
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||||
|
|
||||||
if use_group_norm:
|
if use_group_norm:
|
||||||
self.backbone = _replace_submodules(
|
self.backbone = _replace_submodules(
|
||||||
root_module=self.backbone,
|
root_module=self.backbone,
|
||||||
@@ -132,12 +133,12 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设置池化和最终层。
|
# 设置池化和最终层
|
||||||
# 使用试运行来获取特征图形状。
|
# 使用试运行来获取特征图形状
|
||||||
dummy_shape = (1, input_shape[0], *crop_shape)
|
dummy_shape = (1, input_shape[0], *crop_shape)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
dummy_out = self.backbone(torch.zeros(dummy_shape))
|
dummy_out = self.backbone(torch.zeros(dummy_shape))
|
||||||
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
||||||
|
|
||||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
|
self.pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
|
||||||
self.feature_dim = spatial_softmax_num_keypoints * 2
|
self.feature_dim = spatial_softmax_num_keypoints * 2
|
||||||
@@ -150,58 +151,205 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetDiffusionBackbone(VLABackbone):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_backbone: str = "resnet18",
|
||||||
|
pretrained_backbone_weights: str | None = None,
|
||||||
|
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W)
|
||||||
|
crop_shape: Optional[Tuple[int, int]] = None,
|
||||||
|
crop_is_random: bool = True,
|
||||||
|
use_group_norm: bool = True,
|
||||||
|
spatial_softmax_num_keypoints: int = 32,
|
||||||
|
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
|
||||||
|
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
||||||
|
self.num_cameras = num_cameras
|
||||||
|
|
||||||
|
if use_separate_rgb_encoder_per_camera:
|
||||||
|
# 独立编码器模式:为每个摄像头创建独立的编码器
|
||||||
|
encoders = [
|
||||||
|
_SingleRgbEncoder(
|
||||||
|
vision_backbone=vision_backbone,
|
||||||
|
pretrained_backbone_weights=pretrained_backbone_weights,
|
||||||
|
input_shape=input_shape,
|
||||||
|
crop_shape=crop_shape,
|
||||||
|
crop_is_random=crop_is_random,
|
||||||
|
use_group_norm=use_group_norm,
|
||||||
|
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||||
|
)
|
||||||
|
for _ in range(num_cameras)
|
||||||
|
]
|
||||||
|
self.rgb_encoder = nn.ModuleList(encoders)
|
||||||
|
# 重要:output_dim 始终表示单个编码器的特征维度(与 lerobot 保持一致)
|
||||||
|
self.feature_dim = encoders[0].feature_dim
|
||||||
|
else:
|
||||||
|
# 共享编码器模式:所有摄像头共享同一个编码器
|
||||||
|
self.rgb_encoder = _SingleRgbEncoder(
|
||||||
|
vision_backbone=vision_backbone,
|
||||||
|
pretrained_backbone_weights=pretrained_backbone_weights,
|
||||||
|
input_shape=input_shape,
|
||||||
|
crop_shape=crop_shape,
|
||||||
|
crop_is_random=crop_is_random,
|
||||||
|
use_group_norm=use_group_norm,
|
||||||
|
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||||
|
)
|
||||||
|
self.feature_dim = self.rgb_encoder.feature_dim
|
||||||
|
|
||||||
def forward(self, images):
|
def forward(self, images):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
images: Dict[str, Tensor], 每个摄像头的图像
|
||||||
|
形状: {cam_name: (B, T, C, H, W)}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: (B, T, total_feature_dim)
|
||||||
|
"""
|
||||||
any_tensor = next(iter(images.values()))
|
any_tensor = next(iter(images.values()))
|
||||||
B, T = any_tensor.shape[:2]
|
B, T = any_tensor.shape[:2]
|
||||||
features_all = []
|
cam_names = sorted(images.keys())
|
||||||
for cam_name in sorted(images.keys()):
|
|
||||||
img = images[cam_name]
|
if self.use_separate_rgb_encoder_per_camera:
|
||||||
features = self.forward_single_image(img.view(B * T, *img.shape[2:]))
|
# 独立编码器模式:每个摄像头使用对应的编码器
|
||||||
features_all.append(features)
|
features_all = []
|
||||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
for cam_idx, cam_name in enumerate(cam_names):
|
||||||
|
img = images[cam_name]
|
||||||
|
encoder = self.rgb_encoder[cam_idx]
|
||||||
|
features = encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||||
|
features_all.append(features)
|
||||||
|
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||||
|
else:
|
||||||
|
# 共享编码器模式:所有摄像头共享同一个编码器
|
||||||
|
features_all = []
|
||||||
|
for cam_name in cam_names:
|
||||||
|
img = images[cam_name]
|
||||||
|
features = self.rgb_encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||||
|
features_all.append(features)
|
||||||
|
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_dim(self):
|
def output_dim(self):
|
||||||
return self.feature_dim
|
return self.feature_dim
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("🚀 Testing ResNetDiffusionBackbone...")
|
print("=" * 60)
|
||||||
|
print("🚀 Testing ResNetDiffusionBackbone")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
B, T = 2, 5
|
B, T = 2, 5
|
||||||
C, H, W = 3, 96, 96
|
C, H, W = 3, 96, 96
|
||||||
crop_h, crop_w = 84, 84
|
crop_h, crop_w = 84, 84
|
||||||
num_keypoints = 32
|
num_keypoints = 32
|
||||||
feature_dim_per_cam = num_keypoints * 2
|
feature_dim_per_cam = num_keypoints * 2
|
||||||
|
|
||||||
# Instantiate model
|
# Create dummy input (2 cameras)
|
||||||
backbone = ResNetDiffusionBackbone(
|
|
||||||
vision_backbone="resnet18",
|
|
||||||
pretrained_backbone_weights=None, # Speed up test
|
|
||||||
input_shape=(C, H, W),
|
|
||||||
crop_shape=(crop_h, crop_w),
|
|
||||||
crop_is_random=True,
|
|
||||||
use_group_norm=True,
|
|
||||||
spatial_softmax_num_keypoints=num_keypoints
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"✅ Model instantiated. Output dim per camera: {backbone.output_dim}")
|
|
||||||
|
|
||||||
# Create dummy input
|
|
||||||
images = {
|
images = {
|
||||||
"cam_high": torch.randn(B, T, C, H, W),
|
"cam_high": torch.randn(B, T, C, H, W),
|
||||||
"cam_wrist": torch.randn(B, T, C, H, W)
|
"cam_wrist": torch.randn(B, T, C, H, W)
|
||||||
}
|
}
|
||||||
|
num_cameras = len(images)
|
||||||
# Forward pass
|
|
||||||
print("🔄 Running forward pass...")
|
# ============================================================================
|
||||||
output = backbone(images)
|
# Test 1: Shared Encoder (默认模式)
|
||||||
|
# ============================================================================
|
||||||
print(f"Input shapes: {[v.shape for v in images.values()]}")
|
print("\n[Test 1] Shared Encoder Mode")
|
||||||
print(f"Output shape: {output.shape}")
|
print("-" * 60)
|
||||||
|
backbone_shared = ResNetDiffusionBackbone(
|
||||||
# Verification
|
vision_backbone="resnet18",
|
||||||
expected_dim = len(images) * feature_dim_per_cam
|
pretrained_backbone_weights=None, # Speed up test
|
||||||
|
input_shape=(C, H, W),
|
||||||
|
crop_shape=(crop_h, crop_w),
|
||||||
|
crop_is_random=True,
|
||||||
|
use_group_norm=True,
|
||||||
|
spatial_softmax_num_keypoints=num_keypoints,
|
||||||
|
use_separate_rgb_encoder_per_camera=False, # 共享编码器
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ Shared encoder model instantiated")
|
||||||
|
print(f" Output dim per camera: {feature_dim_per_cam}")
|
||||||
|
print(f" Number of cameras: {num_cameras}")
|
||||||
|
print(f" Expected total dim: {num_cameras * feature_dim_per_cam}")
|
||||||
|
|
||||||
|
output = backbone_shared(images)
|
||||||
|
print(f"\n🔄 Forward pass completed")
|
||||||
|
print(f" Input shapes: {[v.shape for v in images.values()]}")
|
||||||
|
print(f" Output shape: {output.shape}")
|
||||||
|
|
||||||
|
expected_dim = num_cameras * feature_dim_per_cam
|
||||||
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
|
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
|
||||||
|
print(f"✨ Test passed!")
|
||||||
print("✨ Test passed!")
|
|
||||||
|
# ============================================================================
|
||||||
|
# Test 2: Separate Encoders (独立编码器模式)
|
||||||
|
# ============================================================================
|
||||||
|
print("\n[Test 2] Separate Encoders Mode")
|
||||||
|
print("-" * 60)
|
||||||
|
backbone_separate = ResNetDiffusionBackbone(
|
||||||
|
vision_backbone="resnet18",
|
||||||
|
pretrained_backbone_weights=None, # Speed up test
|
||||||
|
input_shape=(C, H, W),
|
||||||
|
crop_shape=(crop_h, crop_w),
|
||||||
|
crop_is_random=True,
|
||||||
|
use_group_norm=True,
|
||||||
|
spatial_softmax_num_keypoints=num_keypoints,
|
||||||
|
use_separate_rgb_encoder_per_camera=True, # 独立编码器
|
||||||
|
num_cameras=num_cameras,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ Separate encoders model instantiated")
|
||||||
|
print(f" Output dim per camera: {feature_dim_per_cam}")
|
||||||
|
print(f" Number of cameras: {num_cameras}")
|
||||||
|
print(f" Number of encoders: {len(backbone_separate.rgb_encoder)}")
|
||||||
|
|
||||||
|
output = backbone_separate(images)
|
||||||
|
print(f"\n🔄 Forward pass completed")
|
||||||
|
print(f" Input shapes: {[v.shape for v in images.values()]}")
|
||||||
|
print(f" Output shape: {output.shape}")
|
||||||
|
|
||||||
|
expected_dim = num_cameras * feature_dim_per_cam
|
||||||
|
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
|
||||||
|
print(f"✨ Test passed!")
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Test 3: Verify parameters count
|
||||||
|
# ============================================================================
|
||||||
|
print("\n[Test 3] Parameter Count Comparison")
|
||||||
|
print("-" * 60)
|
||||||
|
shared_params = sum(p.numel() for p in backbone_shared.parameters())
|
||||||
|
separate_params = sum(p.numel() for p in backbone_separate.parameters())
|
||||||
|
|
||||||
|
print(f" Shared encoder parameters: {shared_params:,}")
|
||||||
|
print(f" Separate encoders parameters: {separate_params:,}")
|
||||||
|
print(f" Ratio: {separate_params / shared_params:.2f}x")
|
||||||
|
|
||||||
|
assert separate_params > shared_params, "Separate encoders should have more parameters"
|
||||||
|
print(f"✨ Verification passed!")
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Test 4: Verify independent parameters
|
||||||
|
# ============================================================================
|
||||||
|
print("\n[Test 4] Verify Independent Parameters")
|
||||||
|
print("-" * 60)
|
||||||
|
# Check that encoders have independent parameters
|
||||||
|
encoder_0_first_param = list(backbone_separate.rgb_encoder[0].parameters())[0]
|
||||||
|
encoder_1_first_param = list(backbone_separate.rgb_encoder[1].parameters())[0]
|
||||||
|
|
||||||
|
# Modify first encoder's parameter
|
||||||
|
with torch.no_grad():
|
||||||
|
encoder_0_first_param += 1.0
|
||||||
|
|
||||||
|
# Verify they are not the same tensor
|
||||||
|
assert not torch.allclose(encoder_0_first_param, encoder_1_first_param), \
|
||||||
|
"Encoders should have independent parameters"
|
||||||
|
|
||||||
|
print(f"✅ Encoders have independent parameters")
|
||||||
|
print(f"✨ All tests passed!")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("🎉 All tests completed successfully!")
|
||||||
|
print("=" * 60)
|
||||||
128
roboimi/vla/models/normalization.py
Normal file
128
roboimi/vla/models/normalization.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""
|
||||||
|
归一化模块 - 统一训练和推理的归一化逻辑
|
||||||
|
|
||||||
|
支持两种归一化方式:
|
||||||
|
1. Gaussian (z-score): (x - mean) / std
|
||||||
|
2. MinMax: 2 * (x - min) / (max - min) - 1 -> [-1, 1]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Optional, Dict, Literal
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizationModule(nn.Module):
|
||||||
|
"""
|
||||||
|
统一的归一化模块
|
||||||
|
|
||||||
|
用于在 Agent 内部对 qpos 和 action 进行归一化/反归一化
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stats: Optional[Dict] = None,
|
||||||
|
normalization_type: Literal['gaussian', 'min_max'] = 'gaussian'
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
stats: 数据集统计信息字典,格式:
|
||||||
|
{
|
||||||
|
'normalization_type': 'gaussian' | 'min_max',
|
||||||
|
'qpos_mean': [...],
|
||||||
|
'qpos_std': [...],
|
||||||
|
'qpos_min': [...], # 仅 min_max 需要
|
||||||
|
'qpos_max': [...], # 仅 min_max 需要
|
||||||
|
'action_mean': [...],
|
||||||
|
'action_std': [...],
|
||||||
|
'action_min': [...], # 仅 min_max 需要
|
||||||
|
'action_max': [...], # 仅 min_max 需要
|
||||||
|
}
|
||||||
|
normalization_type: 归一化类型 ('gaussian' 或 'min_max')
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.normalization_type = normalization_type
|
||||||
|
self.enabled = stats is not None
|
||||||
|
|
||||||
|
if self.enabled:
|
||||||
|
# 从 stats 中读取归一化类型(如果提供)
|
||||||
|
self.normalization_type = stats.get('normalization_type', normalization_type)
|
||||||
|
|
||||||
|
# 注册为 buffer (不会被优化,但会随模型保存)
|
||||||
|
self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32))
|
||||||
|
self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32))
|
||||||
|
self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32))
|
||||||
|
self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32))
|
||||||
|
|
||||||
|
# MinMax 归一化需要 min/max
|
||||||
|
if self.normalization_type == 'min_max':
|
||||||
|
qpos_min = stats.get('qpos_min', [0.0] * len(stats['qpos_mean']))
|
||||||
|
qpos_max = stats.get('qpos_max', [1.0] * len(stats['qpos_mean']))
|
||||||
|
action_min = stats.get('action_min', [0.0] * len(stats['action_mean']))
|
||||||
|
action_max = stats.get('action_max', [1.0] * len(stats['action_mean']))
|
||||||
|
|
||||||
|
self.register_buffer('qpos_min', torch.tensor(qpos_min, dtype=torch.float32))
|
||||||
|
self.register_buffer('qpos_max', torch.tensor(qpos_max, dtype=torch.float32))
|
||||||
|
self.register_buffer('action_min', torch.tensor(action_min, dtype=torch.float32))
|
||||||
|
self.register_buffer('action_max', torch.tensor(action_max, dtype=torch.float32))
|
||||||
|
|
||||||
|
def normalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""归一化 qpos"""
|
||||||
|
if not self.enabled:
|
||||||
|
return qpos
|
||||||
|
|
||||||
|
if self.normalization_type == 'gaussian':
|
||||||
|
return (qpos - self.qpos_mean) / self.qpos_std
|
||||||
|
else: # min_max
|
||||||
|
return 2 * (qpos - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
|
||||||
|
|
||||||
|
def denormalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""反归一化 qpos"""
|
||||||
|
if not self.enabled:
|
||||||
|
return qpos
|
||||||
|
|
||||||
|
if self.normalization_type == 'gaussian':
|
||||||
|
return qpos * self.qpos_std + self.qpos_mean
|
||||||
|
else: # min_max
|
||||||
|
return (qpos + 1) / 2 * (self.qpos_max - self.qpos_min) + self.qpos_min
|
||||||
|
|
||||||
|
def normalize_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""归一化 action"""
|
||||||
|
if not self.enabled:
|
||||||
|
return action
|
||||||
|
|
||||||
|
if self.normalization_type == 'gaussian':
|
||||||
|
return (action - self.action_mean) / self.action_std
|
||||||
|
else: # min_max
|
||||||
|
return 2 * (action - self.action_min) / (self.action_max - self.action_min) - 1
|
||||||
|
|
||||||
|
def denormalize_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""反归一化 action"""
|
||||||
|
if not self.enabled:
|
||||||
|
return action
|
||||||
|
|
||||||
|
if self.normalization_type == 'gaussian':
|
||||||
|
return action * self.action_std + self.action_mean
|
||||||
|
else: # min_max
|
||||||
|
return (action + 1) / 2 * (self.action_max - self.action_min) + self.action_min
|
||||||
|
|
||||||
|
def get_stats(self) -> Optional[Dict]:
|
||||||
|
"""导出统计信息(用于保存到 checkpoint)"""
|
||||||
|
if not self.enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
'normalization_type': self.normalization_type,
|
||||||
|
'qpos_mean': self.qpos_mean.cpu().tolist(),
|
||||||
|
'qpos_std': self.qpos_std.cpu().tolist(),
|
||||||
|
'action_mean': self.action_mean.cpu().tolist(),
|
||||||
|
'action_std': self.action_std.cpu().tolist(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.normalization_type == 'min_max':
|
||||||
|
stats['qpos_min'] = self.qpos_min.cpu().tolist()
|
||||||
|
stats['qpos_max'] = self.qpos_max.cpu().tolist()
|
||||||
|
stats['action_min'] = self.action_min.cpu().tolist()
|
||||||
|
stats['action_max'] = self.action_max.cpu().tolist()
|
||||||
|
|
||||||
|
return stats
|
||||||
Reference in New Issue
Block a user