refactor: 归一化从agent解耦到训练、推理脚本中

This commit is contained in:
gouhanke
2026-02-06 14:29:36 +08:00
parent a43a2e3d18
commit f4a5c77b7c
3 changed files with 72 additions and 80 deletions

View File

@@ -48,7 +48,8 @@ class VLAEvaluator:
pred_horizon: int = 16,
use_smoothing: bool = False,
smooth_method: str = 'ema',
smooth_alpha: float = 0.3
smooth_alpha: float = 0.3,
dataset_stats: dict = None
):
self.agent = agent.to(device)
self.device = device
@@ -57,6 +58,21 @@ class VLAEvaluator:
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
# Dataset statistics for normalization/denormalization
self.stats = dataset_stats
if self.stats is not None:
self.normalization_type = self.stats.get('normalization_type', 'gaussian')
self.qpos_mean = torch.tensor(self.stats['qpos_mean'], dtype=torch.float32)
self.qpos_std = torch.tensor(self.stats['qpos_std'], dtype=torch.float32)
self.qpos_min = torch.tensor(self.stats.get('qpos_min', []), dtype=torch.float32)
self.qpos_max = torch.tensor(self.stats.get('qpos_max', []), dtype=torch.float32)
self.action_mean = torch.tensor(self.stats['action_mean'], dtype=torch.float32)
self.action_std = torch.tensor(self.stats['action_std'], dtype=torch.float32)
self.action_min = torch.tensor(self.stats.get('action_min', []), dtype=torch.float32)
self.action_max = torch.tensor(self.stats.get('action_max', []), dtype=torch.float32)
else:
self.normalization_type = None
# Action smoothing
self.use_smoothing = use_smoothing
self.smooth_method = smooth_method
@@ -124,7 +140,15 @@ class VLAEvaluator:
if len(self.obs_buffer['qpos']) > self.obs_horizon:
self.obs_buffer['qpos'] = self.obs_buffer['qpos'][-self.obs_horizon:]
qpos_tensor = torch.stack(self.obs_buffer['qpos'], dim=0).unsqueeze(0)
qpos_tensor = torch.stack(self.obs_buffer['qpos'], dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
# Normalize qpos
if self.stats is not None:
if self.normalization_type == 'gaussian':
qpos_tensor = (qpos_tensor - self.qpos_mean) / self.qpos_std
else: # min_max: normalize to [-1, 1]
qpos_tensor = 2 * (qpos_tensor - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
return qpos_tensor
@torch.no_grad()
@@ -141,6 +165,13 @@ class VLAEvaluator:
proprioception=qpos
)
# Denormalize actions
if self.stats is not None:
if self.normalization_type == 'gaussian':
predicted_actions = predicted_actions * self.action_std.to(self.device) + self.action_mean.to(self.device)
else: # min_max
predicted_actions = (predicted_actions + 1) / 2 * (self.action_max.to(self.device) - self.action_min.to(self.device)) + self.action_min.to(self.device)
self.cached_actions = predicted_actions.squeeze(0).cpu().numpy()
self.query_step = 0
@@ -208,36 +239,29 @@ def load_checkpoint(
agent = instantiate(agent_cfg)
# Load model state
if 'model_state_dict' in checkpoint:
agent.load_state_dict(checkpoint['model_state_dict'])
log.info(f"✅ Model state loaded (step: {checkpoint.get('step', 'unknown')})")
elif 'state_dict' in checkpoint:
agent.load_state_dict(checkpoint['state_dict'])
log.info("✅ Model state loaded")
else:
agent.load_state_dict(checkpoint)
log.info("✅ Model state loaded")
agent.load_state_dict(checkpoint['model_state_dict'])
log.info(f"✅ Model state loaded (step: {checkpoint.get('step', 'unknown')})")
# Load dataset statistics for denormalization
stats_path = ckpt_path.parent / 'dataset_stats.json'
if stats_path.exists():
with open(stats_path, 'r') as f:
stats = json.load(f)
agent.action_mean = np.array(stats['action_mean'])
agent.action_std = np.array(stats['action_std'])
agent.qpos_mean = np.array(stats['qpos_mean'])
agent.qpos_std = np.array(stats['qpos_std'])
log.info("✅ Dataset statistics loaded for denormalization")
stats = checkpoint.get('dataset_stats', None)
if stats is not None:
log.info(f"✅ Dataset statistics loaded (normalization: {stats.get('normalization_type', 'gaussian')})")
else:
log.warning(f"⚠️ {stats_path} not found. Actions will not be denormalized!")
agent.action_mean = None
agent.action_std = None
# Fallback: try external JSON file (兼容旧 checkpoint)
stats_path = ckpt_path.parent / 'dataset_stats.json'
if stats_path.exists():
with open(stats_path, 'r') as f:
stats = json.load(f)
log.info("✅ Dataset statistics loaded from external JSON (legacy)")
else:
log.warning("⚠️ No dataset statistics found. Actions will not be denormalized!")
agent.eval()
agent.to(device)
log.info(f"✅ Model loaded successfully on {device}")
return agent
return agent, stats
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
@@ -262,7 +286,7 @@ def main(cfg: DictConfig):
# Load model
log.info(f"🚀 Loading model from {eval_cfg.ckpt_path}...")
agent = load_checkpoint(
agent, dataset_stats = load_checkpoint(
ckpt_path=eval_cfg.ckpt_path,
agent_cfg=cfg.agent,
device=device
@@ -277,7 +301,8 @@ def main(cfg: DictConfig):
obs_horizon=eval_cfg.obs_horizon,
use_smoothing=eval_cfg.use_smoothing,
smooth_method=eval_cfg.smooth_method,
smooth_alpha=eval_cfg.smooth_alpha
smooth_alpha=eval_cfg.smooth_alpha,
dataset_stats=dataset_stats
)
# Create environment
@@ -293,9 +318,6 @@ def main(cfg: DictConfig):
env.reset(box_pos)
evaluator.reset()
success = False
success_timestep = 0
with torch.inference_mode():
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"Episode {episode_idx + 1}"):
obs = env._get_image_obs()
@@ -307,17 +329,7 @@ def main(cfg: DictConfig):
env.render()
if env.rew == 1.0:
success = True
success_timestep = t
print(f"\n✅ Task completed at timestep {t}!")
break
print(f"\nEpisode {episode_idx + 1} Summary:")
print(f" Success: {success}")
if success:
print(f" Success Timestep: {success_timestep}")
print(f" Length: {t + 1} timesteps")
print(f"\nEpisode {episode_idx + 1} completed ({eval_cfg.max_timesteps} timesteps)")
print(f"\n{'='*60}")
print("Evaluation complete!")

View File

@@ -106,43 +106,36 @@ def main(cfg: DictConfig):
raise
# =========================================================================
# 2.5. Save Dataset Statistics as JSON
# 2.5. Load Dataset Statistics (will be saved into checkpoints)
# =========================================================================
log.info("💾 Saving dataset statistics...")
log.info("💾 Loading dataset statistics...")
dataset_stats = None
try:
# Get dataset_dir from config
dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer')
stats_path = Path(dataset_dir) / 'data_stats.pkl'
if stats_path.exists():
# Load pickle file
with open(stats_path, 'rb') as f:
stats = pickle.load(f)
# Extract action statistics
action_mean = stats['action']['mean'].tolist() if 'action' in stats else []
action_std = stats['action']['std'].tolist() if 'action' in stats else []
qpos_mean = stats['qpos']['mean'].tolist() if 'qpos' in stats else []
qpos_std = stats['qpos']['std'].tolist() if 'qpos' in stats else []
# Save as JSON
json_stats = {
'action_mean': action_mean,
'action_std': action_std,
'qpos_mean': qpos_mean,
'qpos_std': qpos_std
dataset_stats = {
'normalization_type': cfg.data.get('normalization_type', 'gaussian'),
'action_mean': stats['action']['mean'].tolist(),
'action_std': stats['action']['std'].tolist(),
'action_min': stats['action']['min'].tolist(),
'action_max': stats['action']['max'].tolist(),
'qpos_mean': stats['qpos']['mean'].tolist(),
'qpos_std': stats['qpos']['std'].tolist(),
'qpos_min': stats['qpos']['min'].tolist(),
'qpos_max': stats['qpos']['max'].tolist(),
}
json_path = checkpoint_dir / 'dataset_stats.json'
with open(json_path, 'w') as f:
json.dump(json_stats, f, indent=2)
log.info(f"✅ Dataset statistics saved to {json_path}")
log.info(f"✅ Dataset statistics loaded (normalization: {dataset_stats['normalization_type']})")
else:
log.warning(f"⚠️ Statistics file not found: {stats_path}")
log.warning("⚠️ Actions will not be denormalized during inference!")
except Exception as e:
log.warning(f"⚠️ Failed to save statistics as JSON: {e}")
log.warning(f"⚠️ Failed to load statistics: {e}")
log.warning("⚠️ Training will continue, but inference may not work correctly")
# =========================================================================
@@ -234,6 +227,7 @@ def main(cfg: DictConfig):
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
'dataset_stats': dataset_stats,
}, checkpoint_path)
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
@@ -246,6 +240,7 @@ def main(cfg: DictConfig):
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
'dataset_stats': dataset_stats,
}, best_model_path)
log.info(f"🌟 Best model updated: {best_model_path} (loss: {best_loss:.4f})")
@@ -258,6 +253,7 @@ def main(cfg: DictConfig):
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
'dataset_stats': dataset_stats,
}, final_model_path)
log.info(f"💾 Final model saved: {final_model_path}")

