debug
This commit is contained in:
764
diffusion/modeling_diffusion.py
Normal file
764
diffusion/modeling_diffusion.py
Normal file
@@ -0,0 +1,764 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||||
|
|
||||||
|
TODO(alexander-soare):
|
||||||
|
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections import deque
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
import torchvision
|
||||||
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.policies.utils import (
|
||||||
|
get_device_from_parameters,
|
||||||
|
get_dtype_from_parameters,
|
||||||
|
get_output_shape,
|
||||||
|
populate_queues,
|
||||||
|
)
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionPolicy(PreTrainedPolicy):
|
||||||
|
"""
|
||||||
|
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||||
|
(paper: https://huggingface.co/papers/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = DiffusionConfig
|
||||||
|
name = "diffusion"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: DiffusionConfig,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||||
|
the configuration class is used.
|
||||||
|
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||||
|
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||||
|
"""
|
||||||
|
super().__init__(config)
|
||||||
|
config.validate_features()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||||
|
self._queues = None
|
||||||
|
|
||||||
|
self.diffusion = DiffusionModel(config)
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def get_optim_params(self) -> dict:
|
||||||
|
return self.diffusion.parameters()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||||
|
self._queues = {
|
||||||
|
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
|
||||||
|
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||||
|
}
|
||||||
|
if self.config.image_features:
|
||||||
|
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
||||||
|
if self.config.env_state_feature:
|
||||||
|
self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
|
"""Predict a chunk of actions given environment observations."""
|
||||||
|
# stack n latest observations from the queue
|
||||||
|
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||||
|
actions = self.diffusion.generate_actions(batch, noise=noise)
|
||||||
|
|
||||||
|
return actions
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
|
This method handles caching a history of observations and an action trajectory generated by the
|
||||||
|
underlying diffusion model. Here's how it works:
|
||||||
|
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
||||||
|
copied `n_obs_steps` times to fill the cache).
|
||||||
|
- The diffusion model generates `horizon` steps worth of actions.
|
||||||
|
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
||||||
|
Schematically this looks like:
|
||||||
|
----------------------------------------------------------------------------------------------
|
||||||
|
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
||||||
|
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
||||||
|
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
||||||
|
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||||
|
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||||
|
----------------------------------------------------------------------------------------------
|
||||||
|
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
||||||
|
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||||
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
|
"""
|
||||||
|
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||||
|
if ACTION in batch:
|
||||||
|
batch.pop(ACTION)
|
||||||
|
|
||||||
|
if self.config.image_features:
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
|
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||||
|
# NOTE: It's important that this happens after stacking the images into a single key.
|
||||||
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
|
if len(self._queues[ACTION]) == 0:
|
||||||
|
actions = self.predict_action_chunk(batch, noise=noise)
|
||||||
|
self._queues[ACTION].extend(actions.transpose(0, 1))
|
||||||
|
|
||||||
|
action = self._queues[ACTION].popleft()
|
||||||
|
return action
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||||
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
|
if self.config.image_features:
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
|
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||||
|
loss = self.diffusion.compute_loss(batch)
|
||||||
|
# no output_dict so returning None
|
||||||
|
return loss, None
|
||||||
|
|
||||||
|
|
||||||
|
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||||
|
"""
|
||||||
|
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
||||||
|
to the scheduler.
|
||||||
|
"""
|
||||||
|
if name == "DDPM":
|
||||||
|
return DDPMScheduler(**kwargs)
|
||||||
|
elif name == "DDIM":
|
||||||
|
return DDIMScheduler(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionModel(nn.Module):
|
||||||
|
def __init__(self, config: DiffusionConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Build observation encoders (depending on which observations are provided).
|
||||||
|
global_cond_dim = self.config.robot_state_feature.shape[0]
|
||||||
|
if self.config.image_features:
|
||||||
|
num_images = len(self.config.image_features)
|
||||||
|
if self.config.use_separate_rgb_encoder_per_camera:
|
||||||
|
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
|
||||||
|
self.rgb_encoder = nn.ModuleList(encoders)
|
||||||
|
global_cond_dim += encoders[0].feature_dim * num_images
|
||||||
|
else:
|
||||||
|
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||||
|
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||||
|
if self.config.env_state_feature:
|
||||||
|
global_cond_dim += self.config.env_state_feature.shape[0]
|
||||||
|
|
||||||
|
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||||
|
|
||||||
|
self.noise_scheduler = _make_noise_scheduler(
|
||||||
|
config.noise_scheduler_type,
|
||||||
|
num_train_timesteps=config.num_train_timesteps,
|
||||||
|
beta_start=config.beta_start,
|
||||||
|
beta_end=config.beta_end,
|
||||||
|
beta_schedule=config.beta_schedule,
|
||||||
|
clip_sample=config.clip_sample,
|
||||||
|
clip_sample_range=config.clip_sample_range,
|
||||||
|
prediction_type=config.prediction_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.num_inference_steps is None:
|
||||||
|
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
||||||
|
else:
|
||||||
|
self.num_inference_steps = config.num_inference_steps
|
||||||
|
|
||||||
|
# ========= inference ============
|
||||||
|
def conditional_sample(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
global_cond: Tensor | None = None,
|
||||||
|
generator: torch.Generator | None = None,
|
||||||
|
noise: Tensor | None = None,
|
||||||
|
) -> Tensor:
|
||||||
|
device = get_device_from_parameters(self)
|
||||||
|
dtype = get_dtype_from_parameters(self)
|
||||||
|
|
||||||
|
# Sample prior.
|
||||||
|
sample = (
|
||||||
|
noise
|
||||||
|
if noise is not None
|
||||||
|
else torch.randn(
|
||||||
|
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
||||||
|
|
||||||
|
for t in self.noise_scheduler.timesteps:
|
||||||
|
# Predict model output.
|
||||||
|
model_output = self.unet(
|
||||||
|
sample,
|
||||||
|
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
||||||
|
global_cond=global_cond,
|
||||||
|
)
|
||||||
|
# Compute previous image: x_t -> x_t-1
|
||||||
|
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Encode image features and concatenate them all together along with the state vector."""
|
||||||
|
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||||
|
global_cond_feats = [batch[OBS_STATE]]
|
||||||
|
# Extract image features.
|
||||||
|
if self.config.image_features:
|
||||||
|
if self.config.use_separate_rgb_encoder_per_camera:
|
||||||
|
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||||
|
images_per_camera = einops.rearrange(batch[OBS_IMAGES], "b s n ... -> n (b s) ...")
|
||||||
|
img_features_list = torch.cat(
|
||||||
|
[
|
||||||
|
encoder(images)
|
||||||
|
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||||
|
# feature dim (effectively concatenating the camera features).
|
||||||
|
img_features = einops.rearrange(
|
||||||
|
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||||
|
img_features = self.rgb_encoder(
|
||||||
|
einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...")
|
||||||
|
)
|
||||||
|
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||||
|
# feature dim (effectively concatenating the camera features).
|
||||||
|
img_features = einops.rearrange(
|
||||||
|
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||||
|
)
|
||||||
|
global_cond_feats.append(img_features)
|
||||||
|
|
||||||
|
if self.config.env_state_feature:
|
||||||
|
global_cond_feats.append(batch[OBS_ENV_STATE])
|
||||||
|
|
||||||
|
# Concatenate features then flatten to (B, global_cond_dim).
|
||||||
|
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||||
|
|
||||||
|
def generate_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
|
"""
|
||||||
|
This function expects `batch` to have:
|
||||||
|
{
|
||||||
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
|
|
||||||
|
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||||
|
AND/OR
|
||||||
|
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||||
|
assert n_obs_steps == self.config.n_obs_steps
|
||||||
|
|
||||||
|
# Encode image features and concatenate them all together along with the state vector.
|
||||||
|
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||||
|
|
||||||
|
# run sampling
|
||||||
|
actions = self.conditional_sample(batch_size, global_cond=global_cond, noise=noise)
|
||||||
|
|
||||||
|
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||||
|
start = n_obs_steps - 1
|
||||||
|
end = start + self.config.n_action_steps
|
||||||
|
actions = actions[:, start:end]
|
||||||
|
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""
|
||||||
|
This function expects `batch` to have (at least):
|
||||||
|
{
|
||||||
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
|
|
||||||
|
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||||
|
AND/OR
|
||||||
|
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||||
|
|
||||||
|
"action": (B, horizon, action_dim)
|
||||||
|
"action_is_pad": (B, horizon)
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Input validation.
|
||||||
|
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
|
||||||
|
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
||||||
|
n_obs_steps = batch[OBS_STATE].shape[1]
|
||||||
|
horizon = batch[ACTION].shape[1]
|
||||||
|
assert horizon == self.config.horizon
|
||||||
|
assert n_obs_steps == self.config.n_obs_steps
|
||||||
|
|
||||||
|
# Encode image features and concatenate them all together along with the state vector.
|
||||||
|
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||||
|
|
||||||
|
# Forward diffusion.
|
||||||
|
trajectory = batch[ACTION]
|
||||||
|
# Sample noise to add to the trajectory.
|
||||||
|
eps = torch.randn(trajectory.shape, device=trajectory.device)
|
||||||
|
# Sample a random noising timestep for each item in the batch.
|
||||||
|
timesteps = torch.randint(
|
||||||
|
low=0,
|
||||||
|
high=self.noise_scheduler.config.num_train_timesteps,
|
||||||
|
size=(trajectory.shape[0],),
|
||||||
|
device=trajectory.device,
|
||||||
|
).long()
|
||||||
|
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
|
||||||
|
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
|
||||||
|
|
||||||
|
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
|
||||||
|
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
|
||||||
|
|
||||||
|
# Compute the loss.
|
||||||
|
# The target is either the original trajectory, or the noise.
|
||||||
|
if self.config.prediction_type == "epsilon":
|
||||||
|
target = eps
|
||||||
|
elif self.config.prediction_type == "sample":
|
||||||
|
target = batch[ACTION]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
|
||||||
|
|
||||||
|
loss = F.mse_loss(pred, target, reduction="none")
|
||||||
|
|
||||||
|
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
||||||
|
if self.config.do_mask_loss_for_padding:
|
||||||
|
if "action_is_pad" not in batch:
|
||||||
|
raise ValueError(
|
||||||
|
"You need to provide 'action_is_pad' in the batch when "
|
||||||
|
f"{self.config.do_mask_loss_for_padding=}."
|
||||||
|
)
|
||||||
|
in_episode_bound = ~batch["action_is_pad"]
|
||||||
|
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||||
|
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialSoftmax(nn.Module):
|
||||||
|
"""
|
||||||
|
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||||
|
(https://huggingface.co/papers/1509.06113). A minimal port of the robomimic implementation.
|
||||||
|
|
||||||
|
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
|
||||||
|
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
|
||||||
|
|
||||||
|
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
|
||||||
|
-----------------------------------------------------
|
||||||
|
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
|
||||||
|
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
|
||||||
|
| ... | ... | ... | ... |
|
||||||
|
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
|
||||||
|
-----------------------------------------------------
|
||||||
|
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
|
||||||
|
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
|
||||||
|
|
||||||
|
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
|
||||||
|
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
|
||||||
|
linear mapping (in_channels, H, W) -> (num_kp, H, W).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_shape, num_kp=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_shape (list): (C, H, W) input feature map shape.
|
||||||
|
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert len(input_shape) == 3
|
||||||
|
self._in_c, self._in_h, self._in_w = input_shape
|
||||||
|
|
||||||
|
if num_kp is not None:
|
||||||
|
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||||
|
self._out_c = num_kp
|
||||||
|
else:
|
||||||
|
self.nets = None
|
||||||
|
self._out_c = self._in_c
|
||||||
|
|
||||||
|
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||||
|
# and causes a small degradation in pc_success of pre-trained models.
|
||||||
|
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||||
|
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||||
|
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||||
|
# register as buffer so it's moved to the correct device.
|
||||||
|
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
|
||||||
|
|
||||||
|
def forward(self, features: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
features: (B, C, H, W) input feature maps.
|
||||||
|
Returns:
|
||||||
|
(B, K, 2) image-space coordinates of keypoints.
|
||||||
|
"""
|
||||||
|
if self.nets is not None:
|
||||||
|
features = self.nets(features)
|
||||||
|
|
||||||
|
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||||
|
features = features.reshape(-1, self._in_h * self._in_w)
|
||||||
|
# 2d softmax normalization
|
||||||
|
attention = F.softmax(features, dim=-1)
|
||||||
|
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
|
||||||
|
expected_xy = attention @ self.pos_grid
|
||||||
|
# reshape to [B, K, 2]
|
||||||
|
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||||
|
|
||||||
|
return feature_keypoints
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionRgbEncoder(nn.Module):
|
||||||
|
"""Encodes an RGB image into a 1D feature vector.
|
||||||
|
|
||||||
|
Includes the ability to normalize and crop the image first.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: DiffusionConfig):
|
||||||
|
super().__init__()
|
||||||
|
# Set up optional preprocessing.
|
||||||
|
if config.crop_shape is not None:
|
||||||
|
self.do_crop = True
|
||||||
|
# Always use center crop for eval
|
||||||
|
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||||
|
if config.crop_is_random:
|
||||||
|
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||||
|
else:
|
||||||
|
self.maybe_random_crop = self.center_crop
|
||||||
|
else:
|
||||||
|
self.do_crop = False
|
||||||
|
|
||||||
|
# Set up backbone.
|
||||||
|
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||||
|
weights=config.pretrained_backbone_weights
|
||||||
|
)
|
||||||
|
# Note: This assumes that the layer4 feature map is children()[-3]
|
||||||
|
# TODO(alexander-soare): Use a safer alternative.
|
||||||
|
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||||
|
if config.use_group_norm:
|
||||||
|
if config.pretrained_backbone_weights:
|
||||||
|
raise ValueError(
|
||||||
|
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
|
||||||
|
)
|
||||||
|
self.backbone = _replace_submodules(
|
||||||
|
root_module=self.backbone,
|
||||||
|
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||||
|
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up pooling and final layers.
|
||||||
|
# Use a dry run to get the feature map shape.
|
||||||
|
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||||
|
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||||
|
# height and width from `config.image_features`.
|
||||||
|
|
||||||
|
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||||
|
images_shape = next(iter(config.image_features.values())).shape
|
||||||
|
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||||
|
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||||
|
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||||
|
|
||||||
|
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||||
|
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||||
|
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, C, H, W) image tensor with pixel values in [0, 1].
|
||||||
|
Returns:
|
||||||
|
(B, D) image feature.
|
||||||
|
"""
|
||||||
|
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||||
|
if self.do_crop:
|
||||||
|
if self.training: # noqa: SIM108
|
||||||
|
x = self.maybe_random_crop(x)
|
||||||
|
else:
|
||||||
|
# Always use center crop for eval.
|
||||||
|
x = self.center_crop(x)
|
||||||
|
# Extract backbone feature.
|
||||||
|
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
|
||||||
|
# Final linear layer with non-linearity.
|
||||||
|
x = self.relu(self.out(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_submodules(
|
||||||
|
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
root_module: The module for which the submodules need to be replaced
|
||||||
|
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
|
||||||
|
func: Takes a module as an argument and returns a new module to replace it with.
|
||||||
|
Returns:
|
||||||
|
The root module with its submodules replaced.
|
||||||
|
"""
|
||||||
|
if predicate(root_module):
|
||||||
|
return func(root_module)
|
||||||
|
|
||||||
|
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||||
|
for *parents, k in replace_list:
|
||||||
|
parent_module = root_module
|
||||||
|
if len(parents) > 0:
|
||||||
|
parent_module = root_module.get_submodule(".".join(parents))
|
||||||
|
if isinstance(parent_module, nn.Sequential):
|
||||||
|
src_module = parent_module[int(k)]
|
||||||
|
else:
|
||||||
|
src_module = getattr(parent_module, k)
|
||||||
|
tgt_module = func(src_module)
|
||||||
|
if isinstance(parent_module, nn.Sequential):
|
||||||
|
parent_module[int(k)] = tgt_module
|
||||||
|
else:
|
||||||
|
setattr(parent_module, k, tgt_module)
|
||||||
|
# verify that all BN are replaced
|
||||||
|
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||||
|
return root_module
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionSinusoidalPosEmb(nn.Module):
|
||||||
|
"""1D sinusoidal positional embeddings as in Attention is All You Need."""
|
||||||
|
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
device = x.device
|
||||||
|
half_dim = self.dim // 2
|
||||||
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||||
|
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionConv1dBlock(nn.Module):
|
||||||
|
"""Conv1d --> GroupNorm --> Mish"""
|
||||||
|
|
||||||
|
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||||
|
nn.GroupNorm(n_groups, out_channels),
|
||||||
|
nn.Mish(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionConditionalUnet1d(nn.Module):
|
||||||
|
"""A 1D convolutional UNet with FiLM modulation for conditioning.
|
||||||
|
|
||||||
|
Note: this removes local conditioning as compared to the original diffusion policy code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: DiffusionConfig, global_cond_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Encoder for the diffusion timestep.
|
||||||
|
self.diffusion_step_encoder = nn.Sequential(
|
||||||
|
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
|
||||||
|
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
# The FiLM conditioning dimension.
|
||||||
|
cond_dim = config.diffusion_step_embed_dim + global_cond_dim
|
||||||
|
|
||||||
|
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
||||||
|
# just reverse these.
|
||||||
|
in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list(
|
||||||
|
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unet encoder.
|
||||||
|
common_res_block_kwargs = {
|
||||||
|
"cond_dim": cond_dim,
|
||||||
|
"kernel_size": config.kernel_size,
|
||||||
|
"n_groups": config.n_groups,
|
||||||
|
"use_film_scale_modulation": config.use_film_scale_modulation,
|
||||||
|
}
|
||||||
|
self.down_modules = nn.ModuleList([])
|
||||||
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||||
|
is_last = ind >= (len(in_out) - 1)
|
||||||
|
self.down_modules.append(
|
||||||
|
nn.ModuleList(
|
||||||
|
[
|
||||||
|
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
|
||||||
|
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||||
|
# Downsample as long as it is not the last block.
|
||||||
|
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Processing in the middle of the auto-encoder.
|
||||||
|
self.mid_modules = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DiffusionConditionalResidualBlock1d(
|
||||||
|
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||||
|
),
|
||||||
|
DiffusionConditionalResidualBlock1d(
|
||||||
|
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unet decoder.
|
||||||
|
self.up_modules = nn.ModuleList([])
|
||||||
|
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
|
||||||
|
is_last = ind >= (len(in_out) - 1)
|
||||||
|
self.up_modules.append(
|
||||||
|
nn.ModuleList(
|
||||||
|
[
|
||||||
|
# dim_in * 2, because it takes the encoder's skip connection as well
|
||||||
|
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
|
||||||
|
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||||
|
# Upsample as long as it is not the last block.
|
||||||
|
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_conv = nn.Sequential(
|
||||||
|
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
||||||
|
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, T, input_dim) tensor for input to the Unet.
|
||||||
|
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
|
||||||
|
global_cond: (B, global_cond_dim)
|
||||||
|
output: (B, T, input_dim)
|
||||||
|
Returns:
|
||||||
|
(B, T, input_dim) diffusion model prediction.
|
||||||
|
"""
|
||||||
|
# For 1D convolutions we'll need feature dimension first.
|
||||||
|
x = einops.rearrange(x, "b t d -> b d t")
|
||||||
|
|
||||||
|
timesteps_embed = self.diffusion_step_encoder(timestep)
|
||||||
|
|
||||||
|
# If there is a global conditioning feature, concatenate it to the timestep embedding.
|
||||||
|
if global_cond is not None:
|
||||||
|
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
|
||||||
|
else:
|
||||||
|
global_feature = timesteps_embed
|
||||||
|
|
||||||
|
# Run encoder, keeping track of skip features to pass to the decoder.
|
||||||
|
encoder_skip_features: list[Tensor] = []
|
||||||
|
for resnet, resnet2, downsample in self.down_modules:
|
||||||
|
x = resnet(x, global_feature)
|
||||||
|
x = resnet2(x, global_feature)
|
||||||
|
encoder_skip_features.append(x)
|
||||||
|
x = downsample(x)
|
||||||
|
|
||||||
|
for mid_module in self.mid_modules:
|
||||||
|
x = mid_module(x, global_feature)
|
||||||
|
|
||||||
|
# Run decoder, using the skip features from the encoder.
|
||||||
|
for resnet, resnet2, upsample in self.up_modules:
|
||||||
|
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
|
||||||
|
x = resnet(x, global_feature)
|
||||||
|
x = resnet2(x, global_feature)
|
||||||
|
x = upsample(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
|
||||||
|
x = einops.rearrange(x, "b d t -> b t d")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionConditionalResidualBlock1d(nn.Module):
|
||||||
|
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
cond_dim: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
n_groups: int = 8,
|
||||||
|
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
|
||||||
|
# FiLM just modulates bias).
|
||||||
|
use_film_scale_modulation: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.use_film_scale_modulation = use_film_scale_modulation
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||||
|
|
||||||
|
# FiLM modulation (https://huggingface.co/papers/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||||
|
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
||||||
|
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||||
|
|
||||||
|
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||||
|
|
||||||
|
# A final convolution for dimension matching the residual (if needed).
|
||||||
|
self.residual_conv = (
|
||||||
|
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, in_channels, T)
|
||||||
|
cond: (B, cond_dim)
|
||||||
|
Returns:
|
||||||
|
(B, out_channels, T)
|
||||||
|
"""
|
||||||
|
out = self.conv1(x)
|
||||||
|
|
||||||
|
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
|
||||||
|
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
|
||||||
|
if self.use_film_scale_modulation:
|
||||||
|
# Treat the embedding as a list of scales and biases.
|
||||||
|
scale = cond_embed[:, : self.out_channels]
|
||||||
|
bias = cond_embed[:, self.out_channels :]
|
||||||
|
out = scale * out + bias
|
||||||
|
else:
|
||||||
|
# Treat the embedding as biases.
|
||||||
|
out = out + cond_embed
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = out + self.residual_conv(x)
|
||||||
|
return out
|
||||||
@@ -6,4 +6,3 @@ crop_shape: [84, 84]
|
|||||||
crop_is_random: true
|
crop_is_random: true
|
||||||
use_group_norm: true
|
use_group_norm: true
|
||||||
spatial_softmax_num_keypoints: 32
|
spatial_softmax_num_keypoints: 32
|
||||||
use_separate_rgb_encoder_per_camera: true
|
|
||||||
@@ -5,7 +5,7 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16 # Batch size for training
|
batch_size: 8 # Batch size for training
|
||||||
lr: 1e-4 # Learning rate
|
lr: 1e-4 # Learning rate
|
||||||
max_steps: 20000 # Maximum training steps
|
max_steps: 20000 # Maximum training steps
|
||||||
log_freq: 100 # Log frequency (steps)
|
log_freq: 100 # Log frequency (steps)
|
||||||
|
|||||||
@@ -101,20 +101,9 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
crop_is_random: bool = True,
|
crop_is_random: bool = True,
|
||||||
use_group_norm: bool = True,
|
use_group_norm: bool = True,
|
||||||
spatial_softmax_num_keypoints: int = 32,
|
spatial_softmax_num_keypoints: int = 32,
|
||||||
use_separate_rgb_encoder_per_camera: bool = True,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# 保存所有参数作为实例变量
|
|
||||||
self.vision_backbone = vision_backbone
|
|
||||||
self.pretrained_backbone_weights = pretrained_backbone_weights
|
|
||||||
self.input_shape = input_shape
|
|
||||||
self.crop_shape = crop_shape
|
|
||||||
self.crop_is_random = crop_is_random
|
|
||||||
self.use_group_norm = use_group_norm
|
|
||||||
self.spatial_softmax_num_keypoints = spatial_softmax_num_keypoints
|
|
||||||
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
|
||||||
|
|
||||||
# 设置可选的预处理。
|
# 设置可选的预处理。
|
||||||
if crop_shape is not None:
|
if crop_shape is not None:
|
||||||
self.do_crop = True
|
self.do_crop = True
|
||||||
@@ -126,120 +115,49 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
self.maybe_random_crop = self.center_crop
|
self.maybe_random_crop = self.center_crop
|
||||||
else:
|
else:
|
||||||
self.do_crop = False
|
self.do_crop = False
|
||||||
self.crop_shape = input_shape[1:]
|
crop_shape = input_shape[1:]
|
||||||
|
|
||||||
# 创建骨干网络的内部函数
|
# 设置骨干网络。
|
||||||
def _create_backbone():
|
backbone_model = getattr(torchvision.models, vision_backbone)(
|
||||||
backbone_model = getattr(torchvision.models, vision_backbone)(
|
weights=pretrained_backbone_weights
|
||||||
weights=pretrained_backbone_weights
|
|
||||||
)
|
|
||||||
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
|
|
||||||
backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
|
||||||
if use_group_norm:
|
|
||||||
backbone = _replace_submodules(
|
|
||||||
root_module=backbone,
|
|
||||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
|
||||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
|
||||||
)
|
|
||||||
return backbone
|
|
||||||
|
|
||||||
# 创建池化和最终层的内部函数
|
|
||||||
def _create_head(feature_map_shape):
|
|
||||||
pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
|
|
||||||
feature_dim = spatial_softmax_num_keypoints * 2
|
|
||||||
out = nn.Linear(spatial_softmax_num_keypoints * 2, feature_dim)
|
|
||||||
relu = nn.ReLU()
|
|
||||||
return pool, feature_dim, out, relu
|
|
||||||
|
|
||||||
# 使用试运行来获取特征图形状
|
|
||||||
dummy_shape = (1, input_shape[0], *self.crop_shape)
|
|
||||||
|
|
||||||
if self.use_separate_rgb_encoder_per_camera:
|
|
||||||
# 每个相机使用独立的编码器,我们先创建一个临时骨干网络来获取特征图形状
|
|
||||||
temp_backbone = _create_backbone()
|
|
||||||
with torch.no_grad():
|
|
||||||
dummy_out = temp_backbone(torch.zeros(dummy_shape))
|
|
||||||
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
|
||||||
del temp_backbone
|
|
||||||
|
|
||||||
# 注意:我们在 forward 方法中动态创建编码器,或者在知道相机数量时创建
|
|
||||||
# 这里我们先不创建具体的编码器实例,而是在 forward 时根据需要创建
|
|
||||||
# 或者,我们可以要求用户提供相机数量参数
|
|
||||||
self.camera_encoders = None
|
|
||||||
self.feature_dim = spatial_softmax_num_keypoints * 2
|
|
||||||
else:
|
|
||||||
# 所有相机共享同一个编码器
|
|
||||||
self.backbone = _create_backbone()
|
|
||||||
with torch.no_grad():
|
|
||||||
dummy_out = self.backbone(torch.zeros(dummy_shape))
|
|
||||||
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
|
||||||
self.pool, self.feature_dim, self.out, self.relu = _create_head(feature_map_shape)
|
|
||||||
|
|
||||||
def _create_single_encoder(self):
|
|
||||||
"""内部方法:创建单个编码器(骨干网络 + 池化 + 输出层)"""
|
|
||||||
# 创建骨干网络
|
|
||||||
backbone_model = getattr(torchvision.models, self.vision_backbone)(
|
|
||||||
weights=self.pretrained_backbone_weights
|
|
||||||
)
|
)
|
||||||
backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
|
||||||
|
|
||||||
if self.use_group_norm:
|
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
|
||||||
backbone = _replace_submodules(
|
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||||
root_module=backbone,
|
|
||||||
|
if use_group_norm:
|
||||||
|
self.backbone = _replace_submodules(
|
||||||
|
root_module=self.backbone,
|
||||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取特征图形状
|
# 设置池化和最终层。
|
||||||
dummy_shape = (1, self.input_shape[0], *self.crop_shape)
|
# 使用试运行来获取特征图形状。
|
||||||
|
dummy_shape = (1, input_shape[0], *crop_shape)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
dummy_out = backbone(torch.zeros(dummy_shape))
|
dummy_out = self.backbone(torch.zeros(dummy_shape))
|
||||||
feature_map_shape = dummy_out.shape[1:]
|
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
||||||
|
|
||||||
# 创建池化和输出层
|
self.pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
|
||||||
pool = SpatialSoftmax(feature_map_shape, num_kp=self.spatial_softmax_num_keypoints)
|
self.feature_dim = spatial_softmax_num_keypoints * 2
|
||||||
out = nn.Linear(self.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
self.out = nn.Linear(spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||||
relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
return nn.ModuleList([backbone, pool, out, relu])
|
def forward_single_image(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
def forward_single_image(self, x: torch.Tensor, encoder: nn.ModuleList = None) -> torch.Tensor:
|
|
||||||
if self.do_crop:
|
if self.do_crop:
|
||||||
x = self.maybe_random_crop(x) if self.training else self.center_crop(x)
|
x = self.maybe_random_crop(x) if self.training else self.center_crop(x)
|
||||||
|
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||||
if self.use_separate_rgb_encoder_per_camera:
|
|
||||||
# 使用独立编码器
|
|
||||||
backbone, pool, out, relu = encoder
|
|
||||||
x = relu(out(torch.flatten(pool(backbone(x)), start_dim=1)))
|
|
||||||
else:
|
|
||||||
# 使用共享编码器
|
|
||||||
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, images):
|
def forward(self, images):
|
||||||
any_tensor = next(iter(images.values()))
|
any_tensor = next(iter(images.values()))
|
||||||
B, T = any_tensor.shape[:2]
|
B, T = any_tensor.shape[:2]
|
||||||
features_all = []
|
features_all = []
|
||||||
|
|
||||||
# 检查是否需要初始化独立编码器
|
|
||||||
if self.use_separate_rgb_encoder_per_camera and self.camera_encoders is None:
|
|
||||||
self.camera_encoders = nn.ModuleDict()
|
|
||||||
for cam_name in sorted(images.keys()):
|
|
||||||
self.camera_encoders[cam_name] = self._create_single_encoder()
|
|
||||||
|
|
||||||
for cam_name in sorted(images.keys()):
|
for cam_name in sorted(images.keys()):
|
||||||
img = images[cam_name]
|
img = images[cam_name]
|
||||||
if self.use_separate_rgb_encoder_per_camera:
|
features = self.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||||
# 使用该相机对应的独立编码器
|
|
||||||
features = self.forward_single_image(
|
|
||||||
img.view(B * T, *img.shape[2:]),
|
|
||||||
self.camera_encoders[cam_name]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 使用共享编码器
|
|
||||||
features = self.forward_single_image(img.view(B * T, *img.shape[2:]))
|
|
||||||
features_all.append(features)
|
features_all.append(features)
|
||||||
|
|
||||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
Reference in New Issue
Block a user