跑通配置和训练脚本
This commit is contained in:
@@ -1,45 +1,108 @@
|
|||||||
import hydra
|
import sys
|
||||||
from omegaconf import DictConfig, OmegaConf
|
|
||||||
from hydra.utils import instantiate
|
|
||||||
import torch
|
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
|
import hydra
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.optim import AdamW
|
||||||
|
|
||||||
# 必须指向你的配置文件所在路径
|
# 确保导入路径正确
|
||||||
# config_path 是相对于当前脚本的路径,或者绝对路径
|
sys.path.append(os.getcwd())
|
||||||
# config_name 是不带 .yaml 后缀的主文件名
|
|
||||||
@hydra.main(version_base=None, config_path="../../roboimi/vla/conf", config_name="config")
|
from roboimi.vla.agent import VLAAgent
|
||||||
|
from hydra.utils import instantiate
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path="../../../roboimi/vla/conf", config_name="config")
|
||||||
def main(cfg: DictConfig):
|
def main(cfg: DictConfig):
|
||||||
print(f"Working directory : {os.getcwd()}")
|
print(OmegaConf.to_yaml(cfg))
|
||||||
print(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")
|
log.info(f"🚀 Starting VLA Training with Real Data (Device: {cfg.train.device})")
|
||||||
|
|
||||||
# 1. 实例化 Agent
|
# --- 1. 实例化 Dataset & DataLoader ---
|
||||||
# Hydra 会自动查找 _target_ 并递归实例化 vlm_backbone 和 action_head
|
# Hydra 根据 conf/data/custom_hdf5.yaml 实例化类
|
||||||
print(">>> Instantiating VLA Agent...")
|
dataset = instantiate(cfg.data)
|
||||||
agent = instantiate(cfg.agent)
|
|
||||||
|
|
||||||
# 将模型移至 GPU
|
dataloader = DataLoader(
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
dataset,
|
||||||
agent.to(device)
|
batch_size=cfg.train.batch_size,
|
||||||
print(f">>> Agent created successfully. Backbone: {type(agent.vlm).__name__}")
|
|
||||||
|
|
||||||
# 2. 实例化 DataLoader (假设你也为 Data 写了 yaml)
|
|
||||||
# 实例化 Dataset
|
|
||||||
dataset = hydra.utils.instantiate(cfg.data)
|
|
||||||
|
|
||||||
# 封装进 DataLoader
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=cfg.train.batch_size,
|
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=4
|
num_workers=cfg.train.num_workers,
|
||||||
|
pin_memory=(cfg.train.device != "cpu")
|
||||||
)
|
)
|
||||||
|
log.info(f"✅ Dataset loaded. Size: {len(dataset)}")
|
||||||
|
|
||||||
# 3. 实例化 Optimizer (Hydra 也支持 partial 实例化)
|
# --- 2. 实例化 Agent ---
|
||||||
# optimizer = instantiate(cfg.train.optimizer, params=agent.parameters())
|
agent: VLAAgent = instantiate(cfg.agent)
|
||||||
|
agent.to(cfg.train.device)
|
||||||
|
agent.train()
|
||||||
|
|
||||||
# 4. 模拟训练循环
|
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr)
|
||||||
print(f">>> Starting training with batch size: {cfg.train.batch_size}")
|
|
||||||
# ... training loop logic here ...
|
# --- 3. Training Loop ---
|
||||||
|
# 使用一个无限迭代器或者 epoch 循环
|
||||||
|
data_iter = iter(dataloader)
|
||||||
|
pbar = tqdm(range(cfg.train.max_steps), desc="Training")
|
||||||
|
|
||||||
|
for step in pbar:
|
||||||
|
try:
|
||||||
|
batch = next(data_iter)
|
||||||
|
except StopIteration:
|
||||||
|
#而在 epoch 结束时重新开始
|
||||||
|
data_iter = iter(dataloader)
|
||||||
|
batch = next(data_iter)
|
||||||
|
|
||||||
|
# Move to device
|
||||||
|
# 注意:这里需要递归地将字典里的 tensor 移到 GPU
|
||||||
|
batch = recursive_to_device(batch, cfg.train.device)
|
||||||
|
|
||||||
|
# --- 4. Adapter Layer (适配层) ---
|
||||||
|
# Dataset 返回的是具体的相机 key (如 'agentview_image' 或 'top')
|
||||||
|
# Agent 期望的是通用的 'image'
|
||||||
|
# 我们在这里做一个映射,模拟多模态融合前的处理
|
||||||
|
|
||||||
|
# 假设我们只用配置里的第一个 key 作为主视觉
|
||||||
|
primary_cam_key = cfg.data.obs_keys[0]
|
||||||
|
|
||||||
|
# Dataset 返回 shape: (B, Obs_Horizon, C, H, W)
|
||||||
|
# DebugBackbone 期望: (B, C, H, W) 或者 (B, Seq, Dim)
|
||||||
|
# 这里我们取 Obs_Horizon 的最后一帧 (Current Frame)
|
||||||
|
input_img = batch['obs'][primary_cam_key][:, -1, :, :, :]
|
||||||
|
|
||||||
|
agent_input = {
|
||||||
|
"obs": {
|
||||||
|
"image": input_img,
|
||||||
|
"text": batch["language"] # 传递语言指令
|
||||||
|
},
|
||||||
|
"actions": batch["actions"] # (B, Chunk, Dim)
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- 5. Forward & Backward ---
|
||||||
|
outputs = agent(agent_input)
|
||||||
|
|
||||||
|
# 处理 Loss 掩码 (如果在真实训练中,需要在这里应用 action_mask)
|
||||||
|
# 目前 DebugHead 内部直接算了 MSE,还没用 mask,我们在下一阶段优化 Policy 时加上
|
||||||
|
loss = outputs['loss']
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if step % cfg.train.log_freq == 0:
|
||||||
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||||
|
|
||||||
|
log.info("✅ Training Loop with Real HDF5 Finished!")
|
||||||
|
|
||||||
|
def recursive_to_device(data, device):
|
||||||
|
if isinstance(data, torch.Tensor):
|
||||||
|
return data.to(device)
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
return {k: recursive_to_device(v, device) for k, v in data.items()}
|
||||||
|
elif isinstance(data, list):
|
||||||
|
return [recursive_to_device(v, device) for v in data]
|
||||||
|
return data
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
@@ -1 +1,26 @@
|
|||||||
# 调试用小模型
|
# 调试用小模型
|
||||||
|
# @package agent
|
||||||
|
_target_: roboimi.vla.agent.VLAAgent
|
||||||
|
|
||||||
|
# --- 1. Backbone (VLM) ---
|
||||||
|
backbone:
|
||||||
|
_target_: roboimi.vla.models.backbones.debug.DebugBackbone
|
||||||
|
embed_dim: 768 # 定义源头维度
|
||||||
|
seq_len: 10
|
||||||
|
|
||||||
|
# --- 2. Projector (Adapter) ---
|
||||||
|
projector:
|
||||||
|
_target_: roboimi.vla.models.projectors.mlp.MLPProjector
|
||||||
|
# 【关键】依赖注入:自动读取 backbone 的 embed_dim
|
||||||
|
input_dim: ${..backbone.embed_dim}
|
||||||
|
output_dim: 128 # 瓶颈层维度 (Tiny scale)
|
||||||
|
|
||||||
|
# --- 3. Head (Policy) ---
|
||||||
|
head:
|
||||||
|
_target_: roboimi.vla.models.heads.debug.DebugHead
|
||||||
|
input_dim: ${..projector.output_dim}
|
||||||
|
|
||||||
|
# 【关键修改】改为 16 以匹配你的 Sim 数据
|
||||||
|
action_dim: 16
|
||||||
|
|
||||||
|
chunk_size: 16
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- _self_
|
- _self_
|
||||||
- agent: debug_vla # <--- This tells Hydra to look in conf/agent/ and load debug_vla.yaml
|
- agent: tiny
|
||||||
# Future expansions:
|
- data: custom_hdf5 # 新增这一行,激活数据配置
|
||||||
# - data: robomimic_hdf5
|
|
||||||
# - train: standard
|
|
||||||
|
|
||||||
# Global settings (optional for now)
|
train:
|
||||||
seed: 42
|
batch_size: 4 # 减小 batch size 方便调试
|
||||||
|
lr: 1e-4
|
||||||
|
max_steps: 100
|
||||||
|
log_freq: 10
|
||||||
|
device: "cpu"
|
||||||
|
num_workers: 0 # 调试设为0,验证通过后改为 2 或 4
|
||||||
8
roboimi/vla/conf/data/custom_hdf5.yaml
Normal file
8
roboimi/vla/conf/data/custom_hdf5.yaml
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
_target_: roboimi.vla.data.dataset.VLAChunkedDataset
|
||||||
|
|
||||||
|
# 【关键修改】指向你的数据文件夹目录
|
||||||
|
data_path: "/home/d51/workspace/work/robo-imi-act/roboimi/demos/dataset/sim_transfer"
|
||||||
|
|
||||||
|
pred_horizon: 16
|
||||||
|
obs_horizon: 1 # 先只用单帧调试
|
||||||
|
obs_keys: ["top"] # 数据里有 top, angle, r_vis,我们先拿 top 跑通
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
import h5py
|
import h5py
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any
|
||||||
|
|
||||||
@@ -10,72 +12,108 @@ class VLAChunkedDataset(Dataset):
|
|||||||
data_path: str,
|
data_path: str,
|
||||||
pred_horizon: int = 16,
|
pred_horizon: int = 16,
|
||||||
obs_horizon: int = 2,
|
obs_horizon: int = 2,
|
||||||
obs_keys: List[str] = ["top", "angle"]
|
obs_keys: List[str] = ["top"] # 默认只用 top
|
||||||
):
|
):
|
||||||
self.data_path = data_path
|
self.data_path = data_path
|
||||||
self.pred_horizon = pred_horizon
|
self.pred_horizon = pred_horizon
|
||||||
self.obs_horizon = obs_horizon
|
self.obs_horizon = obs_horizon
|
||||||
self.obs_keys = obs_keys
|
self.obs_keys = obs_keys
|
||||||
self.file_handle = None
|
|
||||||
|
|
||||||
with h5py.File(self.data_path, 'r') as f:
|
# --- 1. 扫描文件 ---
|
||||||
self.total_len = f["action"].shape[0]
|
if os.path.isdir(data_path):
|
||||||
# 尝试从属性或特定路径读取语言指令
|
# 如果是文件夹,读取所有 episode_*.hdf5
|
||||||
# 假设你的格式中语言存在根目录属性里,或者你手动指定
|
self.file_paths = sorted(glob.glob(os.path.join(data_path, "*.hdf5")))
|
||||||
self.lang_instruction = f.attrs.get("language", "执行任务")
|
else:
|
||||||
if isinstance(self.lang_instruction, bytes):
|
# 如果是单文件
|
||||||
self.lang_instruction = self.lang_instruction.decode("utf-8")
|
self.file_paths = [data_path]
|
||||||
|
|
||||||
def _get_handle(self):
|
if len(self.file_paths) == 0:
|
||||||
if self.file_handle is None:
|
raise ValueError(f"No .hdf5 files found in {data_path}")
|
||||||
self.file_handle = h5py.File(self.data_path, 'r', swmr=True)
|
|
||||||
return self.file_handle
|
print(f"Found {len(self.file_paths)} episodes. Indexing...")
|
||||||
|
|
||||||
|
# --- 2. 建立全局索引 (Episode, Time) ---
|
||||||
|
# 我们需要知道 global_index=1000 对应的是哪个文件的第几帧
|
||||||
|
self.index_map = [] # [(file_idx, start_time), ...]
|
||||||
|
|
||||||
|
for i, path in enumerate(self.file_paths):
|
||||||
|
with h5py.File(path, 'r') as f:
|
||||||
|
# 假设所有文件的 action 长度就是 episode 长度
|
||||||
|
total_len = f["action"].shape[0]
|
||||||
|
# 有效的起始点:从 0 到 total_len - 1
|
||||||
|
# 即使到了最后几帧,因为有 padding,所以也是有效的 sample
|
||||||
|
for t in range(total_len):
|
||||||
|
self.index_map.append((i, t))
|
||||||
|
|
||||||
|
print(f"✅ Indexed {len(self.index_map)} total samples.")
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.total_len
|
return len(self.index_map)
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||||
f = self._get_handle()
|
# --- 1. 定位文件 ---
|
||||||
t_start = idx
|
file_idx, t_start = self.index_map[idx]
|
||||||
|
file_path = self.file_paths[file_idx]
|
||||||
|
|
||||||
# --- 1. 动作与掩码 (Action & Mask) ---
|
# 每次读取打开文件 (Lazy Loading),读取完自动关闭
|
||||||
t_end = min(t_start + self.pred_horizon, self.total_len)
|
# 这种方式对多进程 DataLoader 最安全
|
||||||
actual_len = t_end - t_start
|
with h5py.File(file_path, 'r') as f:
|
||||||
|
total_len = f["action"].shape[0]
|
||||||
actions_np = f["action"][t_start:t_end]
|
|
||||||
|
|
||||||
# 创建掩码:1 表示真实数据,0 表示 Padding
|
|
||||||
# 这是为了在计算 Loss 时屏蔽掉末端重复的动作
|
|
||||||
action_mask = torch.ones(self.pred_horizon, dtype=torch.float32)
|
|
||||||
|
|
||||||
if actual_len < self.pred_horizon:
|
|
||||||
pad_len = self.pred_horizon - actual_len
|
|
||||||
# 填充最后一个有效动作
|
|
||||||
pad_block = np.tile(actions_np[-1], (pad_len, 1))
|
|
||||||
actions_np = np.concatenate([actions_np, pad_block], axis=0)
|
|
||||||
# 将填充部分的掩码置为 0
|
|
||||||
action_mask[actual_len:] = 0.0
|
|
||||||
|
|
||||||
# --- 2. 观察值 (Observations) ---
|
|
||||||
obs_dict = {}
|
|
||||||
for key in self.obs_keys:
|
|
||||||
imgs = []
|
|
||||||
for i in range(self.obs_horizon):
|
|
||||||
t_query = max(0, t_start - (self.obs_horizon - 1) + i)
|
|
||||||
imgs.append(f[f"observations/images/{key}"][t_query])
|
|
||||||
|
|
||||||
img_stack = np.stack(imgs).astype(np.float32) / 255.0
|
# --- 2. 动作 (Action) ---
|
||||||
img_stack = img_stack.transpose(0, 3, 1, 2)
|
t_end = min(t_start + self.pred_horizon, total_len)
|
||||||
obs_dict[key] = torch.from_numpy(img_stack)
|
|
||||||
|
# 读取动作片段
|
||||||
|
actions_np = f["action"][t_start:t_end] # (L, 16)
|
||||||
|
|
||||||
|
# Padding 处理
|
||||||
|
actual_len = actions_np.shape[0]
|
||||||
|
action_mask = torch.ones(self.pred_horizon, dtype=torch.float32)
|
||||||
|
|
||||||
|
if actual_len < self.pred_horizon:
|
||||||
|
pad_len = self.pred_horizon - actual_len
|
||||||
|
# 重复最后一帧动作进行填充
|
||||||
|
pad_block = np.tile(actions_np[-1], (pad_len, 1))
|
||||||
|
actions_np = np.concatenate([actions_np, pad_block], axis=0)
|
||||||
|
# 标记 Padding 部分为 0
|
||||||
|
action_mask[actual_len:] = 0.0
|
||||||
|
|
||||||
|
# --- 3. 图像 (Images) ---
|
||||||
|
obs_dict = {}
|
||||||
|
for key in self.obs_keys:
|
||||||
|
imgs = []
|
||||||
|
# 处理观测历史 (Obs Horizon)
|
||||||
|
# 如果 t_start=0, obs_horizon=2, 我们需要读取 t=0 和 t=0 (重复第一帧)
|
||||||
|
for i in range(self.obs_horizon):
|
||||||
|
# 倒序读取:当前帧,前一帧...
|
||||||
|
# 注意:这里逻辑是 [t_start - (obs_horizon-1) + i]
|
||||||
|
# 比如 horizon=2, t=10. i=0 -> t=9; i=1 -> t=10.
|
||||||
|
query_t = t_start - (self.obs_horizon - 1) + i
|
||||||
|
query_t = max(0, query_t) # 边界保护
|
||||||
|
|
||||||
|
imgs.append(f[f"observations/images/{key}"][query_t])
|
||||||
|
|
||||||
|
# Stack -> (Obs_Horizon, H, W, C)
|
||||||
|
img_stack = np.stack(imgs)
|
||||||
|
# Normalize & Permute -> (Obs_Horizon, C, H, W)
|
||||||
|
img_stack = img_stack.astype(np.float32) / 255.0
|
||||||
|
img_stack = np.transpose(img_stack, (0, 3, 1, 2))
|
||||||
|
|
||||||
|
obs_dict[key] = torch.from_numpy(img_stack)
|
||||||
|
|
||||||
# --- 3. 状态值 (Low-dim State) ---
|
# --- 4. QPos ---
|
||||||
# 对应你文件里的 qpos
|
qpos = f["observations/qpos"][t_start].astype(np.float32)
|
||||||
qpos = f["observations/qpos"][t_start].astype(np.float32)
|
|
||||||
|
# --- 5. Language ---
|
||||||
|
# 暂时写死或从 attrs 读取
|
||||||
|
lang = f.attrs.get("language", "task instruction placeholder")
|
||||||
|
if isinstance(lang, bytes):
|
||||||
|
lang = lang.decode("utf-8")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"obs": obs_dict, # 视觉输入
|
"obs": obs_dict,
|
||||||
"qpos": torch.from_numpy(qpos), # 本体感受 (关节角)
|
"qpos": torch.from_numpy(qpos),
|
||||||
"actions": torch.from_numpy(actions_np).float(),
|
"actions": torch.from_numpy(actions_np).float(),
|
||||||
"action_mask": action_mask, # Loss 掩码
|
"action_mask": action_mask,
|
||||||
"language": self.lang_instruction # 文本指令
|
"language": lang
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user