feat(vla): vla框架初始化

This commit is contained in:
gouhanke
2026-02-03 14:18:30 +08:00
parent c1ce560b32
commit 57acfd645f
40 changed files with 443 additions and 63 deletions

View File

View File

@@ -0,0 +1,88 @@
import h5py
import torch
import numpy as np
from torch.utils.data import Dataset
class VLAHDF5Dataset(Dataset):
def __init__(self,
dataset_path: str,
pred_horizon: int = 16,
obs_horizon: int = 2,
transform=None):
self.dataset_path = dataset_path
self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon
self.transform = transform
# 1. 在初始化时,我们只读取数据的“元数据”(形状、长度),不加载内容
# 这一步很快,不会占用内存
with h5py.File(self.dataset_path, 'r') as root:
self.demo_keys = list(root['data'].keys())
# 构建索引表:(demo_key, start_time)
self.indices = []
for key in self.demo_keys:
demo = root['data'][key]
L = demo['actions'].shape[0]
# 遍历该轨迹的所有时刻
for t in range(L):
self.indices.append((key, t))
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
key, t_start = self.indices[idx]
# 2. 【关键】在 __getitem__ 内部打开文件
# 这确保了每个 DataLoader worker 都有自己独立的文件句柄
with h5py.File(self.dataset_path, 'r') as root:
demo = root['data'][key]
# 获取数据总长度
L = demo['actions'].shape[0]
# --- 读取动作 (Actions) ---
t_end = min(t_start + self.pred_horizon, L)
# HDF5 支持直接切片读取,非常快
actions = demo['actions'][t_start : t_end]
# 处理 Padding (如果动作不够长)
if len(actions) < self.pred_horizon:
# 转为 Tensor 处理 Padding
actions = torch.from_numpy(actions)
pad_len = self.pred_horizon - len(actions)
last_action = actions[-1].unsqueeze(0)
actions = torch.cat([actions, last_action.repeat(pad_len, 1)])
action_mask = torch.cat([torch.ones(len(actions)-pad_len), torch.zeros(pad_len)])
else:
actions = torch.from_numpy(actions)
action_mask = torch.ones(self.pred_horizon)
# --- 读取图像 (Images) ---
# 处理历史观测 padding (如果 t_start < obs_horizon)
images_list = []
for i in range(self.obs_horizon):
t_read = max(0, t_start - self.obs_horizon + 1 + i)
# 读取单帧
img = demo['obs']['agentview_rgb'][t_read]
images_list.append(img)
# Stack 并转为 Tensor: [T, H, W, C] -> [T, C, H, W]
images = np.stack(images_list)
images = torch.from_numpy(images).permute(0, 3, 1, 2).float() / 255.0
# --- 读取语言指令 ---
# 假设语言存储在 demo 的属性中 (Robomimic 风格)
lang_text = demo.attrs.get("model_file", "") # 或自定义字段
# 3. 应用图像增强
if self.transform:
images = self.transform(images)
return {
"images": images,
"text": lang_text, # 后续在 collate_fn 中处理 tokenize
"actions": actions,
"action_mask": action_mask
}

View File

@@ -0,0 +1 @@
# 图像预处理

View File

@@ -0,0 +1 @@
# 文本 Tokenizer 包装