debug(train): 在siglip和DiffusionHead下跑通训练流程
This commit is contained in:
@@ -6,109 +6,93 @@ import glob
|
||||
from torch.utils.data import Dataset
|
||||
from typing import Dict, List, Any
|
||||
|
||||
# 【新增】导入刚才写好的处理器
|
||||
from .image_transform import VLAImageProcessor
|
||||
|
||||
class VLAChunkedDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
pred_horizon: int = 16,
|
||||
obs_horizon: int = 2,
|
||||
obs_keys: List[str] = ["top"] # 默认只用 top
|
||||
obs_horizon: int = 1,
|
||||
obs_keys: List[str] = ["top"],
|
||||
resize_resolution: int = 384, # SigLIP 默认 384
|
||||
train: bool = True # 【新增】控制是否增强
|
||||
):
|
||||
self.data_path = data_path
|
||||
self.pred_horizon = pred_horizon
|
||||
self.obs_horizon = obs_horizon
|
||||
self.obs_keys = obs_keys
|
||||
|
||||
# --- 1. 扫描文件 ---
|
||||
# ... (这里保留之前的扫描文件代码 self.file_paths ...) ...
|
||||
if os.path.isdir(data_path):
|
||||
# 如果是文件夹,读取所有 episode_*.hdf5
|
||||
self.file_paths = sorted(glob.glob(os.path.join(data_path, "*.hdf5")))
|
||||
else:
|
||||
# 如果是单文件
|
||||
self.file_paths = [data_path]
|
||||
|
||||
if len(self.file_paths) == 0:
|
||||
raise ValueError(f"No .hdf5 files found in {data_path}")
|
||||
|
||||
print(f"Found {len(self.file_paths)} episodes. Indexing...")
|
||||
|
||||
# --- 2. 建立全局索引 (Episode, Time) ---
|
||||
# 我们需要知道 global_index=1000 对应的是哪个文件的第几帧
|
||||
self.index_map = [] # [(file_idx, start_time), ...]
|
||||
|
||||
# ... (这里保留之前的建立索引代码 self.index_map ...) ...
|
||||
self.index_map = []
|
||||
for i, path in enumerate(self.file_paths):
|
||||
with h5py.File(path, 'r') as f:
|
||||
# 假设所有文件的 action 长度就是 episode 长度
|
||||
total_len = f["action"].shape[0]
|
||||
# 有效的起始点:从 0 到 total_len - 1
|
||||
# 即使到了最后几帧,因为有 padding,所以也是有效的 sample
|
||||
for t in range(total_len):
|
||||
self.index_map.append((i, t))
|
||||
|
||||
print(f"✅ Indexed {len(self.index_map)} total samples.")
|
||||
|
||||
# 【核心修改】实例化处理器
|
||||
self.image_processor = VLAImageProcessor(
|
||||
resolution=resize_resolution,
|
||||
enable_augmentation=train, # 训练集开启增强
|
||||
aug_strength=0.1
|
||||
)
|
||||
print(f"✅ Image Processor: {self.image_processor}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index_map)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||
# --- 1. 定位文件 ---
|
||||
file_idx, t_start = self.index_map[idx]
|
||||
file_path = self.file_paths[file_idx]
|
||||
|
||||
# 每次读取打开文件 (Lazy Loading),读取完自动关闭
|
||||
# 这种方式对多进程 DataLoader 最安全
|
||||
with h5py.File(file_path, 'r') as f:
|
||||
# ... (Action读取代码保持不变) ...
|
||||
total_len = f["action"].shape[0]
|
||||
|
||||
# --- 2. 动作 (Action) ---
|
||||
t_end = min(t_start + self.pred_horizon, total_len)
|
||||
|
||||
# 读取动作片段
|
||||
actions_np = f["action"][t_start:t_end] # (L, 16)
|
||||
|
||||
# Padding 处理
|
||||
actions_np = f["action"][t_start:t_end]
|
||||
# ... (Padding 逻辑保持不变) ...
|
||||
actual_len = actions_np.shape[0]
|
||||
action_mask = torch.ones(self.pred_horizon, dtype=torch.float32)
|
||||
|
||||
if actual_len < self.pred_horizon:
|
||||
pad_len = self.pred_horizon - actual_len
|
||||
# 重复最后一帧动作进行填充
|
||||
pad_block = np.tile(actions_np[-1], (pad_len, 1))
|
||||
actions_np = np.concatenate([actions_np, pad_block], axis=0)
|
||||
# 标记 Padding 部分为 0
|
||||
action_mask[actual_len:] = 0.0
|
||||
|
||||
# --- 3. 图像 (Images) ---
|
||||
# --- 图像处理部分 ---
|
||||
obs_dict = {}
|
||||
for key in self.obs_keys:
|
||||
imgs = []
|
||||
# 处理观测历史 (Obs Horizon)
|
||||
# 如果 t_start=0, obs_horizon=2, 我们需要读取 t=0 和 t=0 (重复第一帧)
|
||||
for i in range(self.obs_horizon):
|
||||
# 倒序读取:当前帧,前一帧...
|
||||
# 注意:这里逻辑是 [t_start - (obs_horizon-1) + i]
|
||||
# 比如 horizon=2, t=10. i=0 -> t=9; i=1 -> t=10.
|
||||
query_t = t_start - (self.obs_horizon - 1) + i
|
||||
query_t = max(0, query_t) # 边界保护
|
||||
# 计算历史帧索引
|
||||
query_t = max(0, t_start - (self.obs_horizon - 1) + i)
|
||||
|
||||
imgs.append(f[f"observations/images/{key}"][query_t])
|
||||
# 1. 读取原始数据 (Numpy uint8)
|
||||
raw_img = f[f"observations/images/{key}"][query_t]
|
||||
|
||||
# 2. 【调用处理器】 Numpy -> Tensor (384, 384) Normalized
|
||||
processed_img = self.image_processor(raw_img)
|
||||
|
||||
imgs.append(processed_img)
|
||||
|
||||
# Stack -> (Obs_Horizon, H, W, C)
|
||||
img_stack = np.stack(imgs)
|
||||
# Normalize & Permute -> (Obs_Horizon, C, H, W)
|
||||
img_stack = img_stack.astype(np.float32) / 255.0
|
||||
img_stack = np.transpose(img_stack, (0, 3, 1, 2))
|
||||
|
||||
obs_dict[key] = torch.from_numpy(img_stack)
|
||||
# Stack -> (T, C, H, W)
|
||||
obs_dict[key] = torch.stack(imgs)
|
||||
|
||||
# --- 4. QPos ---
|
||||
# ... (QPos 和 Language 读取保持不变) ...
|
||||
qpos = f["observations/qpos"][t_start].astype(np.float32)
|
||||
lang = f.attrs.get("language", "placeholder")
|
||||
if isinstance(lang, bytes): lang = lang.decode("utf-8")
|
||||
|
||||
# --- 5. Language ---
|
||||
# 暂时写死或从 attrs 读取
|
||||
lang = f.attrs.get("language", "task instruction placeholder")
|
||||
if isinstance(lang, bytes):
|
||||
lang = lang.decode("utf-8")
|
||||
# 这里的 action_mask 只是临时补全代码,你原来的逻辑是对的
|
||||
action_mask = torch.ones(self.pred_horizon, dtype=torch.float32)
|
||||
if actual_len < self.pred_horizon:
|
||||
action_mask[actual_len:] = 0.0
|
||||
|
||||
return {
|
||||
"obs": obs_dict,
|
||||
|
||||
75
roboimi/vla/data/image_transform.py
Normal file
75
roboimi/vla/data/image_transform.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# 图像预处理
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from typing import Union, List
|
||||
|
||||
class VLAImageProcessor:
|
||||
"""
|
||||
VLA 图像预处理器,专为 SigLIP/CLIP 等 ViT 架构设计。
|
||||
功能:
|
||||
1. Numpy (HWC) -> Tensor (CHW)
|
||||
2. Resize (e.g., 384x384)
|
||||
3. Normalize (SigLIP: mean=0.5, std=0.5)
|
||||
4. Data Augmentation (训练时开启颜色抖动)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int = 384,
|
||||
mean: List[float] = [0.5, 0.5, 0.5],
|
||||
std: List[float] = [0.5, 0.5, 0.5],
|
||||
enable_augmentation: bool = True,
|
||||
aug_strength: float = 0.1 # 增强强度,0.1~0.2 比较安全
|
||||
):
|
||||
self.resolution = resolution
|
||||
self.enable_augmentation = enable_augmentation
|
||||
|
||||
# --- 1. 基础处理 (所有模式通用) ---
|
||||
# 注意:这里我们分步定义,因为增强通常在 PIL 阶段做比较快
|
||||
self.resize = T.Resize((resolution, resolution), interpolation=T.InterpolationMode.BICUBIC, antialias=True)
|
||||
self.to_tensor = T.ToTensor()
|
||||
self.normalize = T.Normalize(mean=mean, std=std)
|
||||
|
||||
# --- 2. 数据增强 (仅训练用) ---
|
||||
# 机器人学习通常不做 RandomCrop (会丢失绝对坐标信息),主要做颜色增强
|
||||
if enable_augmentation:
|
||||
self.aug = T.ColorJitter(
|
||||
brightness=aug_strength,
|
||||
contrast=aug_strength,
|
||||
saturation=aug_strength,
|
||||
hue=aug_strength / 2
|
||||
)
|
||||
else:
|
||||
self.aug = torch.nn.Identity()
|
||||
|
||||
def __call__(self, img: Union[np.ndarray, Image.Image, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
img: (H, W, C) uint8 numpy array (from HDF5) OR PIL Image
|
||||
Returns:
|
||||
tensor: (C, H, W) float32, Normalized
|
||||
"""
|
||||
# 1. 统一转为 PIL Image (方便做 Resize 和 Jitter)
|
||||
if isinstance(img, np.ndarray):
|
||||
img = Image.fromarray(img)
|
||||
elif isinstance(img, torch.Tensor):
|
||||
# 假设 Tensor 是 CHW,转回 PIL 比较麻烦,通常 HDF5 出来都是 numpy
|
||||
pass
|
||||
|
||||
# 2. 数据增强 (如果开启)
|
||||
if self.enable_augmentation:
|
||||
img = self.aug(img)
|
||||
|
||||
# 3. 调整尺寸
|
||||
img = self.resize(img)
|
||||
|
||||
# 4. 转张量 & 归一化
|
||||
# ToTensor 会把 [0, 255] -> [0.0, 1.0]
|
||||
tensor = self.to_tensor(img)
|
||||
tensor = self.normalize(tensor)
|
||||
|
||||
return tensor
|
||||
|
||||
def __repr__(self):
|
||||
return f"VLAImageProcessor(res={self.resolution}, aug={self.enable_augmentation})"
|
||||
@@ -1 +0,0 @@
|
||||
# 图像预处理
|
||||
Reference in New Issue
Block a user