refactor:大重构
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Backbone models
|
||||
from .resnet import ResNetBackbone
|
||||
from .resnet_diffusion import ResNetDiffusionBackbone
|
||||
|
||||
__all__ = ["ResNetBackbone"]
|
||||
__all__ = ["ResNetBackbone", "ResNetDiffusionBackbone"]
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
from transformers import ResNetModel
|
||||
from torchvision import transforms
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class ResNetBackbone(VLABackbone):
|
||||
def __init__(
|
||||
self,
|
||||
model_name = "microsoft/resnet-18",
|
||||
freeze: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = ResNetModel.from_pretrained(model_name)
|
||||
self.out_channels = self.model.config.hidden_sizes[-1]
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((384, 384)),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
self.spatial_softmax = SpatialSoftmax(num_rows=12, num_cols=12)
|
||||
if freeze:
|
||||
self._freeze_parameters()
|
||||
|
||||
def _freeze_parameters(self):
|
||||
print("❄️ Freezing ResNet Backbone parameters")
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
self.model.eval()
|
||||
|
||||
def train(self, mode=True):
|
||||
"""
|
||||
Override train() to keep frozen ResNet in eval mode.
|
||||
This ensures BatchNorm layers use running statistics consistently.
|
||||
"""
|
||||
super().train(mode)
|
||||
if hasattr(self, 'model'):
|
||||
self.model.eval() # Always keep ResNet in eval mode
|
||||
return self
|
||||
|
||||
def forward_single_image(self, image):
|
||||
B, T, C, H, W = image.shape
|
||||
image = image.view(B * T, C, H, W)
|
||||
image = self.transform(image)
|
||||
feature_map = self.model(image).last_hidden_state # (B*T, D, H', W')
|
||||
features = self.spatial_softmax(feature_map) # (B*T, D*2)
|
||||
return features
|
||||
|
||||
def forward(self, images):
|
||||
any_tensor = next(iter(images.values()))
|
||||
B, T = any_tensor.shape[:2]
|
||||
features_all = []
|
||||
sorted_cam_names = sorted(images.keys())
|
||||
for cam_name in sorted_cam_names:
|
||||
img = images[cam_name]
|
||||
features = self.forward_single_image(img) # (B*T, D*2)
|
||||
features_all.append(features)
|
||||
combined_features = torch.cat(features_all, dim=1) # (B*T, Num_Cams*D*2)
|
||||
return combined_features.view(B, T, -1)
|
||||
|
||||
@property
|
||||
def output_dim(self):
|
||||
"""Output dimension after spatial softmax: out_channels * 2"""
|
||||
return self.out_channels * 2
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
将特征图 (N, C, H, W) 转换为坐标特征 (N, C*2)
|
||||
"""
|
||||
def __init__(self, num_rows, num_cols, temperature=None):
|
||||
super().__init__()
|
||||
self.temperature = nn.Parameter(torch.ones(1))
|
||||
# 创建网格坐标
|
||||
pos_x, pos_y = torch.meshgrid(
|
||||
torch.linspace(-1, 1, num_rows),
|
||||
torch.linspace(-1, 1, num_cols),
|
||||
indexing='ij'
|
||||
)
|
||||
self.register_buffer('pos_x', pos_x.reshape(-1))
|
||||
self.register_buffer('pos_y', pos_y.reshape(-1))
|
||||
|
||||
def forward(self, x):
|
||||
N, C, H, W = x.shape
|
||||
x = x.view(N, C, -1) # (N, C, H*W)
|
||||
|
||||
# 计算 Softmax 注意力图
|
||||
softmax_attention = torch.nn.functional.softmax(x / self.temperature, dim=2)
|
||||
|
||||
# 计算期望坐标 (x, y)
|
||||
expected_x = torch.sum(softmax_attention * self.pos_x, dim=2, keepdim=True)
|
||||
expected_y = torch.sum(softmax_attention * self.pos_y, dim=2, keepdim=True)
|
||||
|
||||
# 拼接并展平 -> (N, C*2)
|
||||
return torch.cat([expected_x, expected_y], dim=2).reshape(N, -1)
|
||||
@@ -91,20 +91,21 @@ class SpatialSoftmax(nn.Module):
|
||||
|
||||
return feature_keypoints
|
||||
|
||||
class ResNetDiffusionBackbone(VLABackbone):
|
||||
class _SingleRgbEncoder(nn.Module):
|
||||
"""单个摄像头的 RGB 编码器,支持独立或共享使用"""
|
||||
def __init__(
|
||||
self,
|
||||
vision_backbone: str = "resnet18",
|
||||
pretrained_backbone_weights: str | None = None,
|
||||
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W)
|
||||
crop_shape: Optional[Tuple[int, int]] = None,
|
||||
crop_is_random: bool = True,
|
||||
use_group_norm: bool = True,
|
||||
spatial_softmax_num_keypoints: int = 32,
|
||||
vision_backbone: str,
|
||||
pretrained_backbone_weights: str | None,
|
||||
input_shape: Tuple[int, int, int],
|
||||
crop_shape: Optional[Tuple[int, int]],
|
||||
crop_is_random: bool,
|
||||
use_group_norm: bool,
|
||||
spatial_softmax_num_keypoints: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 设置可选的预处理。
|
||||
|
||||
# 设置可选的预处理
|
||||
if crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# 评估时始终使用中心裁剪
|
||||
@@ -117,14 +118,14 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
self.do_crop = False
|
||||
crop_shape = input_shape[1:]
|
||||
|
||||
# 设置骨干网络。
|
||||
# 设置骨干网络
|
||||
backbone_model = getattr(torchvision.models, vision_backbone)(
|
||||
weights=pretrained_backbone_weights
|
||||
)
|
||||
|
||||
|
||||
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
|
||||
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||
|
||||
|
||||
if use_group_norm:
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
@@ -132,12 +133,12 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
|
||||
# 设置池化和最终层。
|
||||
# 使用试运行来获取特征图形状。
|
||||
# 设置池化和最终层
|
||||
# 使用试运行来获取特征图形状
|
||||
dummy_shape = (1, input_shape[0], *crop_shape)
|
||||
with torch.no_grad():
|
||||
dummy_out = self.backbone(torch.zeros(dummy_shape))
|
||||
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
||||
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
|
||||
self.feature_dim = spatial_softmax_num_keypoints * 2
|
||||
@@ -150,58 +151,205 @@ class ResNetDiffusionBackbone(VLABackbone):
|
||||
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||
return x
|
||||
|
||||
|
||||
class ResNetDiffusionBackbone(VLABackbone):
|
||||
def __init__(
|
||||
self,
|
||||
vision_backbone: str = "resnet18",
|
||||
pretrained_backbone_weights: str | None = None,
|
||||
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W)
|
||||
crop_shape: Optional[Tuple[int, int]] = None,
|
||||
crop_is_random: bool = True,
|
||||
use_group_norm: bool = True,
|
||||
spatial_softmax_num_keypoints: int = 32,
|
||||
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
|
||||
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
||||
self.num_cameras = num_cameras
|
||||
|
||||
if use_separate_rgb_encoder_per_camera:
|
||||
# 独立编码器模式:为每个摄像头创建独立的编码器
|
||||
encoders = [
|
||||
_SingleRgbEncoder(
|
||||
vision_backbone=vision_backbone,
|
||||
pretrained_backbone_weights=pretrained_backbone_weights,
|
||||
input_shape=input_shape,
|
||||
crop_shape=crop_shape,
|
||||
crop_is_random=crop_is_random,
|
||||
use_group_norm=use_group_norm,
|
||||
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||
)
|
||||
for _ in range(num_cameras)
|
||||
]
|
||||
self.rgb_encoder = nn.ModuleList(encoders)
|
||||
# 重要:output_dim 始终表示单个编码器的特征维度(与 lerobot 保持一致)
|
||||
self.feature_dim = encoders[0].feature_dim
|
||||
else:
|
||||
# 共享编码器模式:所有摄像头共享同一个编码器
|
||||
self.rgb_encoder = _SingleRgbEncoder(
|
||||
vision_backbone=vision_backbone,
|
||||
pretrained_backbone_weights=pretrained_backbone_weights,
|
||||
input_shape=input_shape,
|
||||
crop_shape=crop_shape,
|
||||
crop_is_random=crop_is_random,
|
||||
use_group_norm=use_group_norm,
|
||||
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||
)
|
||||
self.feature_dim = self.rgb_encoder.feature_dim
|
||||
|
||||
def forward(self, images):
|
||||
"""
|
||||
Args:
|
||||
images: Dict[str, Tensor], 每个摄像头的图像
|
||||
形状: {cam_name: (B, T, C, H, W)}
|
||||
|
||||
Returns:
|
||||
Tensor: (B, T, total_feature_dim)
|
||||
"""
|
||||
any_tensor = next(iter(images.values()))
|
||||
B, T = any_tensor.shape[:2]
|
||||
features_all = []
|
||||
for cam_name in sorted(images.keys()):
|
||||
img = images[cam_name]
|
||||
features = self.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
cam_names = sorted(images.keys())
|
||||
|
||||
if self.use_separate_rgb_encoder_per_camera:
|
||||
# 独立编码器模式:每个摄像头使用对应的编码器
|
||||
features_all = []
|
||||
for cam_idx, cam_name in enumerate(cam_names):
|
||||
img = images[cam_name]
|
||||
encoder = self.rgb_encoder[cam_idx]
|
||||
features = encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
else:
|
||||
# 共享编码器模式:所有摄像头共享同一个编码器
|
||||
features_all = []
|
||||
for cam_name in cam_names:
|
||||
img = images[cam_name]
|
||||
features = self.rgb_encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
|
||||
@property
|
||||
def output_dim(self):
|
||||
return self.feature_dim
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 Testing ResNetDiffusionBackbone...")
|
||||
|
||||
print("=" * 60)
|
||||
print("🚀 Testing ResNetDiffusionBackbone")
|
||||
print("=" * 60)
|
||||
|
||||
# Configuration
|
||||
B, T = 2, 5
|
||||
C, H, W = 3, 96, 96
|
||||
crop_h, crop_w = 84, 84
|
||||
num_keypoints = 32
|
||||
feature_dim_per_cam = num_keypoints * 2
|
||||
|
||||
# Instantiate model
|
||||
backbone = ResNetDiffusionBackbone(
|
||||
vision_backbone="resnet18",
|
||||
pretrained_backbone_weights=None, # Speed up test
|
||||
input_shape=(C, H, W),
|
||||
crop_shape=(crop_h, crop_w),
|
||||
crop_is_random=True,
|
||||
use_group_norm=True,
|
||||
spatial_softmax_num_keypoints=num_keypoints
|
||||
)
|
||||
|
||||
print(f"✅ Model instantiated. Output dim per camera: {backbone.output_dim}")
|
||||
|
||||
# Create dummy input
|
||||
|
||||
# Create dummy input (2 cameras)
|
||||
images = {
|
||||
"cam_high": torch.randn(B, T, C, H, W),
|
||||
"cam_wrist": torch.randn(B, T, C, H, W)
|
||||
}
|
||||
|
||||
# Forward pass
|
||||
print("🔄 Running forward pass...")
|
||||
output = backbone(images)
|
||||
|
||||
print(f"Input shapes: {[v.shape for v in images.values()]}")
|
||||
print(f"Output shape: {output.shape}")
|
||||
|
||||
# Verification
|
||||
expected_dim = len(images) * feature_dim_per_cam
|
||||
num_cameras = len(images)
|
||||
|
||||
# ============================================================================
|
||||
# Test 1: Shared Encoder (默认模式)
|
||||
# ============================================================================
|
||||
print("\n[Test 1] Shared Encoder Mode")
|
||||
print("-" * 60)
|
||||
backbone_shared = ResNetDiffusionBackbone(
|
||||
vision_backbone="resnet18",
|
||||
pretrained_backbone_weights=None, # Speed up test
|
||||
input_shape=(C, H, W),
|
||||
crop_shape=(crop_h, crop_w),
|
||||
crop_is_random=True,
|
||||
use_group_norm=True,
|
||||
spatial_softmax_num_keypoints=num_keypoints,
|
||||
use_separate_rgb_encoder_per_camera=False, # 共享编码器
|
||||
)
|
||||
|
||||
print(f"✅ Shared encoder model instantiated")
|
||||
print(f" Output dim per camera: {feature_dim_per_cam}")
|
||||
print(f" Number of cameras: {num_cameras}")
|
||||
print(f" Expected total dim: {num_cameras * feature_dim_per_cam}")
|
||||
|
||||
output = backbone_shared(images)
|
||||
print(f"\n🔄 Forward pass completed")
|
||||
print(f" Input shapes: {[v.shape for v in images.values()]}")
|
||||
print(f" Output shape: {output.shape}")
|
||||
|
||||
expected_dim = num_cameras * feature_dim_per_cam
|
||||
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
|
||||
|
||||
print("✨ Test passed!")
|
||||
print(f"✨ Test passed!")
|
||||
|
||||
# ============================================================================
|
||||
# Test 2: Separate Encoders (独立编码器模式)
|
||||
# ============================================================================
|
||||
print("\n[Test 2] Separate Encoders Mode")
|
||||
print("-" * 60)
|
||||
backbone_separate = ResNetDiffusionBackbone(
|
||||
vision_backbone="resnet18",
|
||||
pretrained_backbone_weights=None, # Speed up test
|
||||
input_shape=(C, H, W),
|
||||
crop_shape=(crop_h, crop_w),
|
||||
crop_is_random=True,
|
||||
use_group_norm=True,
|
||||
spatial_softmax_num_keypoints=num_keypoints,
|
||||
use_separate_rgb_encoder_per_camera=True, # 独立编码器
|
||||
num_cameras=num_cameras,
|
||||
)
|
||||
|
||||
print(f"✅ Separate encoders model instantiated")
|
||||
print(f" Output dim per camera: {feature_dim_per_cam}")
|
||||
print(f" Number of cameras: {num_cameras}")
|
||||
print(f" Number of encoders: {len(backbone_separate.rgb_encoder)}")
|
||||
|
||||
output = backbone_separate(images)
|
||||
print(f"\n🔄 Forward pass completed")
|
||||
print(f" Input shapes: {[v.shape for v in images.values()]}")
|
||||
print(f" Output shape: {output.shape}")
|
||||
|
||||
expected_dim = num_cameras * feature_dim_per_cam
|
||||
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
|
||||
print(f"✨ Test passed!")
|
||||
|
||||
# ============================================================================
|
||||
# Test 3: Verify parameters count
|
||||
# ============================================================================
|
||||
print("\n[Test 3] Parameter Count Comparison")
|
||||
print("-" * 60)
|
||||
shared_params = sum(p.numel() for p in backbone_shared.parameters())
|
||||
separate_params = sum(p.numel() for p in backbone_separate.parameters())
|
||||
|
||||
print(f" Shared encoder parameters: {shared_params:,}")
|
||||
print(f" Separate encoders parameters: {separate_params:,}")
|
||||
print(f" Ratio: {separate_params / shared_params:.2f}x")
|
||||
|
||||
assert separate_params > shared_params, "Separate encoders should have more parameters"
|
||||
print(f"✨ Verification passed!")
|
||||
|
||||
# ============================================================================
|
||||
# Test 4: Verify independent parameters
|
||||
# ============================================================================
|
||||
print("\n[Test 4] Verify Independent Parameters")
|
||||
print("-" * 60)
|
||||
# Check that encoders have independent parameters
|
||||
encoder_0_first_param = list(backbone_separate.rgb_encoder[0].parameters())[0]
|
||||
encoder_1_first_param = list(backbone_separate.rgb_encoder[1].parameters())[0]
|
||||
|
||||
# Modify first encoder's parameter
|
||||
with torch.no_grad():
|
||||
encoder_0_first_param += 1.0
|
||||
|
||||
# Verify they are not the same tensor
|
||||
assert not torch.allclose(encoder_0_first_param, encoder_1_first_param), \
|
||||
"Encoders should have independent parameters"
|
||||
|
||||
print(f"✅ Encoders have independent parameters")
|
||||
print(f"✨ All tests passed!")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 All tests completed successfully!")
|
||||
print("=" * 60)
|
||||
128
roboimi/vla/models/normalization.py
Normal file
128
roboimi/vla/models/normalization.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
归一化模块 - 统一训练和推理的归一化逻辑
|
||||
|
||||
支持两种归一化方式:
|
||||
1. Gaussian (z-score): (x - mean) / std
|
||||
2. MinMax: 2 * (x - min) / (max - min) - 1 -> [-1, 1]
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Dict, Literal
|
||||
|
||||
|
||||
class NormalizationModule(nn.Module):
|
||||
"""
|
||||
统一的归一化模块
|
||||
|
||||
用于在 Agent 内部对 qpos 和 action 进行归一化/反归一化
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats: Optional[Dict] = None,
|
||||
normalization_type: Literal['gaussian', 'min_max'] = 'gaussian'
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
stats: 数据集统计信息字典,格式:
|
||||
{
|
||||
'normalization_type': 'gaussian' | 'min_max',
|
||||
'qpos_mean': [...],
|
||||
'qpos_std': [...],
|
||||
'qpos_min': [...], # 仅 min_max 需要
|
||||
'qpos_max': [...], # 仅 min_max 需要
|
||||
'action_mean': [...],
|
||||
'action_std': [...],
|
||||
'action_min': [...], # 仅 min_max 需要
|
||||
'action_max': [...], # 仅 min_max 需要
|
||||
}
|
||||
normalization_type: 归一化类型 ('gaussian' 或 'min_max')
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.normalization_type = normalization_type
|
||||
self.enabled = stats is not None
|
||||
|
||||
if self.enabled:
|
||||
# 从 stats 中读取归一化类型(如果提供)
|
||||
self.normalization_type = stats.get('normalization_type', normalization_type)
|
||||
|
||||
# 注册为 buffer (不会被优化,但会随模型保存)
|
||||
self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32))
|
||||
self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32))
|
||||
self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32))
|
||||
self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32))
|
||||
|
||||
# MinMax 归一化需要 min/max
|
||||
if self.normalization_type == 'min_max':
|
||||
qpos_min = stats.get('qpos_min', [0.0] * len(stats['qpos_mean']))
|
||||
qpos_max = stats.get('qpos_max', [1.0] * len(stats['qpos_mean']))
|
||||
action_min = stats.get('action_min', [0.0] * len(stats['action_mean']))
|
||||
action_max = stats.get('action_max', [1.0] * len(stats['action_mean']))
|
||||
|
||||
self.register_buffer('qpos_min', torch.tensor(qpos_min, dtype=torch.float32))
|
||||
self.register_buffer('qpos_max', torch.tensor(qpos_max, dtype=torch.float32))
|
||||
self.register_buffer('action_min', torch.tensor(action_min, dtype=torch.float32))
|
||||
self.register_buffer('action_max', torch.tensor(action_max, dtype=torch.float32))
|
||||
|
||||
def normalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
|
||||
"""归一化 qpos"""
|
||||
if not self.enabled:
|
||||
return qpos
|
||||
|
||||
if self.normalization_type == 'gaussian':
|
||||
return (qpos - self.qpos_mean) / self.qpos_std
|
||||
else: # min_max
|
||||
return 2 * (qpos - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
|
||||
|
||||
def denormalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
|
||||
"""反归一化 qpos"""
|
||||
if not self.enabled:
|
||||
return qpos
|
||||
|
||||
if self.normalization_type == 'gaussian':
|
||||
return qpos * self.qpos_std + self.qpos_mean
|
||||
else: # min_max
|
||||
return (qpos + 1) / 2 * (self.qpos_max - self.qpos_min) + self.qpos_min
|
||||
|
||||
def normalize_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""归一化 action"""
|
||||
if not self.enabled:
|
||||
return action
|
||||
|
||||
if self.normalization_type == 'gaussian':
|
||||
return (action - self.action_mean) / self.action_std
|
||||
else: # min_max
|
||||
return 2 * (action - self.action_min) / (self.action_max - self.action_min) - 1
|
||||
|
||||
def denormalize_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""反归一化 action"""
|
||||
if not self.enabled:
|
||||
return action
|
||||
|
||||
if self.normalization_type == 'gaussian':
|
||||
return action * self.action_std + self.action_mean
|
||||
else: # min_max
|
||||
return (action + 1) / 2 * (self.action_max - self.action_min) + self.action_min
|
||||
|
||||
def get_stats(self) -> Optional[Dict]:
|
||||
"""导出统计信息(用于保存到 checkpoint)"""
|
||||
if not self.enabled:
|
||||
return None
|
||||
|
||||
stats = {
|
||||
'normalization_type': self.normalization_type,
|
||||
'qpos_mean': self.qpos_mean.cpu().tolist(),
|
||||
'qpos_std': self.qpos_std.cpu().tolist(),
|
||||
'action_mean': self.action_mean.cpu().tolist(),
|
||||
'action_std': self.action_std.cpu().tolist(),
|
||||
}
|
||||
|
||||
if self.normalization_type == 'min_max':
|
||||
stats['qpos_min'] = self.qpos_min.cpu().tolist()
|
||||
stats['qpos_max'] = self.qpos_max.cpu().tolist()
|
||||
stats['action_min'] = self.action_min.cpu().tolist()
|
||||
stats['action_max'] = self.action_max.cpu().tolist()
|
||||
|
||||
return stats
|
||||
Reference in New Issue
Block a user