View File

@@ -107,14 +107,6 @@ class VLAAgent(nn.Module):
# 1. 提取当前观测特征 (只做一次)
visual_features = self.vision_encoder(images).view(B, -1)
proprioception = proprioception.view(B, -1)
if hasattr(self, 'qpos_mean') and hasattr(self, 'qpos_std') and self.qpos_mean is not None:
# Convert to tensor for normalization
qpos_mean = torch.from_numpy(self.qpos_mean).float().to(proprioception.device)
qpos_std = torch.from_numpy(self.qpos_std).float().to(proprioception.device)
qpos_mean = qpos_mean.repeat(2)
qpos_std = qpos_std.repeat(2)
# Normalize: (qpos - mean) / std
proprioception = (proprioception - qpos_mean.unsqueeze(0)) / qpos_std.unsqueeze(0)
global_cond = torch.cat([visual_features, proprioception], dim=-1)
# 2. 初始化纯高斯噪声动作
@@ -141,13 +133,5 @@ class VLAAgent(nn.Module):
noise_pred, t, current_actions
).prev_sample
# 4. 反归一化动作 (Denormalize actions)
if hasattr(self, 'action_mean') and hasattr(self, 'action_std') and self.action_mean is not None:
# Convert to numpy for denormalization
action_mean = torch.from_numpy(self.action_mean).float().to(current_actions.device)
action_std = torch.from_numpy(self.action_std).float().to(current_actions.device)
# Denormalize: action * std + mean
current_actions = current_actions * action_std.unsqueeze(0).unsqueeze(0) + action_mean.unsqueeze(0).unsqueeze(0)
# 5. 输出最终动作序列
return current_actions # 返回去噪后的干净动作
# 4. 输出最终动作序列(归一化空间,由调用方负责反归一化)
return current_actions