refactor:大重构

This commit is contained in:
gouhanke
2026-02-11 15:53:55 +08:00
parent 1e95d40bf9
commit 130d4bb3c5
19 changed files with 1411 additions and 1223 deletions

View File

@@ -1,4 +1,4 @@
# Backbone models
from .resnet import ResNetBackbone
from .resnet_diffusion import ResNetDiffusionBackbone
__all__ = ["ResNetBackbone"]
__all__ = ["ResNetBackbone", "ResNetDiffusionBackbone"]

View File

@@ -1,93 +0,0 @@
from roboimi.vla.core.interfaces import VLABackbone
from transformers import ResNetModel
from torchvision import transforms
import torch
import torch.nn as nn
class ResNetBackbone(VLABackbone):
def __init__(
self,
model_name = "microsoft/resnet-18",
freeze: bool = True,
):
super().__init__()
self.model = ResNetModel.from_pretrained(model_name)
self.out_channels = self.model.config.hidden_sizes[-1]
self.transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.spatial_softmax = SpatialSoftmax(num_rows=12, num_cols=12)
if freeze:
self._freeze_parameters()
def _freeze_parameters(self):
print("❄️ Freezing ResNet Backbone parameters")
for param in self.model.parameters():
param.requires_grad = False
self.model.eval()
def train(self, mode=True):
"""
Override train() to keep frozen ResNet in eval mode.
This ensures BatchNorm layers use running statistics consistently.
"""
super().train(mode)
if hasattr(self, 'model'):
self.model.eval() # Always keep ResNet in eval mode
return self
def forward_single_image(self, image):
B, T, C, H, W = image.shape
image = image.view(B * T, C, H, W)
image = self.transform(image)
feature_map = self.model(image).last_hidden_state # (B*T, D, H', W')
features = self.spatial_softmax(feature_map) # (B*T, D*2)
return features
def forward(self, images):
any_tensor = next(iter(images.values()))
B, T = any_tensor.shape[:2]
features_all = []
sorted_cam_names = sorted(images.keys())
for cam_name in sorted_cam_names:
img = images[cam_name]
features = self.forward_single_image(img) # (B*T, D*2)
features_all.append(features)
combined_features = torch.cat(features_all, dim=1) # (B*T, Num_Cams*D*2)
return combined_features.view(B, T, -1)
@property
def output_dim(self):
"""Output dimension after spatial softmax: out_channels * 2"""
return self.out_channels * 2
class SpatialSoftmax(nn.Module):
"""
将特征图 (N, C, H, W) 转换为坐标特征 (N, C*2)
"""
def __init__(self, num_rows, num_cols, temperature=None):
super().__init__()
self.temperature = nn.Parameter(torch.ones(1))
# 创建网格坐标
pos_x, pos_y = torch.meshgrid(
torch.linspace(-1, 1, num_rows),
torch.linspace(-1, 1, num_cols),
indexing='ij'
)
self.register_buffer('pos_x', pos_x.reshape(-1))
self.register_buffer('pos_y', pos_y.reshape(-1))
def forward(self, x):
N, C, H, W = x.shape
x = x.view(N, C, -1) # (N, C, H*W)
# 计算 Softmax 注意力图
softmax_attention = torch.nn.functional.softmax(x / self.temperature, dim=2)
# 计算期望坐标 (x, y)
expected_x = torch.sum(softmax_attention * self.pos_x, dim=2, keepdim=True)
expected_y = torch.sum(softmax_attention * self.pos_y, dim=2, keepdim=True)
# 拼接并展平 -> (N, C*2)
return torch.cat([expected_x, expected_y], dim=2).reshape(N, -1)

View File

