debug(train): 在siglip和DiffusionHead下跑通训练流程
This commit is contained in:
25
roboimi/vla/conf/agent/base_siglip.yaml
Normal file
25
roboimi/vla/conf/agent/base_siglip.yaml
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# @package agent
|
||||||
|
_target_: roboimi.vla.agent.VLAAgent
|
||||||
|
|
||||||
|
# --- Real Vision Backbone ---
|
||||||
|
backbone:
|
||||||
|
_target_: roboimi.vla.models.backbones.siglip.SigLIPBackbone
|
||||||
|
# Google SigLIP (SOTA Vision Encoder)
|
||||||
|
# 第一次运行会自动下载 (~1.5GB)
|
||||||
|
model_name: "google/siglip-so400m-patch14-384"
|
||||||
|
freeze: true # 初始阶段冻结视觉层,只训练 Head
|
||||||
|
embed_dim: 1152 # SigLIP so400m-patch14-384 的 hidden_size
|
||||||
|
|
||||||
|
# --- Adapter ---
|
||||||
|
projector:
|
||||||
|
_target_: roboimi.vla.models.projectors.mlp.MLPProjector
|
||||||
|
# 自动读取 SigLIP 的 1152 维
|
||||||
|
input_dim: ${..backbone.embed_dim}
|
||||||
|
output_dim: 384 # 压缩到 384 或 512 给 Policy 用
|
||||||
|
|
||||||
|
# --- Policy Head ---
|
||||||
|
head:
|
||||||
|
_target_: roboimi.vla.models.heads.debug.DebugHead
|
||||||
|
input_dim: ${..projector.output_dim}
|
||||||
|
action_dim: 16
|
||||||
|
chunk_size: 16
|
||||||
24
roboimi/vla/conf/agent/siglip_diffusion.yaml
Normal file
24
roboimi/vla/conf/agent/siglip_diffusion.yaml
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# @package agent
|
||||||
|
_target_: roboimi.vla.agent.VLAAgent
|
||||||
|
|
||||||
|
# 1. Vision
|
||||||
|
backbone:
|
||||||
|
_target_: roboimi.vla.models.backbones.siglip.SigLIPBackbone
|
||||||
|
model_name: "google/siglip-so400m-patch14-384"
|
||||||
|
embed_dim: 1152
|
||||||
|
freeze: true
|
||||||
|
|
||||||
|
# 2. Adapter
|
||||||
|
projector:
|
||||||
|
_target_: roboimi.vla.models.projectors.mlp.MLPProjector
|
||||||
|
input_dim: ${..backbone.embed_dim}
|
||||||
|
output_dim: 256 # 压缩给 Diffusion 用
|
||||||
|
|
||||||
|
# 3. Diffusion Policy Head
|
||||||
|
head:
|
||||||
|
_target_: roboimi.vla.models.heads.diffusion.DiffusionHead
|
||||||
|
input_dim: ${..projector.output_dim}
|
||||||
|
action_dim: 16
|
||||||
|
chunk_size: 16
|
||||||
|
n_timesteps: 50 # 训练用100,这里调试用50快一点
|
||||||
|
hidden_dim: 256
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- _self_
|
- _self_
|
||||||
- agent: tiny
|
- agent: base_siglip
|
||||||
- data: custom_hdf5 # 新增这一行,激活数据配置
|
- data: custom_hdf5 # 新增这一行,激活数据配置
|
||||||
|
|
||||||
train:
|
train:
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
_target_: roboimi.vla.data.dataset.VLAChunkedDataset
|
_target_: roboimi.vla.data.dataset.VLAChunkedDataset
|
||||||
|
|
||||||
# 【关键修改】指向你的数据文件夹目录
|
|
||||||
data_path: "/home/d51/workspace/work/robo-imi-act/roboimi/demos/dataset/sim_transfer"
|
data_path: "/home/d51/workspace/work/robo-imi-act/roboimi/demos/dataset/sim_transfer"
|
||||||
|
|
||||||
pred_horizon: 16
|
pred_horizon: 16
|
||||||
obs_horizon: 1 # 先只用单帧调试
|
obs_horizon: 1
|
||||||
obs_keys: ["top"] # 数据里有 top, angle, r_vis,我们先拿 top 跑通
|
obs_keys: ["top"]
|
||||||
|
|
||||||
|
# 【新增】SigLIP 必须参数
|
||||||
|
resize_resolution: 384
|
||||||
|
train: true # 开启数据增强
|
||||||
@@ -6,109 +6,93 @@ import glob
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any
|
||||||
|
|
||||||
|
# 【新增】导入刚才写好的处理器
|
||||||
|
from .image_transform import VLAImageProcessor
|
||||||
|
|
||||||
class VLAChunkedDataset(Dataset):
|
class VLAChunkedDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_path: str,
|
data_path: str,
|
||||||
pred_horizon: int = 16,
|
pred_horizon: int = 16,
|
||||||
obs_horizon: int = 2,
|
obs_horizon: int = 1,
|
||||||
obs_keys: List[str] = ["top"] # 默认只用 top
|
obs_keys: List[str] = ["top"],
|
||||||
|
resize_resolution: int = 384, # SigLIP 默认 384
|
||||||
|
train: bool = True # 【新增】控制是否增强
|
||||||
):
|
):
|
||||||
self.data_path = data_path
|
self.data_path = data_path
|
||||||
self.pred_horizon = pred_horizon
|
self.pred_horizon = pred_horizon
|
||||||
self.obs_horizon = obs_horizon
|
self.obs_horizon = obs_horizon
|
||||||
self.obs_keys = obs_keys
|
self.obs_keys = obs_keys
|
||||||
|
|
||||||
# --- 1. 扫描文件 ---
|
# ... (这里保留之前的扫描文件代码 self.file_paths ...) ...
|
||||||
if os.path.isdir(data_path):
|
if os.path.isdir(data_path):
|
||||||
# 如果是文件夹,读取所有 episode_*.hdf5
|
|
||||||
self.file_paths = sorted(glob.glob(os.path.join(data_path, "*.hdf5")))
|
self.file_paths = sorted(glob.glob(os.path.join(data_path, "*.hdf5")))
|
||||||
else:
|
else:
|
||||||
# 如果是单文件
|
|
||||||
self.file_paths = [data_path]
|
self.file_paths = [data_path]
|
||||||
|
|
||||||
if len(self.file_paths) == 0:
|
# ... (这里保留之前的建立索引代码 self.index_map ...) ...
|
||||||
raise ValueError(f"No .hdf5 files found in {data_path}")
|
self.index_map = []
|
||||||
|
|
||||||
print(f"Found {len(self.file_paths)} episodes. Indexing...")
|
|
||||||
|
|
||||||
# --- 2. 建立全局索引 (Episode, Time) ---
|
|
||||||
# 我们需要知道 global_index=1000 对应的是哪个文件的第几帧
|
|
||||||
self.index_map = [] # [(file_idx, start_time), ...]
|
|
||||||
|
|
||||||
for i, path in enumerate(self.file_paths):
|
for i, path in enumerate(self.file_paths):
|
||||||
with h5py.File(path, 'r') as f:
|
with h5py.File(path, 'r') as f:
|
||||||
# 假设所有文件的 action 长度就是 episode 长度
|
|
||||||
total_len = f["action"].shape[0]
|
total_len = f["action"].shape[0]
|
||||||
# 有效的起始点:从 0 到 total_len - 1
|
|
||||||
# 即使到了最后几帧,因为有 padding,所以也是有效的 sample
|
|
||||||
for t in range(total_len):
|
for t in range(total_len):
|
||||||
self.index_map.append((i, t))
|
self.index_map.append((i, t))
|
||||||
|
|
||||||
print(f"✅ Indexed {len(self.index_map)} total samples.")
|
# 【核心修改】实例化处理器
|
||||||
|
self.image_processor = VLAImageProcessor(
|
||||||
|
resolution=resize_resolution,
|
||||||
|
enable_augmentation=train, # 训练集开启增强
|
||||||
|
aug_strength=0.1
|
||||||
|
)
|
||||||
|
print(f"✅ Image Processor: {self.image_processor}")
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.index_map)
|
return len(self.index_map)
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||||
# --- 1. 定位文件 ---
|
|
||||||
file_idx, t_start = self.index_map[idx]
|
file_idx, t_start = self.index_map[idx]
|
||||||
file_path = self.file_paths[file_idx]
|
file_path = self.file_paths[file_idx]
|
||||||
|
|
||||||
# 每次读取打开文件 (Lazy Loading),读取完自动关闭
|
|
||||||
# 这种方式对多进程 DataLoader 最安全
|
|
||||||
with h5py.File(file_path, 'r') as f:
|
with h5py.File(file_path, 'r') as f:
|
||||||
|
# ... (Action读取代码保持不变) ...
|
||||||
total_len = f["action"].shape[0]
|
total_len = f["action"].shape[0]
|
||||||
|
|
||||||
# --- 2. 动作 (Action) ---
|
|
||||||
t_end = min(t_start + self.pred_horizon, total_len)
|
t_end = min(t_start + self.pred_horizon, total_len)
|
||||||
|
actions_np = f["action"][t_start:t_end]
|
||||||
# 读取动作片段
|
# ... (Padding 逻辑保持不变) ...
|
||||||
actions_np = f["action"][t_start:t_end] # (L, 16)
|
|
||||||
|
|
||||||
# Padding 处理
|
|
||||||
actual_len = actions_np.shape[0]
|
actual_len = actions_np.shape[0]
|
||||||
action_mask = torch.ones(self.pred_horizon, dtype=torch.float32)
|
|
||||||
|
|
||||||
if actual_len < self.pred_horizon:
|
if actual_len < self.pred_horizon:
|
||||||
pad_len = self.pred_horizon - actual_len
|
pad_len = self.pred_horizon - actual_len
|
||||||
# 重复最后一帧动作进行填充
|
|
||||||
pad_block = np.tile(actions_np[-1], (pad_len, 1))
|
pad_block = np.tile(actions_np[-1], (pad_len, 1))
|
||||||
actions_np = np.concatenate([actions_np, pad_block], axis=0)
|
actions_np = np.concatenate([actions_np, pad_block], axis=0)
|
||||||
# 标记 Padding 部分为 0
|
|
||||||
action_mask[actual_len:] = 0.0
|
|
||||||
|
|
||||||
# --- 3. 图像 (Images) ---
|
# --- 图像处理部分 ---
|
||||||
obs_dict = {}
|
obs_dict = {}
|
||||||
for key in self.obs_keys:
|
for key in self.obs_keys:
|
||||||
imgs = []
|
imgs = []
|
||||||
# 处理观测历史 (Obs Horizon)
|
|
||||||
# 如果 t_start=0, obs_horizon=2, 我们需要读取 t=0 和 t=0 (重复第一帧)
|
|
||||||
for i in range(self.obs_horizon):
|
for i in range(self.obs_horizon):
|
||||||
# 倒序读取:当前帧,前一帧...
|
# 计算历史帧索引
|
||||||
# 注意:这里逻辑是 [t_start - (obs_horizon-1) + i]
|
query_t = max(0, t_start - (self.obs_horizon - 1) + i)
|
||||||
# 比如 horizon=2, t=10. i=0 -> t=9; i=1 -> t=10.
|
|
||||||
query_t = t_start - (self.obs_horizon - 1) + i
|
|
||||||
query_t = max(0, query_t) # 边界保护
|
|
||||||
|
|
||||||
imgs.append(f[f"observations/images/{key}"][query_t])
|
# 1. 读取原始数据 (Numpy uint8)
|
||||||
|
raw_img = f[f"observations/images/{key}"][query_t]
|
||||||
|
|
||||||
# Stack -> (Obs_Horizon, H, W, C)
|
# 2. 【调用处理器】 Numpy -> Tensor (384, 384) Normalized
|
||||||
img_stack = np.stack(imgs)
|
processed_img = self.image_processor(raw_img)
|
||||||
# Normalize & Permute -> (Obs_Horizon, C, H, W)
|
|
||||||
img_stack = img_stack.astype(np.float32) / 255.0
|
|
||||||
img_stack = np.transpose(img_stack, (0, 3, 1, 2))
|
|
||||||
|
|
||||||
obs_dict[key] = torch.from_numpy(img_stack)
|
imgs.append(processed_img)
|
||||||
|
|
||||||
# --- 4. QPos ---
|
# Stack -> (T, C, H, W)
|
||||||
|
obs_dict[key] = torch.stack(imgs)
|
||||||
|
|
||||||
|
# ... (QPos 和 Language 读取保持不变) ...
|
||||||
qpos = f["observations/qpos"][t_start].astype(np.float32)
|
qpos = f["observations/qpos"][t_start].astype(np.float32)
|
||||||
|
lang = f.attrs.get("language", "placeholder")
|
||||||
|
if isinstance(lang, bytes): lang = lang.decode("utf-8")
|
||||||
|
|
||||||
# --- 5. Language ---
|
# 这里的 action_mask 只是临时补全代码,你原来的逻辑是对的
|
||||||
# 暂时写死或从 attrs 读取
|
action_mask = torch.ones(self.pred_horizon, dtype=torch.float32)
|
||||||
lang = f.attrs.get("language", "task instruction placeholder")
|
if actual_len < self.pred_horizon:
|
||||||
if isinstance(lang, bytes):
|
action_mask[actual_len:] = 0.0
|
||||||
lang = lang.decode("utf-8")
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"obs": obs_dict,
|
"obs": obs_dict,
|
||||||
|
|||||||
75
roboimi/vla/data/image_transform.py
Normal file
75
roboimi/vla/data/image_transform.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# 图像预处理
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Union, List
|
||||||
|
|
||||||
|
class VLAImageProcessor:
|
||||||
|
"""
|
||||||
|
VLA 图像预处理器,专为 SigLIP/CLIP 等 ViT 架构设计。
|
||||||
|
功能:
|
||||||
|
1. Numpy (HWC) -> Tensor (CHW)
|
||||||
|
2. Resize (e.g., 384x384)
|
||||||
|
3. Normalize (SigLIP: mean=0.5, std=0.5)
|
||||||
|
4. Data Augmentation (训练时开启颜色抖动)
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
resolution: int = 384,
|
||||||
|
mean: List[float] = [0.5, 0.5, 0.5],
|
||||||
|
std: List[float] = [0.5, 0.5, 0.5],
|
||||||
|
enable_augmentation: bool = True,
|
||||||
|
aug_strength: float = 0.1 # 增强强度,0.1~0.2 比较安全
|
||||||
|
):
|
||||||
|
self.resolution = resolution
|
||||||
|
self.enable_augmentation = enable_augmentation
|
||||||
|
|
||||||
|
# --- 1. 基础处理 (所有模式通用) ---
|
||||||
|
# 注意:这里我们分步定义,因为增强通常在 PIL 阶段做比较快
|
||||||
|
self.resize = T.Resize((resolution, resolution), interpolation=T.InterpolationMode.BICUBIC, antialias=True)
|
||||||
|
self.to_tensor = T.ToTensor()
|
||||||
|
self.normalize = T.Normalize(mean=mean, std=std)
|
||||||
|
|
||||||
|
# --- 2. 数据增强 (仅训练用) ---
|
||||||
|
# 机器人学习通常不做 RandomCrop (会丢失绝对坐标信息),主要做颜色增强
|
||||||
|
if enable_augmentation:
|
||||||
|
self.aug = T.ColorJitter(
|
||||||
|
brightness=aug_strength,
|
||||||
|
contrast=aug_strength,
|
||||||
|
saturation=aug_strength,
|
||||||
|
hue=aug_strength / 2
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.aug = torch.nn.Identity()
|
||||||
|
|
||||||
|
def __call__(self, img: Union[np.ndarray, Image.Image, torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img: (H, W, C) uint8 numpy array (from HDF5) OR PIL Image
|
||||||
|
Returns:
|
||||||
|
tensor: (C, H, W) float32, Normalized
|
||||||
|
"""
|
||||||
|
# 1. 统一转为 PIL Image (方便做 Resize 和 Jitter)
|
||||||
|
if isinstance(img, np.ndarray):
|
||||||
|
img = Image.fromarray(img)
|
||||||
|
elif isinstance(img, torch.Tensor):
|
||||||
|
# 假设 Tensor 是 CHW,转回 PIL 比较麻烦,通常 HDF5 出来都是 numpy
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 2. 数据增强 (如果开启)
|
||||||
|
if self.enable_augmentation:
|
||||||
|
img = self.aug(img)
|
||||||
|
|
||||||
|
# 3. 调整尺寸
|
||||||
|
img = self.resize(img)
|
||||||
|
|
||||||
|
# 4. 转张量 & 归一化
|
||||||
|
# ToTensor 会把 [0, 255] -> [0.0, 1.0]
|
||||||
|
tensor = self.to_tensor(img)
|
||||||
|
tensor = self.normalize(tensor)
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"VLAImageProcessor(res={self.resolution}, aug={self.enable_augmentation})"
|
||||||
@@ -1 +0,0 @@
|
|||||||
# 图像预处理
|
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
# Backbone models
|
# Backbone models
|
||||||
# Uncomment when these are implemented:
|
from .siglip import SigLIPBackbone
|
||||||
# from .siglip import SigLIPBackbone
|
|
||||||
# from .clip import CLIPBackbone
|
# from .clip import CLIPBackbone
|
||||||
# from .dinov2 import DinoV2Backbone
|
# from .dinov2 import DinoV2Backbone
|
||||||
from .debug import DebugBackbone
|
|
||||||
|
|
||||||
__all__ = ["DebugBackbone"]
|
__all__ = ["SigLIPBackbone"]
|
||||||
|
|
||||||
|
# from .debug import DebugBackbone
|
||||||
|
# __all__ = ["DebugBackbone"]
|
||||||
@@ -1 +1,62 @@
|
|||||||
# SigLIP Backbone 实现
|
# SigLIP Backbone 实现
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import AutoModel, AutoProcessor, SiglipVisionModel
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from roboimi.vla.core.interfaces import VLABackbone
|
||||||
|
|
||||||
|
class SigLIPBackbone(VLABackbone):
|
||||||
|
"""
|
||||||
|
Wraps Google's SigLIP Vision Encoder.
|
||||||
|
HuggingFace ID example: "google/siglip-so400m-patch14-384"
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "google/siglip-so400m-patch14-384",
|
||||||
|
freeze: bool = True,
|
||||||
|
embed_dim: Optional[int] = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
print(f"Loading SigLIP: {model_name} ...")
|
||||||
|
|
||||||
|
# 加载视觉部分 (Vision Model only)
|
||||||
|
# 我们不需要 Text Tower,因为 SigLIP 是对齐好的,只用 Vision Tower 抽特征即可
|
||||||
|
self.vision_model = SiglipVisionModel.from_pretrained(model_name)
|
||||||
|
|
||||||
|
# 优先使用配置传入的 embed_dim,否则自动获取
|
||||||
|
if embed_dim is not None:
|
||||||
|
self._embed_dim = embed_dim
|
||||||
|
print(f"✓ Using configured embed_dim: {embed_dim}")
|
||||||
|
else:
|
||||||
|
# 自动获取维度 (SigLIP so400m 通常是 1152)
|
||||||
|
self._embed_dim = self.vision_model.config.hidden_size
|
||||||
|
print(f"✓ Auto-detected embed_dim: {self._embed_dim}")
|
||||||
|
|
||||||
|
if freeze:
|
||||||
|
self._freeze_parameters()
|
||||||
|
|
||||||
|
def _freeze_parameters(self):
|
||||||
|
print("❄️ Freezing Vision Backbone parameters")
|
||||||
|
for param in self.vision_model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.vision_model.eval()
|
||||||
|
|
||||||
|
def forward(self, obs: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
obs['image']: (B, C, H, W) normalized tensor
|
||||||
|
Returns:
|
||||||
|
features: (B, Seq_Len, Embed_Dim)
|
||||||
|
"""
|
||||||
|
images = obs['image']
|
||||||
|
|
||||||
|
# SigLIP 期望输入是 (B, C, H, W)
|
||||||
|
# HuggingFace 的 VisionModel 输出是一个 BaseModelOutputWithPooling
|
||||||
|
# last_hidden_state shape: (B, Num_Patches, Embed_Dim)
|
||||||
|
outputs = self.vision_model(pixel_values=images)
|
||||||
|
|
||||||
|
return outputs.last_hidden_state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embed_dim(self) -> int:
|
||||||
|
return self._embed_dim
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
# # Action Head models
|
# # Action Head models
|
||||||
# from .diffusion import DiffusionActionHead
|
from .diffusion import DiffusionHead
|
||||||
# from .act import ACTHead
|
# from .act import ACTHead
|
||||||
|
|
||||||
# __all__ = ["DiffusionActionHead", "ACTHead"]
|
__all__ = ["DiffusionHead"]
|
||||||
|
|
||||||
from .debug import DebugHead
|
# from .debug import DebugHead
|
||||||
|
|
||||||
__all__ = ["DebugHead"]
|
# __all__ = ["DebugHead"]
|
||||||
@@ -1 +1,174 @@
|
|||||||
# Diffusion Policy Action Head 实现
|
# Diffusion Policy Action Head 实现
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from diffusers import DDPMScheduler
|
||||||
|
from roboimi.vla.core.interfaces import VLAHead
|
||||||
|
|
||||||
|
class DiffusionHead(VLAHead):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int, # 来自 Projector 的维度 (e.g. 384)
|
||||||
|
action_dim: int, # 动作维度 (e.g. 16)
|
||||||
|
chunk_size: int, # 预测视界 (e.g. 16)
|
||||||
|
n_timesteps: int = 100, # 扩散步数
|
||||||
|
hidden_dim: int = 256
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.action_dim = action_dim
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
|
||||||
|
# 1. 噪声调度器 (DDPM)
|
||||||
|
self.scheduler = DDPMScheduler(
|
||||||
|
num_train_timesteps=n_timesteps,
|
||||||
|
beta_schedule='squaredcos_cap_v2', # 现代 Diffusion 常用调度
|
||||||
|
clip_sample=True,
|
||||||
|
prediction_type='epsilon' # 预测噪声
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 噪声预测网络 (Noise Predictor Network)
|
||||||
|
# 输入: Noisy Action + Time Embedding + Image Embedding
|
||||||
|
# 这是一个简单的 Conditional MLP/ResNet 结构
|
||||||
|
self.time_emb = nn.Sequential(
|
||||||
|
nn.Linear(1, hidden_dim),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(hidden_dim, hidden_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cond_proj = nn.Linear(input_dim, hidden_dim) # 把图像特征投影一下
|
||||||
|
|
||||||
|
# 主干网络 (由几个 Residual Block 组成)
|
||||||
|
self.mid_layers = nn.ModuleList([
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Linear(hidden_dim + action_dim * chunk_size, hidden_dim),
|
||||||
|
nn.LayerNorm(hidden_dim),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(hidden_dim, hidden_dim + action_dim * chunk_size) # 简单的残差
|
||||||
|
) for _ in range(3)
|
||||||
|
])
|
||||||
|
|
||||||
|
# 输出层: 预测噪声 (Shape 与 Action 相同)
|
||||||
|
self.final_layer = nn.Linear(hidden_dim + action_dim * chunk_size, action_dim * chunk_size)
|
||||||
|
|
||||||
|
def forward(self, embeddings: torch.Tensor, actions: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Unified interface for Training and Inference.
|
||||||
|
"""
|
||||||
|
device = embeddings.device
|
||||||
|
|
||||||
|
# --- 1. 处理条件 (Conditioning) ---
|
||||||
|
# embeddings: (B, Seq, Dim). 我们这里做一个简化,做 Average Pooling 变成 (B, Dim)
|
||||||
|
# 如果你想做更复杂的 Cross-Attention,可以在这里改
|
||||||
|
global_cond = embeddings.mean(dim=1)
|
||||||
|
cond_feat = self.cond_proj(global_cond) # (B, Hidden)
|
||||||
|
|
||||||
|
# =========================================
|
||||||
|
# 分支 A: 训练模式 (Training)
|
||||||
|
# =========================================
|
||||||
|
if actions is not None:
|
||||||
|
batch_size = actions.shape[0]
|
||||||
|
|
||||||
|
# 1.1 准备数据 (Flatten: B, Chunk, ActDim -> B, Chunk*ActDim)
|
||||||
|
actions_flat = actions.view(batch_size, -1)
|
||||||
|
|
||||||
|
# 1.2 采样噪声和时间步
|
||||||
|
noise = torch.randn_like(actions_flat)
|
||||||
|
timesteps = torch.randint(
|
||||||
|
0, self.scheduler.config.num_train_timesteps,
|
||||||
|
(batch_size,), device=device
|
||||||
|
).long()
|
||||||
|
|
||||||
|
# 1.3 加噪 (Forward Diffusion)
|
||||||
|
noisy_actions = self.scheduler.add_noise(actions_flat, noise, timesteps)
|
||||||
|
|
||||||
|
# 1.4 预测噪声 (Network Forward)
|
||||||
|
pred_noise = self._predict_noise(noisy_actions, timesteps, cond_feat)
|
||||||
|
|
||||||
|
# 1.5 计算 Loss (MSE between actual noise and predicted noise)
|
||||||
|
loss = nn.functional.mse_loss(pred_noise, noise)
|
||||||
|
|
||||||
|
return {"loss": loss}
|
||||||
|
|
||||||
|
# =========================================
|
||||||
|
# 分支 B: 推理模式 (Inference)
|
||||||
|
# =========================================
|
||||||
|
else:
|
||||||
|
batch_size = embeddings.shape[0]
|
||||||
|
|
||||||
|
# 2.1 从纯高斯噪声开始
|
||||||
|
noisy_actions = torch.randn(
|
||||||
|
batch_size, self.chunk_size * self.action_dim,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2.2 逐步去噪 (Reverse Diffusion Loop)
|
||||||
|
# 使用 scheduler.timesteps 自动处理步长
|
||||||
|
self.scheduler.set_timesteps(self.scheduler.config.num_train_timesteps)
|
||||||
|
|
||||||
|
for t in self.scheduler.timesteps:
|
||||||
|
# 构造 batch 的 t
|
||||||
|
timesteps = torch.tensor([t], device=device).repeat(batch_size)
|
||||||
|
|
||||||
|
# 预测噪声
|
||||||
|
# 注意:diffusers 的 step 需要 model_output
|
||||||
|
model_output = self._predict_noise(noisy_actions, timesteps, cond_feat)
|
||||||
|
|
||||||
|
# 移除噪声 (Step)
|
||||||
|
noisy_actions = self.scheduler.step(
|
||||||
|
model_output, t, noisy_actions
|
||||||
|
).prev_sample
|
||||||
|
|
||||||
|
# 2.3 Reshape 回 (B, Chunk, ActDim)
|
||||||
|
pred_actions = noisy_actions.view(batch_size, self.chunk_size, self.action_dim)
|
||||||
|
|
||||||
|
return {"pred_actions": pred_actions}
|
||||||
|
|
||||||
|
def _predict_noise(self, noisy_actions, timesteps, cond_feat):
|
||||||
|
"""内部辅助函数:运行简单的 MLP 网络"""
|
||||||
|
# Time Embed
|
||||||
|
t_emb = self.time_emb(timesteps.float().unsqueeze(-1)) # (B, Hidden)
|
||||||
|
|
||||||
|
# Fusion: Concat Action + (Condition * Time)
|
||||||
|
# 这里用简单的相加融合,实际可以更复杂
|
||||||
|
fused_feat = cond_feat + t_emb
|
||||||
|
|
||||||
|
# Concat input
|
||||||
|
x = torch.cat([noisy_actions, fused_feat], dim=-1) # 注意这里维度需要对齐,或者用 MLP 映射
|
||||||
|
|
||||||
|
# 修正:上面的 concat 维度可能不对,为了简化代码,我们用一种更简单的方式:
|
||||||
|
# 将 cond_feat 加到 input 里需要维度匹配。
|
||||||
|
# 这里重写一个极简的 Forward:
|
||||||
|
|
||||||
|
# 正确做法:先将 x 映射到 hidden,再加 t_emb 和 cond_feat
|
||||||
|
# 但为了复用 self.mid_layers 定义的 Linear(Hidden + Input)...
|
||||||
|
# 我们用最傻瓜的方式:Input = Action,Condition 直接拼接到每一层或者只拼输入
|
||||||
|
|
||||||
|
# 让我们修正一下网络结构逻辑,确保不报错:
|
||||||
|
# Input: NoisyAction (Dim_A)
|
||||||
|
# Cond: Hidden (Dim_H)
|
||||||
|
|
||||||
|
# 这种临时写的 MLP 容易维度不匹配,我们改用一个极其稳健的计算流:
|
||||||
|
# x = Action
|
||||||
|
# h = Cond + Time
|
||||||
|
# input = cat([x, h]) -> Linear -> Output
|
||||||
|
|
||||||
|
# 重新定义 _predict_noise 的逻辑依赖于 __init__ 里的定义。
|
||||||
|
# 为了保证一次跑通,我使用动态 cat:
|
||||||
|
|
||||||
|
x = noisy_actions
|
||||||
|
# 假设 mid_layers 的输入是 hidden_dim + action_flat_dim
|
||||||
|
# 我们把 condition 映射成 hidden_dim,然后 concat
|
||||||
|
|
||||||
|
# 真正的计算流:
|
||||||
|
h = cond_feat + t_emb # (B, Hidden)
|
||||||
|
|
||||||
|
# 把 h 拼接到 x 上 (前提是 x 是 action flat)
|
||||||
|
# Linear 输入维度是 Hidden + ActFlat
|
||||||
|
model_input = torch.cat([h, x], dim=-1)
|
||||||
|
|
||||||
|
for layer in self.mid_layers:
|
||||||
|
# Residual connection mechanism
|
||||||
|
out = layer(model_input)
|
||||||
|
model_input = out + model_input # Simple ResNet
|
||||||
|
|
||||||
|
return self.final_layer(model_input)
|
||||||
Reference in New Issue
Block a user