feat(dataset): 定义VLAChunkedDataset类,构建数据可视化工具
This commit is contained in:
@@ -2,87 +2,80 @@ import h5py
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from typing import Dict, List, Any
|
||||
|
||||
class VLAHDF5Dataset(Dataset):
|
||||
def __init__(self,
|
||||
dataset_path: str,
|
||||
pred_horizon: int = 16,
|
||||
obs_horizon: int = 2,
|
||||
transform=None):
|
||||
self.dataset_path = dataset_path
|
||||
class VLAChunkedDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
pred_horizon: int = 16,
|
||||
obs_horizon: int = 2,
|
||||
obs_keys: List[str] = ["top", "angle"]
|
||||
):
|
||||
self.data_path = data_path
|
||||
self.pred_horizon = pred_horizon
|
||||
self.obs_horizon = obs_horizon
|
||||
self.transform = transform
|
||||
self.obs_keys = obs_keys
|
||||
self.file_handle = None
|
||||
|
||||
# 1. 在初始化时,我们只读取数据的“元数据”(形状、长度),不加载内容
|
||||
# 这一步很快,不会占用内存
|
||||
with h5py.File(self.dataset_path, 'r') as root:
|
||||
self.demo_keys = list(root['data'].keys())
|
||||
with h5py.File(self.data_path, 'r') as f:
|
||||
self.total_len = f["action"].shape[0]
|
||||
# 尝试从属性或特定路径读取语言指令
|
||||
# 假设你的格式中语言存在根目录属性里,或者你手动指定
|
||||
self.lang_instruction = f.attrs.get("language", "执行任务")
|
||||
if isinstance(self.lang_instruction, bytes):
|
||||
self.lang_instruction = self.lang_instruction.decode("utf-8")
|
||||
|
||||
# 构建索引表:(demo_key, start_time)
|
||||
self.indices = []
|
||||
for key in self.demo_keys:
|
||||
demo = root['data'][key]
|
||||
L = demo['actions'].shape[0]
|
||||
# 遍历该轨迹的所有时刻
|
||||
for t in range(L):
|
||||
self.indices.append((key, t))
|
||||
|
||||
def _get_handle(self):
|
||||
if self.file_handle is None:
|
||||
self.file_handle = h5py.File(self.data_path, 'r', swmr=True)
|
||||
return self.file_handle
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indices)
|
||||
return self.total_len
|
||||
|
||||
def __getitem__(self, idx):
|
||||
key, t_start = self.indices[idx]
|
||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||
f = self._get_handle()
|
||||
t_start = idx
|
||||
|
||||
# 2. 【关键】在 __getitem__ 内部打开文件
|
||||
# 这确保了每个 DataLoader worker 都有自己独立的文件句柄
|
||||
with h5py.File(self.dataset_path, 'r') as root:
|
||||
demo = root['data'][key]
|
||||
|
||||
# 获取数据总长度
|
||||
L = demo['actions'].shape[0]
|
||||
|
||||
# --- 读取动作 (Actions) ---
|
||||
t_end = min(t_start + self.pred_horizon, L)
|
||||
# HDF5 支持直接切片读取,非常快
|
||||
actions = demo['actions'][t_start : t_end]
|
||||
|
||||
# 处理 Padding (如果动作不够长)
|
||||
if len(actions) < self.pred_horizon:
|
||||
# 转为 Tensor 处理 Padding
|
||||
actions = torch.from_numpy(actions)
|
||||
pad_len = self.pred_horizon - len(actions)
|
||||
last_action = actions[-1].unsqueeze(0)
|
||||
actions = torch.cat([actions, last_action.repeat(pad_len, 1)])
|
||||
action_mask = torch.cat([torch.ones(len(actions)-pad_len), torch.zeros(pad_len)])
|
||||
else:
|
||||
actions = torch.from_numpy(actions)
|
||||
action_mask = torch.ones(self.pred_horizon)
|
||||
|
||||
# --- 读取图像 (Images) ---
|
||||
# 处理历史观测 padding (如果 t_start < obs_horizon)
|
||||
images_list = []
|
||||
# --- 1. 动作与掩码 (Action & Mask) ---
|
||||
t_end = min(t_start + self.pred_horizon, self.total_len)
|
||||
actual_len = t_end - t_start
|
||||
|
||||
actions_np = f["action"][t_start:t_end]
|
||||
|
||||
# 创建掩码:1 表示真实数据,0 表示 Padding
|
||||
# 这是为了在计算 Loss 时屏蔽掉末端重复的动作
|
||||
action_mask = torch.ones(self.pred_horizon, dtype=torch.float32)
|
||||
|
||||
if actual_len < self.pred_horizon:
|
||||
pad_len = self.pred_horizon - actual_len
|
||||
# 填充最后一个有效动作
|
||||
pad_block = np.tile(actions_np[-1], (pad_len, 1))
|
||||
actions_np = np.concatenate([actions_np, pad_block], axis=0)
|
||||
# 将填充部分的掩码置为 0
|
||||
action_mask[actual_len:] = 0.0
|
||||
|
||||
# --- 2. 观察值 (Observations) ---
|
||||
obs_dict = {}
|
||||
for key in self.obs_keys:
|
||||
imgs = []
|
||||
for i in range(self.obs_horizon):
|
||||
t_read = max(0, t_start - self.obs_horizon + 1 + i)
|
||||
# 读取单帧
|
||||
img = demo['obs']['agentview_rgb'][t_read]
|
||||
images_list.append(img)
|
||||
t_query = max(0, t_start - (self.obs_horizon - 1) + i)
|
||||
imgs.append(f[f"observations/images/{key}"][t_query])
|
||||
|
||||
# Stack 并转为 Tensor: [T, H, W, C] -> [T, C, H, W]
|
||||
images = np.stack(images_list)
|
||||
images = torch.from_numpy(images).permute(0, 3, 1, 2).float() / 255.0
|
||||
|
||||
# --- 读取语言指令 ---
|
||||
# 假设语言存储在 demo 的属性中 (Robomimic 风格)
|
||||
lang_text = demo.attrs.get("model_file", "") # 或自定义字段
|
||||
img_stack = np.stack(imgs).astype(np.float32) / 255.0
|
||||
img_stack = img_stack.transpose(0, 3, 1, 2)
|
||||
obs_dict[key] = torch.from_numpy(img_stack)
|
||||
|
||||
# 3. 应用图像增强
|
||||
if self.transform:
|
||||
images = self.transform(images)
|
||||
# --- 3. 状态值 (Low-dim State) ---
|
||||
# 对应你文件里的 qpos
|
||||
qpos = f["observations/qpos"][t_start].astype(np.float32)
|
||||
|
||||
return {
|
||||
"images": images,
|
||||
"text": lang_text, # 后续在 collate_fn 中处理 tokenize
|
||||
"actions": actions,
|
||||
"action_mask": action_mask
|
||||
"obs": obs_dict, # 视觉输入
|
||||
"qpos": torch.from_numpy(qpos), # 本体感受 (关节角)
|
||||
"actions": torch.from_numpy(actions_np).float(),
|
||||
"action_mask": action_mask, # Loss 掩码
|
||||
"language": self.lang_instruction # 文本指令
|
||||
}
|
||||
Reference in New Issue
Block a user