feat: 更新框架,新增数据及定义和backbone
This commit is contained in:
@@ -1,103 +1,156 @@
|
||||
import h5py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
import h5py
|
||||
import numpy as np
|
||||
import os
|
||||
import glob
|
||||
from torch.utils.data import Dataset
|
||||
from typing import Dict, List, Any
|
||||
import pickle
|
||||
|
||||
# 【新增】导入刚才写好的处理器
|
||||
from .image_transform import VLAImageProcessor
|
||||
|
||||
class VLAChunkedDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
pred_horizon: int = 16,
|
||||
obs_horizon: int = 1,
|
||||
obs_keys: List[str] = ["top"],
|
||||
resize_resolution: int = 384, # SigLIP 默认 384
|
||||
train: bool = True # 【新增】控制是否增强
|
||||
):
|
||||
self.data_path = data_path
|
||||
class RobotDiffusionDataset(Dataset):
|
||||
def __init__(self,
|
||||
dataset_dir,
|
||||
pred_horizon=16,
|
||||
obs_horizon=1,
|
||||
action_horizon=8,
|
||||
camera_names=['r_vis', 'top'],
|
||||
normalization_type='gaussian'):
|
||||
"""
|
||||
Args:
|
||||
dataset_dir: 存放 episode_*.hdf5 的文件夹路径
|
||||
pred_horizon: 预测未来动作的长度 (Tp)
|
||||
obs_horizon: 历史观测长度 (To)
|
||||
action_horizon: 执行动作长度 (Ta) - 在Dataset中主要影响Evaluation,这里作为参数保留
|
||||
"""
|
||||
self.dataset_dir = dataset_dir
|
||||
self.pred_horizon = pred_horizon
|
||||
self.obs_horizon = obs_horizon
|
||||
self.obs_keys = obs_keys
|
||||
self.action_horizon = action_horizon
|
||||
self.camera_names = camera_names
|
||||
self.normalization_type = normalization_type
|
||||
# 1. 扫描所有HDF5文件并建立索引
|
||||
# 格式: [(file_path, episode_length), ...]
|
||||
self.episode_files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
|
||||
self.indices = []
|
||||
|
||||
# ... (这里保留之前的扫描文件代码 self.file_paths ...) ...
|
||||
if os.path.isdir(data_path):
|
||||
self.file_paths = sorted(glob.glob(os.path.join(data_path, "*.hdf5")))
|
||||
else:
|
||||
self.file_paths = [data_path]
|
||||
|
||||
# ... (这里保留之前的建立索引代码 self.index_map ...) ...
|
||||
self.index_map = []
|
||||
for i, path in enumerate(self.file_paths):
|
||||
with h5py.File(path, 'r') as f:
|
||||
total_len = f["action"].shape[0]
|
||||
for t in range(total_len):
|
||||
self.index_map.append((i, t))
|
||||
|
||||
# 【核心修改】实例化处理器
|
||||
self.image_processor = VLAImageProcessor(
|
||||
resolution=resize_resolution,
|
||||
enable_augmentation=train, # 训练集开启增强
|
||||
aug_strength=0.1
|
||||
)
|
||||
print(f"✅ Image Processor: {self.image_processor}")
|
||||
print(f"Found {len(self.episode_files)} episodes. Building index...")
|
||||
|
||||
for file_path in self.episode_files:
|
||||
with h5py.File(file_path, 'r') as f:
|
||||
# 获取该 episode 的长度 (例如 700)
|
||||
l = f['action'].shape[0]
|
||||
# 保存每个有效 step 的索引信息
|
||||
# (file_path, episode_length, current_step_index)
|
||||
for i in range(l):
|
||||
self.indices.append((file_path, l, i))
|
||||
|
||||
# 2. 统计数据
|
||||
with open(os.path.join(dataset_dir, 'data_stats.pkl'), 'rb') as f:
|
||||
self.stats = pickle.load(f)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index_map)
|
||||
return len(self.indices)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||
file_idx, t_start = self.index_map[idx]
|
||||
file_path = self.file_paths[file_idx]
|
||||
def __getitem__(self, idx):
|
||||
file_path, episode_len, start_ts = self.indices[idx]
|
||||
|
||||
with h5py.File(file_path, 'r') as f:
|
||||
# ... (Action读取代码保持不变) ...
|
||||
total_len = f["action"].shape[0]
|
||||
t_end = min(t_start + self.pred_horizon, total_len)
|
||||
actions_np = f["action"][t_start:t_end]
|
||||
# ... (Padding 逻辑保持不变) ...
|
||||
actual_len = actions_np.shape[0]
|
||||
if actual_len < self.pred_horizon:
|
||||
pad_len = self.pred_horizon - actual_len
|
||||
pad_block = np.tile(actions_np[-1], (pad_len, 1))
|
||||
actions_np = np.concatenate([actions_np, pad_block], axis=0)
|
||||
# -----------------------------
|
||||
# 1. 打开文件
|
||||
# -----------------------------
|
||||
# 注意: 在 __getitem__ 中打开文件对多进程 DataLoader 更友好
|
||||
# 如果追求极致IO性能,可以考虑使用 h5py 的 swmr 模式或内存缓存
|
||||
with h5py.File(file_path, 'r') as root:
|
||||
|
||||
# --- 图像处理部分 ---
|
||||
obs_dict = {}
|
||||
for key in self.obs_keys:
|
||||
imgs = []
|
||||
for i in range(self.obs_horizon):
|
||||
# 计算历史帧索引
|
||||
query_t = max(0, t_start - (self.obs_horizon - 1) + i)
|
||||
|
||||
# 1. 读取原始数据 (Numpy uint8)
|
||||
raw_img = f[f"observations/images/{key}"][query_t]
|
||||
|
||||
# 2. 【调用处理器】 Numpy -> Tensor (384, 384) Normalized
|
||||
processed_img = self.image_processor(raw_img)
|
||||
|
||||
imgs.append(processed_img)
|
||||
# -----------------------------
|
||||
# 2. 处理 Action (Prediction Target)
|
||||
# -----------------------------
|
||||
# 目标: 获取 [t, t + pred_horizon] 的动作
|
||||
action_start = start_ts
|
||||
action_end = min(start_ts + self.pred_horizon, episode_len)
|
||||
|
||||
actions = root['action'][action_start:action_end] # shape: (T_subset, 16)
|
||||
|
||||
# Padding: 如果剩余动作不足 pred_horizon,复制最后一步
|
||||
if len(actions) < self.pred_horizon:
|
||||
pad_len = self.pred_horizon - len(actions)
|
||||
last_action = actions[-1]
|
||||
# 重复最后一行
|
||||
pad_content = np.repeat(last_action[np.newaxis, :], pad_len, axis=0)
|
||||
actions = np.concatenate([actions, pad_content], axis=0)
|
||||
|
||||
# 归一化 Action
|
||||
if self.stats:
|
||||
actions = self._normalize_data(actions, self.stats['action'])
|
||||
|
||||
# -----------------------------
|
||||
# 3. 处理 Observations (History)
|
||||
# -----------------------------
|
||||
# 目标: 获取 [t - obs_horizon + 1, t + 1] 的观测
|
||||
# 索引逻辑:
|
||||
# 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding)
|
||||
# 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5]
|
||||
|
||||
indices = []
|
||||
for i in range(self.obs_horizon):
|
||||
# t - (To - 1) + i
|
||||
query_ts = start_ts - (self.obs_horizon - 1) + i
|
||||
# 边界处理 (Padding first frame)
|
||||
query_ts = max(query_ts, 0)
|
||||
indices.append(query_ts)
|
||||
|
||||
# Stack -> (T, C, H, W)
|
||||
obs_dict[key] = torch.stack(imgs)
|
||||
# 读取 qpos (proprioception)
|
||||
qpos_data = root['observations/qpos']
|
||||
qpos = qpos_data[indices] # smart indexing
|
||||
if self.stats:
|
||||
qpos = self._normalize_data(qpos, self.stats['qpos'])
|
||||
|
||||
# ... (QPos 和 Language 读取保持不变) ...
|
||||
qpos = f["observations/qpos"][t_start].astype(np.float32)
|
||||
lang = f.attrs.get("language", "placeholder")
|
||||
if isinstance(lang, bytes): lang = lang.decode("utf-8")
|
||||
|
||||
# 这里的 action_mask 只是临时补全代码,你原来的逻辑是对的
|
||||
action_mask = torch.ones(self.pred_horizon, dtype=torch.float32)
|
||||
if actual_len < self.pred_horizon:
|
||||
action_mask[actual_len:] = 0.0
|
||||
# 读取 Images
|
||||
# 你有三个视角: angle, r_vis, top
|
||||
# 建议将它们分开返回,或者在 Dataset 里 Concat
|
||||
image_dict = {}
|
||||
for cam_name in self.camera_names:
|
||||
# HDF5 dataset
|
||||
img_dset = root['observations']['images'][cam_name]
|
||||
|
||||
imgs = []
|
||||
for t in indices:
|
||||
img = img_dset[t] # (480, 640, 3) uint8
|
||||
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 # (C, H, W)
|
||||
imgs.append(img)
|
||||
|
||||
# Stack time dimension: (obs_horizon, 3, H, W)
|
||||
image_dict[cam_name] = torch.stack(imgs)
|
||||
|
||||
return {
|
||||
"obs": obs_dict,
|
||||
"qpos": torch.from_numpy(qpos),
|
||||
"actions": torch.from_numpy(actions_np).float(),
|
||||
"action_mask": action_mask,
|
||||
"language": lang
|
||||
}
|
||||
# -----------------------------
|
||||
# 4. 组装 Batch
|
||||
# -----------------------------
|
||||
data_batch = {
|
||||
'action': torch.from_numpy(actions).float(), # (Tp, 16)
|
||||
'qpos': torch.from_numpy(qpos).float(), # (To, 16)
|
||||
}
|
||||
# 将图像放入 batch
|
||||
for cam_name, img_tensor in image_dict.items():
|
||||
data_batch[f'image_{cam_name}'] = img_tensor # (To, 3, H, W)
|
||||
|
||||
# TODO: 添加 Language Instruction
|
||||
# 如果所有 episode 共享任务,这里可以是固定 embedding
|
||||
# 如果每个 episode 任务不同,你需要一个额外的 meta json 来映射 file_path -> text
|
||||
# data_batch['lang_text'] = "pick up the red cube"
|
||||
|
||||
return data_batch
|
||||
|
||||
def _normalize_data(self, data, stats):
|
||||
if self.normalization_type == 'min_max':
|
||||
# 之前的逻辑: [-1, 1]
|
||||
min_val = stats['min']
|
||||
max_val = stats['max']
|
||||
data = (data - min_val) / (max_val - min_val + 1e-8)
|
||||
return data * 2 - 1
|
||||
|
||||
elif self.normalization_type == 'gaussian':
|
||||
# 新逻辑: Mean/Std
|
||||
mean = stats['mean']
|
||||
std = stats['std']
|
||||
# (data - mean) / std
|
||||
# 这里的 data 是 numpy array
|
||||
return (data - mean) / (std + 1e-8)
|
||||
Reference in New Issue
Block a user