feat: 注册了自定义 resolver计算长度
This commit is contained in:
@@ -32,6 +32,10 @@ sys.path.append(os.getcwd())
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
|
||||||
|
if not OmegaConf.has_resolver("len"):
|
||||||
|
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||||
|
|
||||||
|
|
||||||
class VLAEvaluator:
|
class VLAEvaluator:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import hydra
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader, random_split
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -18,6 +18,10 @@ from hydra.utils import instantiate
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Register resolver for list length in configs (e.g., ${len:${data.camera_names}})
|
||||||
|
if not OmegaConf.has_resolver("len"):
|
||||||
|
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||||
|
|
||||||
|
|
||||||
def recursive_to_device(data, device):
|
def recursive_to_device(data, device):
|
||||||
"""
|
"""
|
||||||
@@ -75,15 +79,45 @@ def main(cfg: DictConfig):
|
|||||||
log.error(f"❌ Failed to load dataset: {e}")
|
log.error(f"❌ Failed to load dataset: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
dataloader = DataLoader(
|
# Train/Val split
|
||||||
|
val_split = float(cfg.train.get('val_split', 0.1))
|
||||||
|
seed = int(cfg.train.get('seed', 42))
|
||||||
|
val_size = int(len(dataset) * val_split)
|
||||||
|
train_size = len(dataset) - val_size
|
||||||
|
if val_size > 0:
|
||||||
|
train_dataset, val_dataset = random_split(
|
||||||
dataset,
|
dataset,
|
||||||
|
[train_size, val_size],
|
||||||
|
generator=torch.Generator().manual_seed(seed)
|
||||||
|
)
|
||||||
|
log.info(f"✅ Dataset split: train={train_size}, val={val_size} (val_split={val_split})")
|
||||||
|
else:
|
||||||
|
train_dataset, val_dataset = dataset, None
|
||||||
|
log.info("✅ Dataset split: train=all, val=0 (val_split=0)")
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
batch_size=cfg.train.batch_size,
|
batch_size=cfg.train.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=cfg.train.num_workers,
|
num_workers=cfg.train.num_workers,
|
||||||
pin_memory=(cfg.train.device != "cpu"),
|
pin_memory=(cfg.train.device != "cpu"),
|
||||||
drop_last=True # Drop incomplete batches for stable training
|
drop_last=True # Drop incomplete batches for stable training
|
||||||
)
|
)
|
||||||
log.info(f"✅ DataLoader created. Batches per epoch: {len(dataloader)}")
|
|
||||||
|
val_loader = None
|
||||||
|
if val_dataset is not None:
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=cfg.train.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=cfg.train.num_workers,
|
||||||
|
pin_memory=(cfg.train.device != "cpu"),
|
||||||
|
drop_last=False
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f"✅ Train loader batches per epoch: {len(train_loader)}")
|
||||||
|
if val_loader is not None:
|
||||||
|
log.info(f"✅ Val loader batches per epoch: {len(val_loader)}")
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 2. Instantiate VLA Agent
|
# 2. Instantiate VLA Agent
|
||||||
@@ -149,7 +183,36 @@ def main(cfg: DictConfig):
|
|||||||
# =========================================================================
|
# =========================================================================
|
||||||
log.info("🏋️ Starting training loop...")
|
log.info("🏋️ Starting training loop...")
|
||||||
|
|
||||||
data_iter = iter(dataloader)
|
def build_agent_input(batch_data):
|
||||||
|
images = {}
|
||||||
|
for cam_name in cfg.data.camera_names:
|
||||||
|
key = f"image_{cam_name}"
|
||||||
|
if key in batch_data:
|
||||||
|
images[cam_name] = batch_data[key]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'images': images,
|
||||||
|
'qpos': batch_data['qpos'],
|
||||||
|
'action': batch_data['action']
|
||||||
|
}
|
||||||
|
|
||||||
|
def run_validation():
|
||||||
|
if val_loader is None:
|
||||||
|
return None
|
||||||
|
agent.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
num_batches = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for val_batch in val_loader:
|
||||||
|
val_batch = recursive_to_device(val_batch, cfg.train.device)
|
||||||
|
val_input = build_agent_input(val_batch)
|
||||||
|
val_loss = agent.compute_loss(val_input)
|
||||||
|
total_loss += val_loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
agent.train()
|
||||||
|
return total_loss / max(num_batches, 1)
|
||||||
|
|
||||||
|
data_iter = iter(train_loader)
|
||||||
pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100)
|
pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100)
|
||||||
|
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
@@ -159,7 +222,7 @@ def main(cfg: DictConfig):
|
|||||||
batch = next(data_iter)
|
batch = next(data_iter)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
# Restart iterator when epoch ends
|
# Restart iterator when epoch ends
|
||||||
data_iter = iter(dataloader)
|
data_iter = iter(train_loader)
|
||||||
batch = next(data_iter)
|
batch = next(data_iter)
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
@@ -173,19 +236,8 @@ def main(cfg: DictConfig):
|
|||||||
# Dataset returns: {action, qpos, image_<cam_name>, ...}
|
# Dataset returns: {action, qpos, image_<cam_name>, ...}
|
||||||
# Agent expects: {images: dict, qpos: tensor, action: tensor}
|
# Agent expects: {images: dict, qpos: tensor, action: tensor}
|
||||||
|
|
||||||
# Extract images into a dictionary
|
|
||||||
images = {}
|
|
||||||
for cam_name in cfg.data.camera_names:
|
|
||||||
key = f"image_{cam_name}"
|
|
||||||
if key in batch:
|
|
||||||
images[cam_name] = batch[key] # (B, obs_horizon, C, H, W)
|
|
||||||
|
|
||||||
# Prepare agent input
|
# Prepare agent input
|
||||||
agent_input = {
|
agent_input = build_agent_input(batch)
|
||||||
'images': images, # Dict of camera images
|
|
||||||
'qpos': batch['qpos'], # (B, obs_horizon, obs_dim)
|
|
||||||
'action': batch['action'] # (B, pred_horizon, action_dim)
|
|
||||||
}
|
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# Forward pass & compute loss
|
# Forward pass & compute loss
|
||||||
@@ -217,6 +269,15 @@ def main(cfg: DictConfig):
|
|||||||
})
|
})
|
||||||
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
|
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# Validation
|
||||||
|
# =====================================================================
|
||||||
|
val_freq = int(cfg.train.get('val_freq', cfg.train.log_freq))
|
||||||
|
if val_loader is not None and val_freq > 0 and step % val_freq == 0:
|
||||||
|
val_loss = run_validation()
|
||||||
|
if val_loss is not None:
|
||||||
|
log.info(f"Step {step}/{cfg.train.max_steps} | Val Loss: {val_loss:.4f}")
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# Checkpoint saving
|
# Checkpoint saving
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
|
|||||||
@@ -19,4 +19,4 @@ obs_horizon: 2
|
|||||||
diffusion_steps: 100 # Number of diffusion timesteps for training
|
diffusion_steps: 100 # Number of diffusion timesteps for training
|
||||||
|
|
||||||
# Camera Configuration
|
# Camera Configuration
|
||||||
num_cams: ${oc.len:data.camera_names} # 自动从 data.camera_names 列表长度获取
|
num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取
|
||||||
Reference in New Issue
Block a user