561 lines
24 KiB
Python
561 lines
24 KiB
Python
from __future__ import annotations
|
|
|
|
from contextlib import nullcontext
|
|
from collections import deque
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Mapping, Optional, Sequence
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from roboimi.vla.agent import VLAAgent
|
|
|
|
try:
|
|
from torch.func import jvp as TORCH_FUNC_JVP
|
|
except ImportError: # pragma: no cover
|
|
TORCH_FUNC_JVP = None
|
|
|
|
|
|
class IMFVLAAgent(VLAAgent):
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
inference_steps: int = 1,
|
|
lewm_history_horizon: Optional[int] = None,
|
|
lewm_query_offsets: Optional[Sequence[int]] = None,
|
|
lewm_predictor: Optional[nn.Module] = None,
|
|
lewm_pred_projector: Optional[nn.Module] = None,
|
|
future_decoder: Optional[nn.Module] = None,
|
|
future_query_init_std: float = 0.02,
|
|
lewm_sigreg: Optional[nn.Module] = None,
|
|
lewm_sigreg_weight: float = 0.09,
|
|
lewm_loss_weight: float = 0.0,
|
|
lewm_pretrained_ckpt: Optional[str | Path | Mapping[str, Any]] = None,
|
|
**kwargs,
|
|
):
|
|
if inference_steps != 1:
|
|
raise ValueError(
|
|
'IMFVLAAgent only supports one-step inference; '
|
|
f'inference_steps must be 1, got {inference_steps}.'
|
|
)
|
|
lewm_query_offsets = tuple(int(offset) for offset in (lewm_query_offsets or ()))
|
|
inferred_extra_condition_tokens = len(lewm_query_offsets) if lewm_query_offsets else 0
|
|
default_extra_condition_tokens = (
|
|
0 if future_decoder is not None else inferred_extra_condition_tokens
|
|
)
|
|
kwargs.setdefault('extra_condition_tokens', default_extra_condition_tokens)
|
|
self.__dict__['lewm_history_horizon'] = int(lewm_history_horizon or kwargs.get('obs_horizon', 1))
|
|
self.__dict__['lewm_query_offsets'] = lewm_query_offsets
|
|
self.__dict__['lewm_predictor'] = lewm_predictor
|
|
self.__dict__['lewm_pred_projector'] = lewm_pred_projector or nn.Identity()
|
|
self.__dict__['future_decoder'] = future_decoder
|
|
self.__dict__['future_query_tokens'] = None
|
|
self.__dict__['future_query_init_std'] = float(future_query_init_std)
|
|
self.__dict__['lewm_sigreg'] = lewm_sigreg
|
|
self.__dict__['lewm_sigreg_weight'] = float(lewm_sigreg_weight)
|
|
self.__dict__['lewm_loss_weight'] = float(lewm_loss_weight)
|
|
self.__dict__['_last_loss_breakdown'] = {
|
|
'action_loss': 0.0,
|
|
'lewm_pred_loss': 0.0,
|
|
'lewm_sigreg_loss': 0.0,
|
|
'lewm_loss': 0.0,
|
|
'loss': 0.0,
|
|
}
|
|
super().__init__(*args, inference_steps=inference_steps, **kwargs)
|
|
self.inference_steps = 1
|
|
self.lewm_history_horizon = int(lewm_history_horizon or self.obs_horizon)
|
|
self.lewm_predictor = lewm_predictor
|
|
self.lewm_pred_projector = lewm_pred_projector or nn.Identity()
|
|
if future_decoder is not None and not isinstance(future_decoder, nn.Module):
|
|
self.future_decoder = future_decoder()
|
|
else:
|
|
self.future_decoder = future_decoder
|
|
self.future_query_tokens = None
|
|
self.future_query_init_std = float(future_query_init_std)
|
|
self.lewm_sigreg = lewm_sigreg
|
|
self.lewm_sigreg_weight = float(lewm_sigreg_weight)
|
|
if self.lewm_predictor is not None and self.future_decoder is not None:
|
|
raise ValueError('lewm_predictor and future_decoder are mutually exclusive')
|
|
if self.lewm_predictor is None and self.extra_condition_tokens > 0:
|
|
raise ValueError(
|
|
'extra_condition_tokens > 0 requires lewm_predictor to be provided'
|
|
)
|
|
if self.lewm_predictor is not None and self.extra_condition_tokens != inferred_extra_condition_tokens:
|
|
raise ValueError(
|
|
'extra_condition_tokens must equal len(lewm_query_offsets) when lewm_predictor is enabled'
|
|
)
|
|
if self.future_decoder is not None:
|
|
if inferred_extra_condition_tokens <= 0:
|
|
raise ValueError('future_decoder requires non-empty lewm_query_offsets')
|
|
if self.extra_condition_tokens != 0:
|
|
raise ValueError('future_decoder requires extra_condition_tokens=0')
|
|
self.future_query_tokens = nn.Parameter(
|
|
torch.randn(
|
|
1,
|
|
inferred_extra_condition_tokens,
|
|
self.per_step_cond_dim,
|
|
) * self.future_query_init_std
|
|
)
|
|
if lewm_pretrained_ckpt is not None:
|
|
self.load_lewm_pretrained_components(lewm_pretrained_ckpt)
|
|
|
|
@staticmethod
|
|
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
|
|
while value.ndim < reference.ndim:
|
|
value = value.unsqueeze(-1)
|
|
return value
|
|
|
|
@staticmethod
|
|
def _apply_conditioning(
|
|
trajectory: torch.Tensor,
|
|
condition_data: Optional[torch.Tensor] = None,
|
|
condition_mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if condition_data is None or condition_mask is None:
|
|
return trajectory
|
|
conditioned = trajectory.clone()
|
|
conditioned[condition_mask] = condition_data[condition_mask]
|
|
return conditioned
|
|
|
|
@staticmethod
|
|
def _jvp_math_sdp_context(z_t: torch.Tensor):
|
|
if z_t.is_cuda:
|
|
return torch.backends.cuda.sdp_kernel(
|
|
enable_flash=False,
|
|
enable_math=True,
|
|
enable_mem_efficient=False,
|
|
enable_cudnn=False,
|
|
)
|
|
return nullcontext()
|
|
|
|
@staticmethod
|
|
def _jvp_tangents(v: torch.Tensor, r: torch.Tensor, t: torch.Tensor):
|
|
return v.detach(), torch.zeros_like(r), torch.ones_like(t)
|
|
|
|
def fn(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor, cond=None) -> torch.Tensor:
|
|
return self.noise_pred_net(z, r, t, cond=cond)
|
|
|
|
def _compute_u_and_du_dt(
|
|
self,
|
|
z_t: torch.Tensor,
|
|
r: torch.Tensor,
|
|
t: torch.Tensor,
|
|
cond,
|
|
v: torch.Tensor,
|
|
condition_data: Optional[torch.Tensor] = None,
|
|
condition_mask: Optional[torch.Tensor] = None,
|
|
):
|
|
tangents = self._jvp_tangents(v, r, t)
|
|
|
|
def g(z, r_value, t_value):
|
|
conditioned_z = self._apply_conditioning(z, condition_data, condition_mask)
|
|
return self.fn(conditioned_z, r_value, t_value, cond=cond)
|
|
|
|
with self._jvp_math_sdp_context(z_t):
|
|
if TORCH_FUNC_JVP is not None:
|
|
try:
|
|
return TORCH_FUNC_JVP(g, (z_t, r, t), tangents)
|
|
except (RuntimeError, TypeError, NotImplementedError):
|
|
pass
|
|
|
|
u = g(z_t, r, t)
|
|
_, du_dt = torch.autograd.functional.jvp(
|
|
g,
|
|
(z_t, r, t),
|
|
tangents,
|
|
create_graph=False,
|
|
strict=False,
|
|
)
|
|
return u, du_dt
|
|
|
|
def _compound_velocity(
|
|
self,
|
|
u: torch.Tensor,
|
|
du_dt: torch.Tensor,
|
|
r: torch.Tensor,
|
|
t: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
delta = self._broadcast_batch_time(t - r, u)
|
|
return u + delta * du_dt.detach()
|
|
|
|
def _sample_one_step(
|
|
self,
|
|
z_t: torch.Tensor,
|
|
r: Optional[torch.Tensor] = None,
|
|
t: Optional[torch.Tensor] = None,
|
|
cond=None,
|
|
) -> torch.Tensor:
|
|
batch_size = z_t.shape[0]
|
|
if t is None:
|
|
t = torch.ones(batch_size, device=z_t.device, dtype=z_t.dtype)
|
|
if r is None:
|
|
r = torch.zeros(batch_size, device=z_t.device, dtype=z_t.dtype)
|
|
u = self.fn(z_t, r, t, cond=cond)
|
|
delta = self._broadcast_batch_time(t - r, z_t)
|
|
return z_t - delta * u
|
|
|
|
def _normalize_qpos_for_lewm(self, qpos: torch.Tensor) -> torch.Tensor:
|
|
if not self.normalization.enabled:
|
|
return qpos
|
|
|
|
qpos_mean = getattr(self.normalization, 'qpos_mean', None)
|
|
qpos_std = getattr(self.normalization, 'qpos_std', None)
|
|
if qpos_mean is not None and qpos_std is not None:
|
|
return (qpos - qpos_mean) / qpos_std
|
|
if isinstance(self.dataset_stats, dict):
|
|
mean = self.dataset_stats.get('qpos_mean', None)
|
|
std = self.dataset_stats.get('qpos_std', None)
|
|
if mean is not None and std is not None:
|
|
mean = torch.as_tensor(mean, dtype=qpos.dtype, device=qpos.device)
|
|
std = torch.as_tensor(std, dtype=qpos.dtype, device=qpos.device)
|
|
return (qpos - mean) / std
|
|
return self.normalization.normalize_qpos(qpos)
|
|
|
|
def _project_lewm_future_tokens(self, predicted_tokens: torch.Tensor) -> torch.Tensor:
|
|
if predicted_tokens.ndim != 3:
|
|
raise ValueError(
|
|
f"expected predicted future tokens to be 3D, got rank {predicted_tokens.ndim}"
|
|
)
|
|
batch_size, token_count, token_dim = predicted_tokens.shape
|
|
flattened = predicted_tokens.reshape(batch_size * token_count, token_dim)
|
|
projected = self.lewm_pred_projector(flattened)
|
|
if projected.ndim != 2:
|
|
raise ValueError(
|
|
f"expected lewm_pred_projector to return rank-2 tensors, got rank {projected.ndim}"
|
|
)
|
|
return projected.reshape(batch_size, token_count, projected.shape[-1])
|
|
|
|
@staticmethod
|
|
def _load_checkpoint_payload(
|
|
checkpoint_or_path: str | Path | Mapping[str, Any],
|
|
) -> Mapping[str, torch.Tensor]:
|
|
if isinstance(checkpoint_or_path, (str, Path)):
|
|
payload = torch.load(Path(checkpoint_or_path), map_location='cpu', weights_only=False)
|
|
else:
|
|
payload = checkpoint_or_path
|
|
state_dict = payload.get('state_dict', payload)
|
|
if not isinstance(state_dict, Mapping):
|
|
raise TypeError('checkpoint payload must contain a mapping state_dict')
|
|
return state_dict
|
|
|
|
@staticmethod
|
|
def _extract_prefixed_state_dict(
|
|
state_dict: Mapping[str, torch.Tensor],
|
|
prefix: str,
|
|
) -> Dict[str, torch.Tensor]:
|
|
extracted = {
|
|
key[len(prefix):]: value
|
|
for key, value in state_dict.items()
|
|
if key.startswith(prefix)
|
|
}
|
|
if not extracted:
|
|
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
|
|
return extracted
|
|
|
|
@staticmethod
|
|
def _adapt_and_load_state_dict(
|
|
module: nn.Module,
|
|
incoming_state_dict: Mapping[str, torch.Tensor],
|
|
*,
|
|
query_key: str = 'query_tokens',
|
|
pos_key: str = 'pos_embedding',
|
|
) -> Dict[str, Sequence[str]]:
|
|
current_state_dict = module.state_dict()
|
|
adapted_state_dict = dict(current_state_dict)
|
|
loaded_keys = []
|
|
mismatched_keys = []
|
|
missing_keys = []
|
|
for key, current_tensor in current_state_dict.items():
|
|
if key not in incoming_state_dict:
|
|
continue
|
|
source_tensor = incoming_state_dict[key]
|
|
if source_tensor.shape == current_tensor.shape:
|
|
adapted_state_dict[key] = source_tensor
|
|
loaded_keys.append(key)
|
|
continue
|
|
|
|
if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim:
|
|
patched = current_tensor.clone()
|
|
if key == query_key:
|
|
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
|
|
patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...]
|
|
if copy_count < current_tensor.shape[1] and copy_count > 0:
|
|
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
|
|
else:
|
|
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
|
|
patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...]
|
|
if copy_count < current_tensor.shape[1] and copy_count > 0:
|
|
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
|
|
adapted_state_dict[key] = patched
|
|
loaded_keys.append(key)
|
|
continue
|
|
mismatched_keys.append(key)
|
|
|
|
for key in incoming_state_dict.keys():
|
|
if key not in current_state_dict:
|
|
missing_keys.append(key)
|
|
module.load_state_dict(adapted_state_dict, strict=True)
|
|
return {
|
|
'loaded_keys': tuple(sorted(loaded_keys)),
|
|
'mismatched_keys': tuple(sorted(set(mismatched_keys))),
|
|
'missing_keys': tuple(sorted(set(missing_keys))),
|
|
}
|
|
|
|
@staticmethod
|
|
def _load_state_dict_ignoring_shape_mismatches(
|
|
module: nn.Module,
|
|
incoming_state_dict: Mapping[str, torch.Tensor],
|
|
) -> Dict[str, Sequence[str]]:
|
|
current_state_dict = module.state_dict()
|
|
merged_state_dict = dict(current_state_dict)
|
|
loaded_keys = []
|
|
mismatched_keys = []
|
|
missing_keys = []
|
|
|
|
for key, value in incoming_state_dict.items():
|
|
if key not in current_state_dict:
|
|
missing_keys.append(key)
|
|
continue
|
|
if current_state_dict[key].shape != value.shape:
|
|
mismatched_keys.append(key)
|
|
continue
|
|
merged_state_dict[key] = value
|
|
loaded_keys.append(key)
|
|
|
|
module.load_state_dict(merged_state_dict, strict=True)
|
|
return {
|
|
'loaded_keys': tuple(sorted(loaded_keys)),
|
|
'mismatched_keys': tuple(sorted(mismatched_keys)),
|
|
'missing_keys': tuple(sorted(missing_keys)),
|
|
}
|
|
|
|
def load_lewm_pretrained_components(
|
|
self,
|
|
checkpoint_or_path: str | Path | Mapping[str, Any],
|
|
) -> None:
|
|
state_dict = self._load_checkpoint_payload(checkpoint_or_path)
|
|
|
|
if hasattr(self.vision_encoder, 'load_lewm_checkpoint'):
|
|
try:
|
|
self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict})
|
|
except RuntimeError:
|
|
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
|
|
self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict)
|
|
else:
|
|
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
|
|
self._load_state_dict_ignoring_shape_mismatches(self.vision_encoder, vision_state_dict)
|
|
|
|
state_encoder_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.state_encoder.')
|
|
self._load_state_dict_ignoring_shape_mismatches(self.state_encoder, state_encoder_state_dict)
|
|
|
|
projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.projector.proj.')
|
|
mapped_projector_state_dict = {
|
|
f'linear.{key}': value
|
|
for key, value in projector_state_dict.items()
|
|
}
|
|
self._load_state_dict_ignoring_shape_mismatches(self.cond_projector, mapped_projector_state_dict)
|
|
|
|
if self.lewm_predictor is not None:
|
|
predictor_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.predictor.')
|
|
self._adapt_and_load_state_dict(self.lewm_predictor, predictor_state_dict)
|
|
|
|
if self.lewm_pred_projector is not None:
|
|
pred_projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.pred_proj.')
|
|
self._load_state_dict_ignoring_shape_mismatches(
|
|
self.lewm_pred_projector,
|
|
pred_projector_state_dict,
|
|
)
|
|
|
|
def _predict_future_tokens_with_decoder(self, history_cond: torch.Tensor) -> torch.Tensor:
|
|
if self.future_decoder is None or self.future_query_tokens is None:
|
|
raise RuntimeError('future_decoder path requested but not initialized')
|
|
batch_size = history_cond.shape[0]
|
|
query_tokens = self.future_query_tokens.expand(batch_size, -1, -1)
|
|
r = torch.zeros(batch_size, device=history_cond.device, dtype=history_cond.dtype)
|
|
t = torch.ones(batch_size, device=history_cond.device, dtype=history_cond.dtype)
|
|
return self.future_decoder(query_tokens, r, t, cond=history_cond)
|
|
|
|
def _build_full_condition(
|
|
self,
|
|
images,
|
|
proprioception,
|
|
*,
|
|
lewm_images=None,
|
|
lewm_proprioception=None,
|
|
):
|
|
normalized_proprioception = self.normalization.normalize_qpos(proprioception)
|
|
history_cond = self._build_cond(images, normalized_proprioception)
|
|
predicted_future_tokens = None
|
|
lewm_history_cond = None
|
|
|
|
if self.lewm_predictor is None and self.future_decoder is None:
|
|
return history_cond, predicted_future_tokens, lewm_history_cond
|
|
|
|
lewm_images = lewm_images if lewm_images is not None else images
|
|
lewm_proprioception = (
|
|
lewm_proprioception if lewm_proprioception is not None else proprioception
|
|
)
|
|
lewm_history_cond = self._build_cond(
|
|
lewm_images,
|
|
self._normalize_qpos_for_lewm(lewm_proprioception),
|
|
)
|
|
cond = history_cond
|
|
if self.lewm_predictor is not None:
|
|
predicted_future_tokens = self.lewm_predictor(lewm_history_cond)
|
|
predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens)
|
|
cond = torch.cat([history_cond, predicted_future_tokens], dim=1)
|
|
if cond.shape[1] != self.condition_sequence_length:
|
|
raise RuntimeError(
|
|
f"完整条件序列长度不匹配: got {cond.shape[1]}, expected {self.condition_sequence_length}"
|
|
)
|
|
if cond.shape[-1] != self.per_step_cond_dim:
|
|
raise RuntimeError(
|
|
f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
|
)
|
|
elif self.future_decoder is not None:
|
|
predicted_future_tokens = self._predict_future_tokens_with_decoder(lewm_history_cond)
|
|
return cond, predicted_future_tokens, lewm_history_cond
|
|
|
|
@staticmethod
|
|
def _masked_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
return F.mse_loss(pred, target)
|
|
|
|
def compute_loss(self, batch):
|
|
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
|
action_is_pad = batch.get('action_is_pad', None)
|
|
batch_size = actions.shape[0]
|
|
|
|
actions = self.normalization.normalize_action(actions)
|
|
cond, predicted_future_tokens, lewm_history_cond = self._build_full_condition(
|
|
images,
|
|
states,
|
|
lewm_images=batch.get('lewm_images', None),
|
|
lewm_proprioception=batch.get('lewm_qpos', None),
|
|
)
|
|
|
|
x = actions
|
|
e = torch.randn_like(x)
|
|
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
|
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
|
t, r = torch.maximum(t, r), torch.minimum(t, r)
|
|
|
|
t_broadcast = self._broadcast_batch_time(t, x)
|
|
z_t = (1 - t_broadcast) * x + t_broadcast * e
|
|
|
|
v = self.fn(z_t, t, t, cond=cond)
|
|
u, du_dt = self._compute_u_and_du_dt(z_t, r, t, cond=cond, v=v)
|
|
V = self._compound_velocity(u, du_dt, r, t)
|
|
target = e - x
|
|
|
|
loss = F.mse_loss(V, target, reduction='none')
|
|
if action_is_pad is not None:
|
|
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
|
|
valid_count = mask.sum() * loss.shape[-1]
|
|
action_loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
|
else:
|
|
action_loss = loss.mean()
|
|
|
|
lewm_pred_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
|
|
lewm_sigreg_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
|
|
if predicted_future_tokens is not None:
|
|
lewm_future_images = batch.get('lewm_future_images', None)
|
|
lewm_future_qpos = batch.get('lewm_future_qpos', None)
|
|
if lewm_future_images is not None and lewm_future_qpos is not None:
|
|
future_target = self._build_cond(
|
|
lewm_future_images,
|
|
self._normalize_qpos_for_lewm(lewm_future_qpos),
|
|
)
|
|
lewm_pred_loss = self._masked_mse_loss(predicted_future_tokens, future_target)
|
|
if self.lewm_sigreg is not None and lewm_history_cond is not None:
|
|
lewm_sigreg_loss = self.lewm_sigreg(lewm_history_cond.transpose(0, 1))
|
|
|
|
lewm_loss = lewm_pred_loss + self.lewm_sigreg_weight * lewm_sigreg_loss
|
|
total_loss = action_loss + self.lewm_loss_weight * lewm_loss
|
|
self._last_loss_breakdown = {
|
|
'action_loss': float(action_loss.detach().item()),
|
|
'lewm_pred_loss': float(lewm_pred_loss.detach().item()),
|
|
'lewm_sigreg_loss': float(lewm_sigreg_loss.detach().item()),
|
|
'lewm_loss': float(lewm_loss.detach().item()),
|
|
'loss': float(total_loss.detach().item()),
|
|
}
|
|
return total_loss
|
|
|
|
def get_last_loss_breakdown(self) -> Dict[str, float]:
|
|
return dict(self._last_loss_breakdown)
|
|
|
|
def reset(self):
|
|
super().reset()
|
|
if self.lewm_predictor is not None:
|
|
self._queues['lewm_qpos'] = deque(maxlen=self.lewm_history_horizon)
|
|
self._queues['lewm_images'] = deque(maxlen=self.lewm_history_horizon)
|
|
|
|
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
|
|
super()._populate_queues(observation)
|
|
if self.lewm_predictor is None:
|
|
return
|
|
if 'qpos' in observation:
|
|
self._queues['lewm_qpos'].append(observation['qpos'].clone())
|
|
if 'images' in observation:
|
|
ordered_images = self._order_images(observation['images'])
|
|
self._queues['lewm_images'].append({k: v.clone() for k, v in ordered_images.items()})
|
|
|
|
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
|
batch = super()._prepare_observation_batch()
|
|
if self.lewm_predictor is None:
|
|
return batch
|
|
|
|
qpos_list = list(self._queues['lewm_qpos'])
|
|
images_list = list(self._queues['lewm_images'])
|
|
if len(qpos_list) == 0 or len(images_list) == 0:
|
|
raise ValueError("LeWM 观测队列为空,请先调用 _populate_queues 添加观测")
|
|
while len(qpos_list) < self.lewm_history_horizon:
|
|
qpos_list.append(qpos_list[-1])
|
|
while len(images_list) < self.lewm_history_horizon:
|
|
images_list.append(images_list[-1])
|
|
|
|
batch['lewm_qpos'] = torch.stack(qpos_list, dim=0).unsqueeze(0)
|
|
batch['lewm_images'] = {}
|
|
camera_names = self.camera_names if self.camera_names is not None else tuple(sorted(images_list[0].keys()))
|
|
for cam_name in camera_names:
|
|
batch['lewm_images'][cam_name] = torch.stack(
|
|
[img[cam_name] for img in images_list],
|
|
dim=0,
|
|
).unsqueeze(0)
|
|
return batch
|
|
|
|
@torch.no_grad()
|
|
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
return self.predict_action(
|
|
batch['images'],
|
|
batch['qpos'],
|
|
lewm_images=batch.get('lewm_images', None),
|
|
lewm_proprioception=batch.get('lewm_qpos', None),
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def predict_action(
|
|
self,
|
|
images,
|
|
proprioception,
|
|
*,
|
|
lewm_images=None,
|
|
lewm_proprioception=None,
|
|
):
|
|
batch_size = proprioception.shape[0]
|
|
if self.lewm_predictor is not None:
|
|
cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition(
|
|
images,
|
|
proprioception,
|
|
lewm_images=lewm_images,
|
|
lewm_proprioception=lewm_proprioception,
|
|
)
|
|
else:
|
|
cond = self._build_cond(
|
|
images,
|
|
self.normalization.normalize_qpos(proprioception),
|
|
)
|
|
z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype)
|
|
action = self._sample_one_step(z_t, cond=cond)
|
|
return self.normalization.denormalize_action(action)
|