feat: add vision transfer backbones and IMF variants

This commit is contained in:
Logic
2026-04-09 14:02:24 +08:00
parent d51b3ecafa
commit ff7c9c1f2a
58 changed files with 2788 additions and 26 deletions

View File

@@ -1,7 +1,7 @@
import torch
import h5py
from torch.utils.data import Dataset
from typing import List, Dict, Union
from typing import List, Dict, Union, Optional, Sequence
from pathlib import Path
from collections import OrderedDict
@@ -22,6 +22,7 @@ class SimpleRobotDataset(Dataset):
obs_horizon: int = 2,
pred_horizon: int = 8,
camera_names: List[str] = None,
image_resize_shape: Optional[Sequence[int]] = (224, 224),
max_open_files: int = 64,
):
"""
@@ -30,6 +31,7 @@ class SimpleRobotDataset(Dataset):
obs_horizon: 观察过去多少帧
pred_horizon: 预测未来多少帧动作
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
image_resize_shape: 图像缩放尺寸 (W, H);为 None 时保留原始分辨率
max_open_files: 每个 worker 最多缓存的 HDF5 文件句柄数
HDF5 文件格式:
@@ -40,6 +42,10 @@ class SimpleRobotDataset(Dataset):
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.camera_names = camera_names or []
self.image_resize_shape = (
tuple(int(v) for v in image_resize_shape)
if image_resize_shape is not None else None
)
self.max_open_files = max(1, int(max_open_files))
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
@@ -123,9 +129,9 @@ class SimpleRobotDataset(Dataset):
h5_path = f'observations/images/{cam_name}'
if h5_path in f:
img = f[h5_path][meta["frame_idx"]]
# Resize图像到224x224减少内存和I/O负担
import cv2
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
if self.image_resize_shape is not None:
import cv2
img = cv2.resize(img, self.image_resize_shape, interpolation=cv2.INTER_LINEAR)
# 转换为float并归一化到 [0, 1]
img = torch.from_numpy(img).float() / 255.0
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW