refactor: 归一化从agent解耦到训练、推理脚本中
This commit is contained in:
@@ -48,7 +48,8 @@ class VLAEvaluator:
|
||||
pred_horizon: int = 16,
|
||||
use_smoothing: bool = False,
|
||||
smooth_method: str = 'ema',
|
||||
smooth_alpha: float = 0.3
|
||||
smooth_alpha: float = 0.3,
|
||||
dataset_stats: dict = None
|
||||
):
|
||||
self.agent = agent.to(device)
|
||||
self.device = device
|
||||
@@ -57,6 +58,21 @@ class VLAEvaluator:
|
||||
self.obs_horizon = obs_horizon
|
||||
self.pred_horizon = pred_horizon
|
||||
|
||||
# Dataset statistics for normalization/denormalization
|
||||
self.stats = dataset_stats
|
||||
if self.stats is not None:
|
||||
self.normalization_type = self.stats.get('normalization_type', 'gaussian')
|
||||
self.qpos_mean = torch.tensor(self.stats['qpos_mean'], dtype=torch.float32)
|
||||
self.qpos_std = torch.tensor(self.stats['qpos_std'], dtype=torch.float32)
|
||||
self.qpos_min = torch.tensor(self.stats.get('qpos_min', []), dtype=torch.float32)
|
||||
self.qpos_max = torch.tensor(self.stats.get('qpos_max', []), dtype=torch.float32)
|
||||
self.action_mean = torch.tensor(self.stats['action_mean'], dtype=torch.float32)
|
||||
self.action_std = torch.tensor(self.stats['action_std'], dtype=torch.float32)
|
||||
self.action_min = torch.tensor(self.stats.get('action_min', []), dtype=torch.float32)
|
||||
self.action_max = torch.tensor(self.stats.get('action_max', []), dtype=torch.float32)
|
||||
else:
|
||||
self.normalization_type = None
|
||||
|
||||
# Action smoothing
|
||||
self.use_smoothing = use_smoothing
|
||||
self.smooth_method = smooth_method
|
||||
@@ -124,7 +140,15 @@ class VLAEvaluator:
|
||||
if len(self.obs_buffer['qpos']) > self.obs_horizon:
|
||||
self.obs_buffer['qpos'] = self.obs_buffer['qpos'][-self.obs_horizon:]
|
||||
|
||||
qpos_tensor = torch.stack(self.obs_buffer['qpos'], dim=0).unsqueeze(0)
|
||||
qpos_tensor = torch.stack(self.obs_buffer['qpos'], dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
|
||||
|
||||
# Normalize qpos
|
||||
if self.stats is not None:
|
||||
if self.normalization_type == 'gaussian':
|
||||
qpos_tensor = (qpos_tensor - self.qpos_mean) / self.qpos_std
|
||||
else: # min_max: normalize to [-1, 1]
|
||||
qpos_tensor = 2 * (qpos_tensor - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
|
||||
|
||||
return qpos_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -141,6 +165,13 @@ class VLAEvaluator:
|
||||
proprioception=qpos
|
||||
)
|
||||
|
||||
# Denormalize actions
|
||||
if self.stats is not None:
|
||||
if self.normalization_type == 'gaussian':
|
||||
predicted_actions = predicted_actions * self.action_std.to(self.device) + self.action_mean.to(self.device)
|
||||
else: # min_max
|
||||
predicted_actions = (predicted_actions + 1) / 2 * (self.action_max.to(self.device) - self.action_min.to(self.device)) + self.action_min.to(self.device)
|
||||
|
||||
self.cached_actions = predicted_actions.squeeze(0).cpu().numpy()
|
||||
self.query_step = 0
|
||||
|
||||
@@ -208,36 +239,29 @@ def load_checkpoint(
|
||||
agent = instantiate(agent_cfg)
|
||||
|
||||
# Load model state
|
||||
if 'model_state_dict' in checkpoint:
|
||||
agent.load_state_dict(checkpoint['model_state_dict'])
|
||||
log.info(f"✅ Model state loaded (step: {checkpoint.get('step', 'unknown')})")
|
||||
elif 'state_dict' in checkpoint:
|
||||
agent.load_state_dict(checkpoint['state_dict'])
|
||||
log.info("✅ Model state loaded")
|
||||
else:
|
||||
agent.load_state_dict(checkpoint)
|
||||
log.info("✅ Model state loaded")
|
||||
|
||||
# Load dataset statistics for denormalization
|
||||
stats = checkpoint.get('dataset_stats', None)
|
||||
|
||||
if stats is not None:
|
||||
log.info(f"✅ Dataset statistics loaded (normalization: {stats.get('normalization_type', 'gaussian')})")
|
||||
else:
|
||||
# Fallback: try external JSON file (兼容旧 checkpoint)
|
||||
stats_path = ckpt_path.parent / 'dataset_stats.json'
|
||||
if stats_path.exists():
|
||||
with open(stats_path, 'r') as f:
|
||||
stats = json.load(f)
|
||||
agent.action_mean = np.array(stats['action_mean'])
|
||||
agent.action_std = np.array(stats['action_std'])
|
||||
agent.qpos_mean = np.array(stats['qpos_mean'])
|
||||
agent.qpos_std = np.array(stats['qpos_std'])
|
||||
log.info("✅ Dataset statistics loaded for denormalization")
|
||||
log.info("✅ Dataset statistics loaded from external JSON (legacy)")
|
||||
else:
|
||||
log.warning(f"⚠️ {stats_path} not found. Actions will not be denormalized!")
|
||||
agent.action_mean = None
|
||||
agent.action_std = None
|
||||
log.warning("⚠️ No dataset statistics found. Actions will not be denormalized!")
|
||||
|
||||
agent.eval()
|
||||
agent.to(device)
|
||||
|
||||
log.info(f"✅ Model loaded successfully on {device}")
|
||||
return agent
|
||||
return agent, stats
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||||
@@ -262,7 +286,7 @@ def main(cfg: DictConfig):
|
||||
|
||||
# Load model
|
||||
log.info(f"🚀 Loading model from {eval_cfg.ckpt_path}...")
|
||||
agent = load_checkpoint(
|
||||
agent, dataset_stats = load_checkpoint(
|
||||
ckpt_path=eval_cfg.ckpt_path,
|
||||
agent_cfg=cfg.agent,
|
||||
device=device
|
||||
@@ -277,7 +301,8 @@ def main(cfg: DictConfig):
|
||||
obs_horizon=eval_cfg.obs_horizon,
|
||||
use_smoothing=eval_cfg.use_smoothing,
|
||||
smooth_method=eval_cfg.smooth_method,
|
||||
smooth_alpha=eval_cfg.smooth_alpha
|
||||
smooth_alpha=eval_cfg.smooth_alpha,
|
||||
dataset_stats=dataset_stats
|
||||
)
|
||||
|
||||
# Create environment
|
||||
@@ -293,9 +318,6 @@ def main(cfg: DictConfig):
|
||||
env.reset(box_pos)
|
||||
evaluator.reset()
|
||||
|
||||
success = False
|
||||
success_timestep = 0
|
||||
|
||||
with torch.inference_mode():
|
||||
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"Episode {episode_idx + 1}"):
|
||||
obs = env._get_image_obs()
|
||||
@@ -307,17 +329,7 @@ def main(cfg: DictConfig):
|
||||
|
||||
env.render()
|
||||
|
||||
if env.rew == 1.0:
|
||||
success = True
|
||||
success_timestep = t
|
||||
print(f"\n✅ Task completed at timestep {t}!")
|
||||
break
|
||||
|
||||
print(f"\nEpisode {episode_idx + 1} Summary:")
|
||||
print(f" Success: {success}")
|
||||
if success:
|
||||
print(f" Success Timestep: {success_timestep}")
|
||||
print(f" Length: {t + 1} timesteps")
|
||||
print(f"\nEpisode {episode_idx + 1} completed ({eval_cfg.max_timesteps} timesteps)")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("Evaluation complete!")
|
||||
|
||||
@@ -106,43 +106,36 @@ def main(cfg: DictConfig):
|
||||
raise
|
||||
|
||||
# =========================================================================
|
||||
# 2.5. Save Dataset Statistics as JSON
|
||||
# 2.5. Load Dataset Statistics (will be saved into checkpoints)
|
||||
# =========================================================================
|
||||
log.info("💾 Saving dataset statistics...")
|
||||
log.info("💾 Loading dataset statistics...")
|
||||
dataset_stats = None
|
||||
try:
|
||||
# Get dataset_dir from config
|
||||
dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer')
|
||||
stats_path = Path(dataset_dir) / 'data_stats.pkl'
|
||||
|
||||
if stats_path.exists():
|
||||
# Load pickle file
|
||||
with open(stats_path, 'rb') as f:
|
||||
stats = pickle.load(f)
|
||||
|
||||
# Extract action statistics
|
||||
action_mean = stats['action']['mean'].tolist() if 'action' in stats else []
|
||||
action_std = stats['action']['std'].tolist() if 'action' in stats else []
|
||||
qpos_mean = stats['qpos']['mean'].tolist() if 'qpos' in stats else []
|
||||
qpos_std = stats['qpos']['std'].tolist() if 'qpos' in stats else []
|
||||
|
||||
# Save as JSON
|
||||
json_stats = {
|
||||
'action_mean': action_mean,
|
||||
'action_std': action_std,
|
||||
'qpos_mean': qpos_mean,
|
||||
'qpos_std': qpos_std
|
||||
dataset_stats = {
|
||||
'normalization_type': cfg.data.get('normalization_type', 'gaussian'),
|
||||
'action_mean': stats['action']['mean'].tolist(),
|
||||
'action_std': stats['action']['std'].tolist(),
|
||||
'action_min': stats['action']['min'].tolist(),
|
||||
'action_max': stats['action']['max'].tolist(),
|
||||
'qpos_mean': stats['qpos']['mean'].tolist(),
|
||||
'qpos_std': stats['qpos']['std'].tolist(),
|
||||
'qpos_min': stats['qpos']['min'].tolist(),
|
||||
'qpos_max': stats['qpos']['max'].tolist(),
|
||||
}
|
||||
json_path = checkpoint_dir / 'dataset_stats.json'
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(json_stats, f, indent=2)
|
||||
|
||||
log.info(f"✅ Dataset statistics saved to {json_path}")
|
||||
log.info(f"✅ Dataset statistics loaded (normalization: {dataset_stats['normalization_type']})")
|
||||
else:
|
||||
log.warning(f"⚠️ Statistics file not found: {stats_path}")
|
||||
log.warning("⚠️ Actions will not be denormalized during inference!")
|
||||
|
||||
except Exception as e:
|
||||
log.warning(f"⚠️ Failed to save statistics as JSON: {e}")
|
||||
log.warning(f"⚠️ Failed to load statistics: {e}")
|
||||
log.warning("⚠️ Training will continue, but inference may not work correctly")
|
||||
|
||||
# =========================================================================
|
||||
@@ -234,6 +227,7 @@ def main(cfg: DictConfig):
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'dataset_stats': dataset_stats,
|
||||
}, checkpoint_path)
|
||||
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
|
||||
|
||||
@@ -246,6 +240,7 @@ def main(cfg: DictConfig):
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'dataset_stats': dataset_stats,
|
||||
}, best_model_path)
|
||||
log.info(f"🌟 Best model updated: {best_model_path} (loss: {best_loss:.4f})")
|
||||
|
||||
@@ -258,6 +253,7 @@ def main(cfg: DictConfig):
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'dataset_stats': dataset_stats,
|
||||
}, final_model_path)
|
||||
log.info(f"💾 Final model saved: {final_model_path}")
|
||||
|
||||
|
||||
@@ -107,14 +107,6 @@ class VLAAgent(nn.Module):
|
||||
# 1. 提取当前观测特征 (只做一次)
|
||||
visual_features = self.vision_encoder(images).view(B, -1)
|
||||
proprioception = proprioception.view(B, -1)
|
||||
if hasattr(self, 'qpos_mean') and hasattr(self, 'qpos_std') and self.qpos_mean is not None:
|
||||
# Convert to tensor for normalization
|
||||
qpos_mean = torch.from_numpy(self.qpos_mean).float().to(proprioception.device)
|
||||
qpos_std = torch.from_numpy(self.qpos_std).float().to(proprioception.device)
|
||||
qpos_mean = qpos_mean.repeat(2)
|
||||
qpos_std = qpos_std.repeat(2)
|
||||
# Normalize: (qpos - mean) / std
|
||||
proprioception = (proprioception - qpos_mean.unsqueeze(0)) / qpos_std.unsqueeze(0)
|
||||
global_cond = torch.cat([visual_features, proprioception], dim=-1)
|
||||
|
||||
# 2. 初始化纯高斯噪声动作
|
||||
@@ -141,13 +133,5 @@ class VLAAgent(nn.Module):
|
||||
noise_pred, t, current_actions
|
||||
).prev_sample
|
||||
|
||||
# 4. 反归一化动作 (Denormalize actions)
|
||||
if hasattr(self, 'action_mean') and hasattr(self, 'action_std') and self.action_mean is not None:
|
||||
# Convert to numpy for denormalization
|
||||
action_mean = torch.from_numpy(self.action_mean).float().to(current_actions.device)
|
||||
action_std = torch.from_numpy(self.action_std).float().to(current_actions.device)
|
||||
# Denormalize: action * std + mean
|
||||
current_actions = current_actions * action_std.unsqueeze(0).unsqueeze(0) + action_mean.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# 5. 输出最终动作序列
|
||||
return current_actions # 返回去噪后的干净动作
|
||||
# 4. 输出最终动作序列(归一化空间,由调用方负责反归一化)
|
||||
return current_actions
|
||||
Reference in New Issue
Block a user