feat: 注册了自定义 resolver计算长度

This commit is contained in:
gouhanke
2026-02-06 16:08:56 +08:00
parent 7a9ca06aa0
commit ea49e63eb7
3 changed files with 84 additions and 19 deletions

View File

@@ -32,6 +32,10 @@ sys.path.append(os.getcwd())
log = logging.getLogger(__name__)
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
if not OmegaConf.has_resolver("len"):
OmegaConf.register_new_resolver("len", lambda x: len(x))
class VLAEvaluator:
"""

View File

@@ -7,7 +7,7 @@ import hydra
import torch
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from pathlib import Path
@@ -18,6 +18,10 @@ from hydra.utils import instantiate
log = logging.getLogger(__name__)
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
if not OmegaConf.has_resolver("len"):
OmegaConf.register_new_resolver("len", lambda x: len(x))
def recursive_to_device(data, device):
"""
@@ -75,15 +79,45 @@ def main(cfg: DictConfig):
log.error(f"❌ Failed to load dataset: {e}")
raise
dataloader = DataLoader(
dataset,
# Train/Val split
val_split = float(cfg.train.get('val_split', 0.1))
seed = int(cfg.train.get('seed', 42))
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
if val_size > 0:
train_dataset, val_dataset = random_split(
dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(seed)
)
log.info(f"✅ Dataset split: train={train_size}, val={val_size} (val_split={val_split})")
else:
train_dataset, val_dataset = dataset, None
log.info("✅ Dataset split: train=all, val=0 (val_split=0)")
train_loader = DataLoader(
train_dataset,
batch_size=cfg.train.batch_size,
shuffle=True,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
drop_last=True # Drop incomplete batches for stable training
)
log.info(f"✅ DataLoader created. Batches per epoch: {len(dataloader)}")
val_loader = None
if val_dataset is not None:
val_loader = DataLoader(
val_dataset,
batch_size=cfg.train.batch_size,
shuffle=False,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
drop_last=False
)
log.info(f"✅ Train loader batches per epoch: {len(train_loader)}")
if val_loader is not None:
log.info(f"✅ Val loader batches per epoch: {len(val_loader)}")
# =========================================================================
# 2. Instantiate VLA Agent
@@ -149,7 +183,36 @@ def main(cfg: DictConfig):
# =========================================================================
log.info("🏋️ Starting training loop...")
data_iter = iter(dataloader)
def build_agent_input(batch_data):
images = {}
for cam_name in cfg.data.camera_names:
key = f"image_{cam_name}"
if key in batch_data:
images[cam_name] = batch_data[key]
return {
'images': images,
'qpos': batch_data['qpos'],
'action': batch_data['action']
}
def run_validation():
if val_loader is None:
return None
agent.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for val_batch in val_loader:
val_batch = recursive_to_device(val_batch, cfg.train.device)
val_input = build_agent_input(val_batch)
val_loss = agent.compute_loss(val_input)
total_loss += val_loss.item()
num_batches += 1
agent.train()
return total_loss / max(num_batches, 1)
data_iter = iter(train_loader)
pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100)
best_loss = float('inf')
@@ -159,7 +222,7 @@ def main(cfg: DictConfig):
batch = next(data_iter)
except StopIteration:
# Restart iterator when epoch ends
data_iter = iter(dataloader)
data_iter = iter(train_loader)
batch = next(data_iter)
# =====================================================================
@@ -173,19 +236,8 @@ def main(cfg: DictConfig):
# Dataset returns: {action, qpos, image_<cam_name>, ...}
# Agent expects: {images: dict, qpos: tensor, action: tensor}
# Extract images into a dictionary
images = {}
for cam_name in cfg.data.camera_names:
key = f"image_{cam_name}"
if key in batch:
images[cam_name] = batch[key] # (B, obs_horizon, C, H, W)
# Prepare agent input
agent_input = {
'images': images, # Dict of camera images
'qpos': batch['qpos'], # (B, obs_horizon, obs_dim)
'action': batch['action'] # (B, pred_horizon, action_dim)
}
agent_input = build_agent_input(batch)
# =====================================================================
# Forward pass & compute loss
@@ -217,6 +269,15 @@ def main(cfg: DictConfig):
})
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
# =====================================================================
# Validation
# =====================================================================
val_freq = int(cfg.train.get('val_freq', cfg.train.log_freq))
if val_loader is not None and val_freq > 0 and step % val_freq == 0:
val_loss = run_validation()
if val_loss is not None:
log.info(f"Step {step}/{cfg.train.max_steps} | Val Loss: {val_loss:.4f}")
# =====================================================================
# Checkpoint saving
# =====================================================================

View File

@@ -19,4 +19,4 @@ obs_horizon: 2
diffusion_steps: 100 # Number of diffusion timesteps for training
# Camera Configuration
num_cams: ${oc.len:data.camera_names} # 自动从 data.camera_names 列表长度获取
num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取