feat: 注册了自定义 resolver计算长度
This commit is contained in:
@@ -32,6 +32,10 @@ sys.path.append(os.getcwd())
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
|
||||
if not OmegaConf.has_resolver("len"):
|
||||
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||
|
||||
|
||||
class VLAEvaluator:
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,7 @@ import hydra
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torch.optim import AdamW
|
||||
from pathlib import Path
|
||||
|
||||
@@ -18,6 +18,10 @@ from hydra.utils import instantiate
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
|
||||
if not OmegaConf.has_resolver("len"):
|
||||
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||
|
||||
|
||||
def recursive_to_device(data, device):
|
||||
"""
|
||||
@@ -75,15 +79,45 @@ def main(cfg: DictConfig):
|
||||
log.error(f"❌ Failed to load dataset: {e}")
|
||||
raise
|
||||
|
||||
dataloader = DataLoader(
|
||||
# Train/Val split
|
||||
val_split = float(cfg.train.get('val_split', 0.1))
|
||||
seed = int(cfg.train.get('seed', 42))
|
||||
val_size = int(len(dataset) * val_split)
|
||||
train_size = len(dataset) - val_size
|
||||
if val_size > 0:
|
||||
train_dataset, val_dataset = random_split(
|
||||
dataset,
|
||||
[train_size, val_size],
|
||||
generator=torch.Generator().manual_seed(seed)
|
||||
)
|
||||
log.info(f"✅ Dataset split: train={train_size}, val={val_size} (val_split={val_split})")
|
||||
else:
|
||||
train_dataset, val_dataset = dataset, None
|
||||
log.info("✅ Dataset split: train=all, val=0 (val_split=0)")
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=cfg.train.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=cfg.train.num_workers,
|
||||
pin_memory=(cfg.train.device != "cpu"),
|
||||
drop_last=True # Drop incomplete batches for stable training
|
||||
)
|
||||
log.info(f"✅ DataLoader created. Batches per epoch: {len(dataloader)}")
|
||||
|
||||
val_loader = None
|
||||
if val_dataset is not None:
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=cfg.train.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=cfg.train.num_workers,
|
||||
pin_memory=(cfg.train.device != "cpu"),
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
log.info(f"✅ Train loader batches per epoch: {len(train_loader)}")
|
||||
if val_loader is not None:
|
||||
log.info(f"✅ Val loader batches per epoch: {len(val_loader)}")
|
||||
|
||||
# =========================================================================
|
||||
# 2. Instantiate VLA Agent
|
||||
@@ -149,7 +183,36 @@ def main(cfg: DictConfig):
|
||||
# =========================================================================
|
||||
log.info("🏋️ Starting training loop...")
|
||||
|
||||
data_iter = iter(dataloader)
|
||||
def build_agent_input(batch_data):
|
||||
images = {}
|
||||
for cam_name in cfg.data.camera_names:
|
||||
key = f"image_{cam_name}"
|
||||
if key in batch_data:
|
||||
images[cam_name] = batch_data[key]
|
||||
|
||||
return {
|
||||
'images': images,
|
||||
'qpos': batch_data['qpos'],
|
||||
'action': batch_data['action']
|
||||
}
|
||||
|
||||
def run_validation():
|
||||
if val_loader is None:
|
||||
return None
|
||||
agent.eval()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
with torch.no_grad():
|
||||
for val_batch in val_loader:
|
||||
val_batch = recursive_to_device(val_batch, cfg.train.device)
|
||||
val_input = build_agent_input(val_batch)
|
||||
val_loss = agent.compute_loss(val_input)
|
||||
total_loss += val_loss.item()
|
||||
num_batches += 1
|
||||
agent.train()
|
||||
return total_loss / max(num_batches, 1)
|
||||
|
||||
data_iter = iter(train_loader)
|
||||
pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100)
|
||||
|
||||
best_loss = float('inf')
|
||||
@@ -159,7 +222,7 @@ def main(cfg: DictConfig):
|
||||
batch = next(data_iter)
|
||||
except StopIteration:
|
||||
# Restart iterator when epoch ends
|
||||
data_iter = iter(dataloader)
|
||||
data_iter = iter(train_loader)
|
||||
batch = next(data_iter)
|
||||
|
||||
# =====================================================================
|
||||
@@ -173,19 +236,8 @@ def main(cfg: DictConfig):
|
||||
# Dataset returns: {action, qpos, image_<cam_name>, ...}
|
||||
# Agent expects: {images: dict, qpos: tensor, action: tensor}
|
||||
|
||||
# Extract images into a dictionary
|
||||
images = {}
|
||||
for cam_name in cfg.data.camera_names:
|
||||
key = f"image_{cam_name}"
|
||||
if key in batch:
|
||||
images[cam_name] = batch[key] # (B, obs_horizon, C, H, W)
|
||||
|
||||
# Prepare agent input
|
||||
agent_input = {
|
||||
'images': images, # Dict of camera images
|
||||
'qpos': batch['qpos'], # (B, obs_horizon, obs_dim)
|
||||
'action': batch['action'] # (B, pred_horizon, action_dim)
|
||||
}
|
||||
agent_input = build_agent_input(batch)
|
||||
|
||||
# =====================================================================
|
||||
# Forward pass & compute loss
|
||||
@@ -217,6 +269,15 @@ def main(cfg: DictConfig):
|
||||
})
|
||||
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
|
||||
|
||||
# =====================================================================
|
||||
# Validation
|
||||
# =====================================================================
|
||||
val_freq = int(cfg.train.get('val_freq', cfg.train.log_freq))
|
||||
if val_loader is not None and val_freq > 0 and step % val_freq == 0:
|
||||
val_loss = run_validation()
|
||||
if val_loss is not None:
|
||||
log.info(f"Step {step}/{cfg.train.max_steps} | Val Loss: {val_loss:.4f}")
|
||||
|
||||
# =====================================================================
|
||||
# Checkpoint saving
|
||||
# =====================================================================
|
||||
|
||||
@@ -19,4 +19,4 @@ obs_horizon: 2
|
||||
diffusion_steps: 100 # Number of diffusion timesteps for training
|
||||
|
||||
# Camera Configuration
|
||||
num_cams: ${oc.len:data.camera_names} # 自动从 data.camera_names 列表长度获取
|
||||
num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取
|
||||
Reference in New Issue
Block a user