@@ -91,20 +91,21 @@ class SpatialSoftmax(nn.Module):
return feature_keypoints
class ResNetDiffusionBackbone(VLABackbone):
class _SingleRgbEncoder(nn.Module):
"""单个摄像头的 RGB 编码器,支持独立或共享使用"""
def __init__(
self,
vision_backbone: str = "resnet18",
pretrained_backbone_weights: str | None = None,
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W)
crop_shape: Optional[Tuple[int, int]] = None,
crop_is_random: bool = True,
use_group_norm: bool = True,
spatial_softmax_num_keypoints: int = 32,
vision_backbone: str,
pretrained_backbone_weights: str | None,
input_shape: Tuple[int, int, int],
crop_shape: Optional[Tuple[int, int]],
crop_is_random: bool,
use_group_norm: bool,
spatial_softmax_num_keypoints: int,
):
super().__init__()
# 设置可选的预处理
# 设置可选的预处理
if crop_shape is not None:
self.do_crop = True
# 评估时始终使用中心裁剪
@@ -117,14 +118,14 @@ class ResNetDiffusionBackbone(VLABackbone):
self.do_crop = False
crop_shape = input_shape[1:]
# 设置骨干网络
# 设置骨干网络
backbone_model = getattr(torchvision.models, vision_backbone)(
weights=pretrained_backbone_weights
)
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if use_group_norm:
self.backbone = _replace_submodules(
root_module=self.backbone,
@@ -132,12 +133,12 @@ class ResNetDiffusionBackbone(VLABackbone):
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# 设置池化和最终层
# 使用试运行来获取特征图形状
# 设置池化和最终层
# 使用试运行来获取特征图形状
dummy_shape = (1, input_shape[0], *crop_shape)
with torch.no_grad():
dummy_out = self.backbone(torch.zeros(dummy_shape))
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
self.pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
self.feature_dim = spatial_softmax_num_keypoints * 2
@@ -150,58 +151,205 @@ class ResNetDiffusionBackbone(VLABackbone):
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
return x
class ResNetDiffusionBackbone(VLABackbone):
def __init__(
self,
vision_backbone: str = "resnet18",
pretrained_backbone_weights: str | None = None,
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W)
crop_shape: Optional[Tuple[int, int]] = None,
crop_is_random: bool = True,
use_group_norm: bool = True,
spatial_softmax_num_keypoints: int = 32,
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
):
super().__init__()
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
self.num_cameras = num_cameras
if use_separate_rgb_encoder_per_camera:
# 独立编码器模式:为每个摄像头创建独立的编码器
encoders = [
_SingleRgbEncoder(
vision_backbone=vision_backbone,
pretrained_backbone_weights=pretrained_backbone_weights,
input_shape=input_shape,
crop_shape=crop_shape,
crop_is_random=crop_is_random,
use_group_norm=use_group_norm,
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
)
for _ in range(num_cameras)
]
self.rgb_encoder = nn.ModuleList(encoders)
# 重要output_dim 始终表示单个编码器的特征维度(与 lerobot 保持一致)
self.feature_dim = encoders[0].feature_dim
else:
# 共享编码器模式:所有摄像头共享同一个编码器
self.rgb_encoder = _SingleRgbEncoder(
vision_backbone=vision_backbone,
pretrained_backbone_weights=pretrained_backbone_weights,
input_shape=input_shape,
crop_shape=crop_shape,
crop_is_random=crop_is_random,
use_group_norm=use_group_norm,
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
)
self.feature_dim = self.rgb_encoder.feature_dim
def forward(self, images):
"""
Args:
images: Dict[str, Tensor], 每个摄像头的图像
形状: {cam_name: (B, T, C, H, W)}
Returns:
Tensor: (B, T, total_feature_dim)
"""
any_tensor = next(iter(images.values()))
B, T = any_tensor.shape[:2]
features_all = []
for cam_name in sorted(images.keys()):
img = images[cam_name]
features = self.forward_single_image(img.view(B * T, *img.shape[2:]))
features_all.append(features)
return torch.cat(features_all, dim=1).view(B, T, -1)
cam_names = sorted(images.keys())
if self.use_separate_rgb_encoder_per_camera:
# 独立编码器模式:每个摄像头使用对应的编码器
features_all = []
for cam_idx, cam_name in enumerate(cam_names):
img = images[cam_name]
encoder = self.rgb_encoder[cam_idx]
features = encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
features_all.append(features)
return torch.cat(features_all, dim=1).view(B, T, -1)
else:
# 共享编码器模式:所有摄像头共享同一个编码器
features_all = []
for cam_name in cam_names:
img = images[cam_name]
features = self.rgb_encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
features_all.append(features)
return torch.cat(features_all, dim=1).view(B, T, -1)
@property
def output_dim(self):
return self.feature_dim
if __name__ == "__main__":
print("🚀 Testing ResNetDiffusionBackbone...")
print("=" * 60)
print("🚀 Testing ResNetDiffusionBackbone")
print("=" * 60)
# Configuration
B, T = 2, 5
C, H, W = 3, 96, 96
crop_h, crop_w = 84, 84
num_keypoints = 32
feature_dim_per_cam = num_keypoints * 2
# Instantiate model
backbone = ResNetDiffusionBackbone(
vision_backbone="resnet18",
pretrained_backbone_weights=None, # Speed up test
input_shape=(C, H, W),
crop_shape=(crop_h, crop_w),
crop_is_random=True,
use_group_norm=True,
spatial_softmax_num_keypoints=num_keypoints
)
print(f"✅ Model instantiated. Output dim per camera: {backbone.output_dim}")
# Create dummy input
# Create dummy input (2 cameras)
images = {
"cam_high": torch.randn(B, T, C, H, W),
"cam_wrist": torch.randn(B, T, C, H, W)
}
# Forward pass
print("🔄 Running forward pass...")
output = backbone(images)
print(f"Input shapes: {[v.shape for v in images.values()]}")
print(f"Output shape: {output.shape}")
# Verification
expected_dim = len(images) * feature_dim_per_cam
num_cameras = len(images)
# ============================================================================
# Test 1: Shared Encoder (默认模式)
# ============================================================================
print("\n[Test 1] Shared Encoder Mode")
print("-" * 60)
backbone_shared = ResNetDiffusionBackbone(
vision_backbone="resnet18",
pretrained_backbone_weights=None, # Speed up test
input_shape=(C, H, W),
crop_shape=(crop_h, crop_w),
crop_is_random=True,
use_group_norm=True,
spatial_softmax_num_keypoints=num_keypoints,
use_separate_rgb_encoder_per_camera=False, # 共享编码器
)
print(f"✅ Shared encoder model instantiated")
print(f" Output dim per camera: {feature_dim_per_cam}")
print(f" Number of cameras: {num_cameras}")
print(f" Expected total dim: {num_cameras * feature_dim_per_cam}")
output = backbone_shared(images)
print(f"\n🔄 Forward pass completed")
print(f" Input shapes: {[v.shape for v in images.values()]}")
print(f" Output shape: {output.shape}")
expected_dim = num_cameras * feature_dim_per_cam
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
print("✨ Test passed!")
print(f"✨ Test passed!")
# ============================================================================
# Test 2: Separate Encoders (独立编码器模式)
# ============================================================================
print("\n[Test 2] Separate Encoders Mode")
print("-" * 60)
backbone_separate = ResNetDiffusionBackbone(
vision_backbone="resnet18",
pretrained_backbone_weights=None, # Speed up test
input_shape=(C, H, W),
crop_shape=(crop_h, crop_w),
crop_is_random=True,
use_group_norm=True,
spatial_softmax_num_keypoints=num_keypoints,
use_separate_rgb_encoder_per_camera=True, # 独立编码器
num_cameras=num_cameras,
)
print(f"✅ Separate encoders model instantiated")
print(f" Output dim per camera: {feature_dim_per_cam}")
print(f" Number of cameras: {num_cameras}")
print(f" Number of encoders: {len(backbone_separate.rgb_encoder)}")
output = backbone_separate(images)
print(f"\n🔄 Forward pass completed")
print(f" Input shapes: {[v.shape for v in images.values()]}")
print(f" Output shape: {output.shape}")
expected_dim = num_cameras * feature_dim_per_cam
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
print(f"✨ Test passed!")
# ============================================================================
# Test 3: Verify parameters count
# ============================================================================
print("\n[Test 3] Parameter Count Comparison")
print("-" * 60)
shared_params = sum(p.numel() for p in backbone_shared.parameters())
separate_params = sum(p.numel() for p in backbone_separate.parameters())
print(f" Shared encoder parameters: {shared_params:,}")
print(f" Separate encoders parameters: {separate_params:,}")
print(f" Ratio: {separate_params / shared_params:.2f}x")
assert separate_params > shared_params, "Separate encoders should have more parameters"
print(f"✨ Verification passed!")
# ============================================================================
# Test 4: Verify independent parameters
# ============================================================================
print("\n[Test 4] Verify Independent Parameters")
print("-" * 60)
# Check that encoders have independent parameters
encoder_0_first_param = list(backbone_separate.rgb_encoder[0].parameters())[0]
encoder_1_first_param = list(backbone_separate.rgb_encoder[1].parameters())[0]
# Modify first encoder's parameter
with torch.no_grad():
encoder_0_first_param += 1.0
# Verify they are not the same tensor
assert not torch.allclose(encoder_0_first_param, encoder_1_first_param), \
"Encoders should have independent parameters"
print(f"✅ Encoders have independent parameters")
print(f"✨ All tests passed!")
print("\n" + "=" * 60)
print("🎉 All tests completed successfully!")
print("=" * 60)

