Files
roboimi/roboimi/vla/agent_imf.py

561 lines
24 KiB
Python

from __future__ import annotations
from contextlib import nullcontext
from collections import deque
from pathlib import Path
from typing import Any, Dict, Mapping, Optional, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from roboimi.vla.agent import VLAAgent
try:
from torch.func import jvp as TORCH_FUNC_JVP
except ImportError: # pragma: no cover
TORCH_FUNC_JVP = None
class IMFVLAAgent(VLAAgent):
def __init__(
self,
*args,
inference_steps: int = 1,
lewm_history_horizon: Optional[int] = None,
lewm_query_offsets: Optional[Sequence[int]] = None,
lewm_predictor: Optional[nn.Module] = None,
lewm_pred_projector: Optional[nn.Module] = None,
future_decoder: Optional[nn.Module] = None,
future_query_init_std: float = 0.02,
lewm_sigreg: Optional[nn.Module] = None,
lewm_sigreg_weight: float = 0.09,
lewm_loss_weight: float = 0.0,
lewm_pretrained_ckpt: Optional[str | Path | Mapping[str, Any]] = None,
**kwargs,
):
if inference_steps != 1:
raise ValueError(
'IMFVLAAgent only supports one-step inference; '
f'inference_steps must be 1, got {inference_steps}.'
)
lewm_query_offsets = tuple(int(offset) for offset in (lewm_query_offsets or ()))
inferred_extra_condition_tokens = len(lewm_query_offsets) if lewm_query_offsets else 0
default_extra_condition_tokens = (
0 if future_decoder is not None else inferred_extra_condition_tokens
)
kwargs.setdefault('extra_condition_tokens', default_extra_condition_tokens)
self.__dict__['lewm_history_horizon'] = int(lewm_history_horizon or kwargs.get('obs_horizon', 1))
self.__dict__['lewm_query_offsets'] = lewm_query_offsets
self.__dict__['lewm_predictor'] = lewm_predictor
self.__dict__['lewm_pred_projector'] = lewm_pred_projector or nn.Identity()
self.__dict__['future_decoder'] = future_decoder
self.__dict__['future_query_tokens'] = None
self.__dict__['future_query_init_std'] = float(future_query_init_std)
self.__dict__['lewm_sigreg'] = lewm_sigreg
self.__dict__['lewm_sigreg_weight'] = float(lewm_sigreg_weight)
self.__dict__['lewm_loss_weight'] = float(lewm_loss_weight)
self.__dict__['_last_loss_breakdown'] = {
'action_loss': 0.0,
'lewm_pred_loss': 0.0,
'lewm_sigreg_loss': 0.0,
'lewm_loss': 0.0,
'loss': 0.0,
}
super().__init__(*args, inference_steps=inference_steps, **kwargs)
self.inference_steps = 1
self.lewm_history_horizon = int(lewm_history_horizon or self.obs_horizon)
self.lewm_predictor = lewm_predictor
self.lewm_pred_projector = lewm_pred_projector or nn.Identity()
if future_decoder is not None and not isinstance(future_decoder, nn.Module):
self.future_decoder = future_decoder()
else:
self.future_decoder = future_decoder
self.future_query_tokens = None
self.future_query_init_std = float(future_query_init_std)
self.lewm_sigreg = lewm_sigreg
self.lewm_sigreg_weight = float(lewm_sigreg_weight)
if self.lewm_predictor is not None and self.future_decoder is not None:
raise ValueError('lewm_predictor and future_decoder are mutually exclusive')
if self.lewm_predictor is None and self.extra_condition_tokens > 0:
raise ValueError(
'extra_condition_tokens > 0 requires lewm_predictor to be provided'
)
if self.lewm_predictor is not None and self.extra_condition_tokens != inferred_extra_condition_tokens:
raise ValueError(
'extra_condition_tokens must equal len(lewm_query_offsets) when lewm_predictor is enabled'
)
if self.future_decoder is not None:
if inferred_extra_condition_tokens <= 0:
raise ValueError('future_decoder requires non-empty lewm_query_offsets')
if self.extra_condition_tokens != 0:
raise ValueError('future_decoder requires extra_condition_tokens=0')
self.future_query_tokens = nn.Parameter(
torch.randn(
1,
inferred_extra_condition_tokens,
self.per_step_cond_dim,
) * self.future_query_init_std
)
if lewm_pretrained_ckpt is not None:
self.load_lewm_pretrained_components(lewm_pretrained_ckpt)
@staticmethod
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
while value.ndim < reference.ndim:
value = value.unsqueeze(-1)
return value
@staticmethod
def _apply_conditioning(
trajectory: torch.Tensor,
condition_data: Optional[torch.Tensor] = None,
condition_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if condition_data is None or condition_mask is None:
return trajectory
conditioned = trajectory.clone()
conditioned[condition_mask] = condition_data[condition_mask]
return conditioned
@staticmethod
def _jvp_math_sdp_context(z_t: torch.Tensor):
if z_t.is_cuda:
return torch.backends.cuda.sdp_kernel(
enable_flash=False,
enable_math=True,
enable_mem_efficient=False,
enable_cudnn=False,
)
return nullcontext()
@staticmethod
def _jvp_tangents(v: torch.Tensor, r: torch.Tensor, t: torch.Tensor):
return v.detach(), torch.zeros_like(r), torch.ones_like(t)
def fn(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor, cond=None) -> torch.Tensor:
return self.noise_pred_net(z, r, t, cond=cond)
def _compute_u_and_du_dt(
self,
z_t: torch.Tensor,
r: torch.Tensor,
t: torch.Tensor,
cond,
v: torch.Tensor,
condition_data: Optional[torch.Tensor] = None,
condition_mask: Optional[torch.Tensor] = None,
):
tangents = self._jvp_tangents(v, r, t)
def g(z, r_value, t_value):
conditioned_z = self._apply_conditioning(z, condition_data, condition_mask)
return self.fn(conditioned_z, r_value, t_value, cond=cond)
with self._jvp_math_sdp_context(z_t):
if TORCH_FUNC_JVP is not None:
try:
return TORCH_FUNC_JVP(g, (z_t, r, t), tangents)
except (RuntimeError, TypeError, NotImplementedError):
pass
u = g(z_t, r, t)
_, du_dt = torch.autograd.functional.jvp(
g,
(z_t, r, t),
tangents,
create_graph=False,
strict=False,
)
return u, du_dt
def _compound_velocity(
self,
u: torch.Tensor,
du_dt: torch.Tensor,
r: torch.Tensor,
t: torch.Tensor,
) -> torch.Tensor:
delta = self._broadcast_batch_time(t - r, u)
return u + delta * du_dt.detach()
def _sample_one_step(
self,
z_t: torch.Tensor,
r: Optional[torch.Tensor] = None,
t: Optional[torch.Tensor] = None,
cond=None,
) -> torch.Tensor:
batch_size = z_t.shape[0]
if t is None:
t = torch.ones(batch_size, device=z_t.device, dtype=z_t.dtype)
if r is None:
r = torch.zeros(batch_size, device=z_t.device, dtype=z_t.dtype)
u = self.fn(z_t, r, t, cond=cond)
delta = self._broadcast_batch_time(t - r, z_t)
return z_t - delta * u
def _normalize_qpos_for_lewm(self, qpos: torch.Tensor) -> torch.Tensor:
if not self.normalization.enabled:
return qpos
qpos_mean = getattr(self.normalization, 'qpos_mean', None)
qpos_std = getattr(self.normalization, 'qpos_std', None)
if qpos_mean is not None and qpos_std is not None:
return (qpos - qpos_mean) / qpos_std
if isinstance(self.dataset_stats, dict):
mean = self.dataset_stats.get('qpos_mean', None)
std = self.dataset_stats.get('qpos_std', None)
if mean is not None and std is not None:
mean = torch.as_tensor(mean, dtype=qpos.dtype, device=qpos.device)
std = torch.as_tensor(std, dtype=qpos.dtype, device=qpos.device)
return (qpos - mean) / std
return self.normalization.normalize_qpos(qpos)
def _project_lewm_future_tokens(self, predicted_tokens: torch.Tensor) -> torch.Tensor:
if predicted_tokens.ndim != 3:
raise ValueError(
f"expected predicted future tokens to be 3D, got rank {predicted_tokens.ndim}"
)
batch_size, token_count, token_dim = predicted_tokens.shape
flattened = predicted_tokens.reshape(batch_size * token_count, token_dim)
projected = self.lewm_pred_projector(flattened)
if projected.ndim != 2:
raise ValueError(
f"expected lewm_pred_projector to return rank-2 tensors, got rank {projected.ndim}"
)
return projected.reshape(batch_size, token_count, projected.shape[-1])
@staticmethod
def _load_checkpoint_payload(
checkpoint_or_path: str | Path | Mapping[str, Any],
) -> Mapping[str, torch.Tensor]:
if isinstance(checkpoint_or_path, (str, Path)):
payload = torch.load(Path(checkpoint_or_path), map_location='cpu', weights_only=False)
else:
payload = checkpoint_or_path
state_dict = payload.get('state_dict', payload)
if not isinstance(state_dict, Mapping):
raise TypeError('checkpoint payload must contain a mapping state_dict')
return state_dict
@staticmethod
def _extract_prefixed_state_dict(
state_dict: Mapping[str, torch.Tensor],
prefix: str,
) -> Dict[str, torch.Tensor]:
extracted = {
key[len(prefix):]: value
for key, value in state_dict.items()
if key.startswith(prefix)
}
if not extracted:
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
return extracted
@staticmethod
def _adapt_and_load_state_dict(
module: nn.Module,
incoming_state_dict: Mapping[str, torch.Tensor],
*,
query_key: str = 'query_tokens',
pos_key: str = 'pos_embedding',
) -> Dict[str, Sequence[str]]:
current_state_dict = module.state_dict()
adapted_state_dict = dict(current_state_dict)
loaded_keys = []
mismatched_keys = []
missing_keys = []
for key, current_tensor in current_state_dict.items():
if key not in incoming_state_dict:
continue
source_tensor = incoming_state_dict[key]
if source_tensor.shape == current_tensor.shape:
adapted_state_dict[key] = source_tensor
loaded_keys.append(key)
continue
if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim:
patched = current_tensor.clone()
if key == query_key:
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...]
if copy_count < current_tensor.shape[1] and copy_count > 0:
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
else:
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...]
if copy_count < current_tensor.shape[1] and copy_count > 0:
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
adapted_state_dict[key] = patched
loaded_keys.append(key)
continue
mismatched_keys.append(key)
for key in incoming_state_dict.keys():
if key not in current_state_dict:
missing_keys.append(key)
module.load_state_dict(adapted_state_dict, strict=True)
return {
'loaded_keys': tuple(sorted(loaded_keys)),
'mismatched_keys': tuple(sorted(set(mismatched_keys))),
'missing_keys': tuple(sorted(set(missing_keys))),
}
@staticmethod
def _load_state_dict_ignoring_shape_mismatches(
module: nn.Module,
incoming_state_dict: Mapping[str, torch.Tensor],
) -> Dict[str, Sequence[str]]:
current_state_dict = module.state_dict()
merged_state_dict = dict(current_state_dict)
loaded_keys = []
mismatched_keys = []
missing_keys = []
for key, value in incoming_state_dict.items():
if key not in current_state_dict:
missing_keys.append(key)
continue
if current_state_dict[key].shape != value.shape:
mismatched_keys.append(key)
continue
merged_state_dict[key] = value
loaded_keys.append(key)
module.load_state_dict(merged_state_dict, strict=True)
return {
'loaded_keys': tuple(sorted(loaded_keys)),
'mismatched_keys': tuple(sorted(mismatched_keys)),
'missing_keys': tuple(sorted(missing_keys)),
}
def load_lewm_pretrained_components(
self,
checkpoint_or_path: str | Path | Mapping[str, Any],
) -> None:
state_dict = self._load_checkpoint_payload(checkpoint_or_path)
if hasattr(self.vision_encoder, 'load_lewm_checkpoint'):
try:
self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict})
except RuntimeError:
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict)
else:
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict)
state_encoder_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.state_encoder.')
self._load_state_dict_ignoring_shape_mismatches(self.state_encoder, state_encoder_state_dict)
projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.projector.proj.')
mapped_projector_state_dict = {
f'linear.{key}': value
for key, value in projector_state_dict.items()
}
self._load_state_dict_ignoring_shape_mismatches(self.cond_projector, mapped_projector_state_dict)
if self.lewm_predictor is not None:
predictor_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.predictor.')
self._adapt_and_load_state_dict(self.lewm_predictor, predictor_state_dict)
if self.lewm_pred_projector is not None:
pred_projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.pred_proj.')
self._load_state_dict_ignoring_shape_mismatches(
self.lewm_pred_projector,
pred_projector_state_dict,
)
def _predict_future_tokens_with_decoder(self, history_cond: torch.Tensor) -> torch.Tensor:
if self.future_decoder is None or self.future_query_tokens is None:
raise RuntimeError('future_decoder path requested but not initialized')
batch_size = history_cond.shape[0]
query_tokens = self.future_query_tokens.expand(batch_size, -1, -1)
r = torch.zeros(batch_size, device=history_cond.device, dtype=history_cond.dtype)
t = torch.ones(batch_size, device=history_cond.device, dtype=history_cond.dtype)
return self.future_decoder(query_tokens, r, t, cond=history_cond)
def _build_full_condition(
self,
images,
proprioception,
*,
lewm_images=None,
lewm_proprioception=None,
):
normalized_proprioception = self.normalization.normalize_qpos(proprioception)
history_cond = self._build_cond(images, normalized_proprioception)
predicted_future_tokens = None
lewm_history_cond = None
if self.lewm_predictor is None and self.future_decoder is None:
return history_cond, predicted_future_tokens, lewm_history_cond
lewm_images = lewm_images if lewm_images is not None else images
lewm_proprioception = (
lewm_proprioception if lewm_proprioception is not None else proprioception
)
lewm_history_cond = self._build_cond(
lewm_images,
self._normalize_qpos_for_lewm(lewm_proprioception),
)
cond = history_cond
if self.lewm_predictor is not None:
predicted_future_tokens = self.lewm_predictor(lewm_history_cond)
predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens)
cond = torch.cat([history_cond, predicted_future_tokens], dim=1)
if cond.shape[1] != self.condition_sequence_length:
raise RuntimeError(
f"完整条件序列长度不匹配: got {cond.shape[1]}, expected {self.condition_sequence_length}"
)
if cond.shape[-1] != self.per_step_cond_dim:
raise RuntimeError(
f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
)
elif self.future_decoder is not None:
predicted_future_tokens = self._predict_future_tokens_with_decoder(lewm_history_cond)
return cond, predicted_future_tokens, lewm_history_cond
@staticmethod
def _masked_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return F.mse_loss(pred, target)
def compute_loss(self, batch):
actions, states, images = batch['action'], batch['qpos'], batch['images']
action_is_pad = batch.get('action_is_pad', None)
batch_size = actions.shape[0]
actions = self.normalization.normalize_action(actions)
cond, predicted_future_tokens, lewm_history_cond = self._build_full_condition(
images,
states,
lewm_images=batch.get('lewm_images', None),
lewm_proprioception=batch.get('lewm_qpos', None),
)
x = actions
e = torch.randn_like(x)
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
t, r = torch.maximum(t, r), torch.minimum(t, r)
t_broadcast = self._broadcast_batch_time(t, x)
z_t = (1 - t_broadcast) * x + t_broadcast * e
v = self.fn(z_t, t, t, cond=cond)
u, du_dt = self._compute_u_and_du_dt(z_t, r, t, cond=cond, v=v)
V = self._compound_velocity(u, du_dt, r, t)
target = e - x
loss = F.mse_loss(V, target, reduction='none')
if action_is_pad is not None:
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
valid_count = mask.sum() * loss.shape[-1]
action_loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
else:
action_loss = loss.mean()
lewm_pred_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
lewm_sigreg_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
if predicted_future_tokens is not None:
lewm_future_images = batch.get('lewm_future_images', None)
lewm_future_qpos = batch.get('lewm_future_qpos', None)
if lewm_future_images is not None and lewm_future_qpos is not None:
future_target = self._build_cond(
lewm_future_images,
self._normalize_qpos_for_lewm(lewm_future_qpos),
)
lewm_pred_loss = self._masked_mse_loss(predicted_future_tokens, future_target)
if self.lewm_sigreg is not None and lewm_history_cond is not None:
lewm_sigreg_loss = self.lewm_sigreg(lewm_history_cond.transpose(0, 1))
lewm_loss = lewm_pred_loss + self.lewm_sigreg_weight * lewm_sigreg_loss
total_loss = action_loss + self.lewm_loss_weight * lewm_loss
self._last_loss_breakdown = {
'action_loss': float(action_loss.detach().item()),
'lewm_pred_loss': float(lewm_pred_loss.detach().item()),
'lewm_sigreg_loss': float(lewm_sigreg_loss.detach().item()),
'lewm_loss': float(lewm_loss.detach().item()),
'loss': float(total_loss.detach().item()),
}
return total_loss
def get_last_loss_breakdown(self) -> Dict[str, float]:
return dict(self._last_loss_breakdown)
def reset(self):
super().reset()
if self.lewm_predictor is not None:
self._queues['lewm_qpos'] = deque(maxlen=self.lewm_history_horizon)
self._queues['lewm_images'] = deque(maxlen=self.lewm_history_horizon)
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
super()._populate_queues(observation)
if self.lewm_predictor is None:
return
if 'qpos' in observation:
self._queues['lewm_qpos'].append(observation['qpos'].clone())
if 'images' in observation:
ordered_images = self._order_images(observation['images'])
self._queues['lewm_images'].append({k: v.clone() for k, v in ordered_images.items()})
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
batch = super()._prepare_observation_batch()
if self.lewm_predictor is None:
return batch
qpos_list = list(self._queues['lewm_qpos'])
images_list = list(self._queues['lewm_images'])
if len(qpos_list) == 0 or len(images_list) == 0:
raise ValueError("LeWM 观测队列为空,请先调用 _populate_queues 添加观测")
while len(qpos_list) < self.lewm_history_horizon:
qpos_list.append(qpos_list[-1])
while len(images_list) < self.lewm_history_horizon:
images_list.append(images_list[-1])
batch['lewm_qpos'] = torch.stack(qpos_list, dim=0).unsqueeze(0)
batch['lewm_images'] = {}
camera_names = self.camera_names if self.camera_names is not None else tuple(sorted(images_list[0].keys()))
for cam_name in camera_names:
batch['lewm_images'][cam_name] = torch.stack(
[img[cam_name] for img in images_list],
dim=0,
).unsqueeze(0)
return batch
@torch.no_grad()
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
return self.predict_action(
batch['images'],
batch['qpos'],
lewm_images=batch.get('lewm_images', None),
lewm_proprioception=batch.get('lewm_qpos', None),
)
@torch.no_grad()
def predict_action(
self,
images,
proprioception,
*,
lewm_images=None,
lewm_proprioception=None,
):
batch_size = proprioception.shape[0]
if self.lewm_predictor is not None:
cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition(
images,
proprioception,
lewm_images=lewm_images,
lewm_proprioception=lewm_proprioception,
)
else:
cond = self._build_cond(
images,
self.normalization.normalize_qpos(proprioception),
)
z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype)
action = self._sample_one_step(z_t, cond=cond)
return self.normalization.denormalize_action(action)