feat: 注册了自定义 resolver计算长度

This commit is contained in:
gouhanke
2026-02-06 16:08:56 +08:00
parent 7a9ca06aa0
commit ea49e63eb7
3 changed files with 84 additions and 19 deletions

View File

@@ -32,6 +32,10 @@ sys.path.append(os.getcwd())
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
if not OmegaConf.has_resolver("len"):
OmegaConf.register_new_resolver("len", lambda x: len(x))
class VLAEvaluator: class VLAEvaluator:
""" """

View File

@@ -7,7 +7,7 @@ import hydra
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW from torch.optim import AdamW
from pathlib import Path from pathlib import Path
@@ -18,6 +18,10 @@ from hydra.utils import instantiate
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
if not OmegaConf.has_resolver("len"):
OmegaConf.register_new_resolver("len", lambda x: len(x))
def recursive_to_device(data, device): def recursive_to_device(data, device):
""" """
@@ -75,15 +79,45 @@ def main(cfg: DictConfig):
log.error(f"❌ Failed to load dataset: {e}") log.error(f"❌ Failed to load dataset: {e}")
raise raise
dataloader = DataLoader( # Train/Val split
val_split = float(cfg.train.get('val_split', 0.1))
seed = int(cfg.train.get('seed', 42))
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
if val_size > 0:
train_dataset, val_dataset = random_split(
dataset, dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(seed)
)
log.info(f"✅ Dataset split: train={train_size}, val={val_size} (val_split={val_split})")
else:
train_dataset, val_dataset = dataset, None
log.info("✅ Dataset split: train=all, val=0 (val_split=0)")
train_loader = DataLoader(
train_dataset,
batch_size=cfg.train.batch_size, batch_size=cfg.train.batch_size,
shuffle=True, shuffle=True,
num_workers=cfg.train.num_workers, num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"), pin_memory=(cfg.train.device != "cpu"),
drop_last=True # Drop incomplete batches for stable training drop_last=True # Drop incomplete batches for stable training
) )
log.info(f"✅ DataLoader created. Batches per epoch: {len(dataloader)}")
val_loader = None
if val_dataset is not None:
val_loader = DataLoader(
val_dataset,
batch_size=cfg.train.batch_size,
shuffle=False,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
drop_last=False
)
log.info(f"✅ Train loader batches per epoch: {len(train_loader)}")
if val_loader is not None:
log.info(f"✅ Val loader batches per epoch: {len(val_loader)}")
# ========================================================================= # =========================================================================
# 2. Instantiate VLA Agent # 2. Instantiate VLA Agent
@@ -149,7 +183,36 @@ def main(cfg: DictConfig):
# ========================================================================= # =========================================================================
log.info("🏋️ Starting training loop...") log.info("🏋️ Starting training loop...")
data_iter = iter(dataloader) def build_agent_input(batch_data):
images = {}
for cam_name in cfg.data.camera_names:
key = f"image_{cam_name}"
if key in batch_data:
images[cam_name] = batch_data[key]
return {
'images': images,
'qpos': batch_data['qpos'],
'action': batch_data['action']
}
def run_validation():
if val_loader is None:
return None
agent.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for val_batch in val_loader:
val_batch = recursive_to_device(val_batch, cfg.train.device)
val_input = build_agent_input(val_batch)
val_loss = agent.compute_loss(val_input)
total_loss += val_loss.item()
num_batches += 1
agent.train()
return total_loss / max(num_batches, 1)
data_iter = iter(train_loader)
pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100) pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100)
best_loss = float('inf') best_loss = float('inf')
@@ -159,7 +222,7 @@ def main(cfg: DictConfig):
batch = next(data_iter) batch = next(data_iter)
except StopIteration: except StopIteration:
# Restart iterator when epoch ends # Restart iterator when epoch ends
data_iter = iter(dataloader) data_iter = iter(train_loader)
batch = next(data_iter) batch = next(data_iter)
# ===================================================================== # =====================================================================
@@ -173,19 +236,8 @@ def main(cfg: DictConfig):
# Dataset returns: {action, qpos, image_<cam_name>, ...} # Dataset returns: {action, qpos, image_<cam_name>, ...}
# Agent expects: {images: dict, qpos: tensor, action: tensor} # Agent expects: {images: dict, qpos: tensor, action: tensor}
# Extract images into a dictionary
images = {}
for cam_name in cfg.data.camera_names:
key = f"image_{cam_name}"
if key in batch:
images[cam_name] = batch[key] # (B, obs_horizon, C, H, W)
# Prepare agent input # Prepare agent input
agent_input = { agent_input = build_agent_input(batch)
'images': images, # Dict of camera images
'qpos': batch['qpos'], # (B, obs_horizon, obs_dim)
'action': batch['action'] # (B, pred_horizon, action_dim)
}
# ===================================================================== # =====================================================================
# Forward pass & compute loss # Forward pass & compute loss
@@ -217,6 +269,15 @@ def main(cfg: DictConfig):
}) })
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}") log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
# =====================================================================
# Validation
# =====================================================================
val_freq = int(cfg.train.get('val_freq', cfg.train.log_freq))
if val_loader is not None and val_freq > 0 and step % val_freq == 0:
val_loss = run_validation()
if val_loss is not None:
log.info(f"Step {step}/{cfg.train.max_steps} | Val Loss: {val_loss:.4f}")
# ===================================================================== # =====================================================================
# Checkpoint saving # Checkpoint saving
# ===================================================================== # =====================================================================

View File

@@ -19,4 +19,4 @@ obs_horizon: 2
diffusion_steps: 100 # Number of diffusion timesteps for training diffusion_steps: 100 # Number of diffusion timesteps for training
# Camera Configuration # Camera Configuration
num_cams: ${oc.len:data.camera_names} # 自动从 data.camera_names 列表长度获取 num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取