View File

@@ -0,0 +1,128 @@
"""
归一化模块 - 统一训练和推理的归一化逻辑
支持两种归一化方式:
1. Gaussian (z-score): (x - mean) / std
2. MinMax: 2 * (x - min) / (max - min) - 1 -> [-1, 1]
"""
import torch
import torch.nn as nn
from typing import Optional, Dict, Literal
class NormalizationModule(nn.Module):
"""
统一的归一化模块
用于在 Agent 内部对 qpos 和 action 进行归一化/反归一化
"""
def __init__(
self,
stats: Optional[Dict] = None,
normalization_type: Literal['gaussian', 'min_max'] = 'gaussian'
):
"""
Args:
stats: 数据集统计信息字典,格式:
{
'normalization_type': 'gaussian' | 'min_max',
'qpos_mean': [...],
'qpos_std': [...],
'qpos_min': [...], # 仅 min_max 需要
'qpos_max': [...], # 仅 min_max 需要
'action_mean': [...],
'action_std': [...],
'action_min': [...], # 仅 min_max 需要
'action_max': [...], # 仅 min_max 需要
}
normalization_type: 归一化类型 ('gaussian''min_max')
"""
super().__init__()
self.normalization_type = normalization_type
self.enabled = stats is not None
if self.enabled:
# 从 stats 中读取归一化类型(如果提供)
self.normalization_type = stats.get('normalization_type', normalization_type)
# 注册为 buffer (不会被优化,但会随模型保存)
self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32))
self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32))
self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32))
self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32))
# MinMax 归一化需要 min/max
if self.normalization_type == 'min_max':
qpos_min = stats.get('qpos_min', [0.0] * len(stats['qpos_mean']))
qpos_max = stats.get('qpos_max', [1.0] * len(stats['qpos_mean']))
action_min = stats.get('action_min', [0.0] * len(stats['action_mean']))
action_max = stats.get('action_max', [1.0] * len(stats['action_mean']))
self.register_buffer('qpos_min', torch.tensor(qpos_min, dtype=torch.float32))
self.register_buffer('qpos_max', torch.tensor(qpos_max, dtype=torch.float32))
self.register_buffer('action_min', torch.tensor(action_min, dtype=torch.float32))
self.register_buffer('action_max', torch.tensor(action_max, dtype=torch.float32))
def normalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
"""归一化 qpos"""
if not self.enabled:
return qpos
if self.normalization_type == 'gaussian':
return (qpos - self.qpos_mean) / self.qpos_std
else: # min_max
return 2 * (qpos - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
def denormalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
"""反归一化 qpos"""
if not self.enabled:
return qpos
if self.normalization_type == 'gaussian':
return qpos * self.qpos_std + self.qpos_mean
else: # min_max
return (qpos + 1) / 2 * (self.qpos_max - self.qpos_min) + self.qpos_min
def normalize_action(self, action: torch.Tensor) -> torch.Tensor:
"""归一化 action"""
if not self.enabled:
return action
if self.normalization_type == 'gaussian':
return (action - self.action_mean) / self.action_std
else: # min_max
return 2 * (action - self.action_min) / (self.action_max - self.action_min) - 1
def denormalize_action(self, action: torch.Tensor) -> torch.Tensor:
"""反归一化 action"""
if not self.enabled:
return action
if self.normalization_type == 'gaussian':
return action * self.action_std + self.action_mean
else: # min_max
return (action + 1) / 2 * (self.action_max - self.action_min) + self.action_min
def get_stats(self) -> Optional[Dict]:
"""导出统计信息(用于保存到 checkpoint"""
if not self.enabled:
return None
stats = {
'normalization_type': self.normalization_type,
'qpos_mean': self.qpos_mean.cpu().tolist(),
'qpos_std': self.qpos_std.cpu().tolist(),
'action_mean': self.action_mean.cpu().tolist(),
'action_std': self.action_std.cpu().tolist(),
}
if self.normalization_type == 'min_max':
stats['qpos_min'] = self.qpos_min.cpu().tolist()
stats['qpos_max'] = self.qpos_max.cpu().tolist()
stats['action_min'] = self.action_min.cpu().tolist()
stats['action_max'] = self.action_max.cpu().tolist()
return stats