feat(train): 跑通训练脚本
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
# Backbone models
|
||||
from .siglip import SigLIPBackbone
|
||||
from .resnet import ResNetBackbone
|
||||
# from .clip import CLIPBackbone
|
||||
# from .dinov2 import DinoV2Backbone
|
||||
|
||||
__all__ = ["SigLIPBackbone"]
|
||||
__all__ = ["SigLIPBackbone", "ResNetBackbone"]
|
||||
|
||||
# from .debug import DebugBackbone
|
||||
# __all__ = ["DebugBackbone"]
|
||||
@@ -1 +0,0 @@
|
||||
# CLIP Backbone 实现
|
||||
@@ -1,30 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
|
||||
class DebugBackbone(VLABackbone):
|
||||
"""
|
||||
A fake backbone that outputs random tensors.
|
||||
"""
|
||||
def __init__(self, embed_dim: int = 768, seq_len: int = 10):
|
||||
super().__init__()
|
||||
self._embed_dim = embed_dim
|
||||
self.seq_len = seq_len
|
||||
# A dummy trainable parameter
|
||||
self.dummy_param = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, obs: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
batch_size = obs['image'].shape[0]
|
||||
|
||||
# 1. Generate random noise
|
||||
noise = torch.randn(batch_size, self.seq_len, self._embed_dim, device=obs['image'].device)
|
||||
|
||||
# 2. CRITICAL FIX: Add the dummy parameter to the noise.
|
||||
# This connects 'noise' to 'self.dummy_param' in the computation graph.
|
||||
# The value doesn't change (since param is 0), but the gradient path is established.
|
||||
return noise + self.dummy_param
|
||||
|
||||
@property
|
||||
def embed_dim(self) -> int:
|
||||
return self._embed_dim
|
||||
@@ -1 +0,0 @@
|
||||
# DinoV2 Backbone 实现
|
||||
83
roboimi/vla/models/backbones/resnet.py
Normal file
83
roboimi/vla/models/backbones/resnet.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
from transformers import ResNetModel
|
||||
from torchvision import transforms
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class ResNetBackbone(VLABackbone):
|
||||
def __init__(
|
||||
self,
|
||||
model_name = "microsoft/resnet-18",
|
||||
freeze: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = ResNetModel.from_pretrained(model_name)
|
||||
self.out_channels = self.model.config.hidden_sizes[-1]
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((384, 384)),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
self.spatial_softmax = SpatialSoftmax(num_rows=12, num_cols=12)
|
||||
if freeze:
|
||||
self._freeze_parameters()
|
||||
|
||||
def _freeze_parameters(self):
|
||||
print("❄️ Freezing ResNet Backbone parameters")
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
self.model.eval()
|
||||
|
||||
def forward_single_image(self, image):
|
||||
B, T, C, H, W = image.shape
|
||||
image = image.view(B * T, C, H, W)
|
||||
image = self.transform(image)
|
||||
feature_map = self.model(image).last_hidden_state # (B*T, D, H', W')
|
||||
features = self.spatial_softmax(feature_map) # (B*T, D*2)
|
||||
return features
|
||||
|
||||
def forward(self, images):
|
||||
any_tensor = next(iter(images.values()))
|
||||
B, T = any_tensor.shape[:2]
|
||||
features_all = []
|
||||
sorted_cam_names = sorted(images.keys())
|
||||
for cam_name in sorted_cam_names:
|
||||
img = images[cam_name]
|
||||
features = self.forward_single_image(img) # (B*T, D*2)
|
||||
features_all.append(features)
|
||||
combined_features = torch.cat(features_all, dim=1) # (B*T, Num_Cams*D*2)
|
||||
return combined_features.view(B, T, -1)
|
||||
|
||||
@property
|
||||
def output_dim(self):
|
||||
"""Output dimension after spatial softmax: out_channels * 2"""
|
||||
return self.out_channels * 2
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
将特征图 (N, C, H, W) 转换为坐标特征 (N, C*2)
|
||||
"""
|
||||
def __init__(self, num_rows, num_cols, temperature=None):
|
||||
super().__init__()
|
||||
self.temperature = nn.Parameter(torch.ones(1))
|
||||
# 创建网格坐标
|
||||
pos_x, pos_y = torch.meshgrid(
|
||||
torch.linspace(-1, 1, num_rows),
|
||||
torch.linspace(-1, 1, num_cols),
|
||||
indexing='ij'
|
||||
)
|
||||
self.register_buffer('pos_x', pos_x.reshape(-1))
|
||||
self.register_buffer('pos_y', pos_y.reshape(-1))
|
||||
|
||||
def forward(self, x):
|
||||
N, C, H, W = x.shape
|
||||
x = x.view(N, C, -1) # (N, C, H*W)
|
||||
|
||||
# 计算 Softmax 注意力图
|
||||
softmax_attention = torch.nn.functional.softmax(x / self.temperature, dim=2)
|
||||
|
||||
# 计算期望坐标 (x, y)
|
||||
expected_x = torch.sum(softmax_attention * self.pos_x, dim=2, keepdim=True)
|
||||
expected_y = torch.sum(softmax_attention * self.pos_y, dim=2, keepdim=True)
|
||||
|
||||
# 拼接并展平 -> (N, C*2)
|
||||
return torch.cat([expected_x, expected_y], dim=2).reshape(N, -1)
|
||||
@@ -1,9 +1,8 @@
|
||||
# # Action Head models
|
||||
from .diffusion import DiffusionHead
|
||||
from .diffusion import ConditionalUnet1D
|
||||
# from .act import ACTHead
|
||||
|
||||
__all__ = ["DiffusionHead"]
|
||||
__all__ = ["ConditionalUnet1D"]
|
||||
|
||||
# from .debug import DebugHead
|
||||
|
||||
# __all__ = ["DebugHead"]
|
||||
@@ -1 +0,0 @@
|
||||
# ACT-VAE Action Head 实现
|
||||
@@ -1,33 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Optional
|
||||
from roboimi.vla.core.interfaces import VLAHead
|
||||
|
||||
class DebugHead(VLAHead):
|
||||
"""
|
||||
A fake Action Head using MSE Loss.
|
||||
Replaces complex Diffusion/ACT policies for architecture verification.
|
||||
"""
|
||||
def __init__(self, input_dim: int, action_dim: int, chunk_size: int = 16):
|
||||
super().__init__()
|
||||
# Simple regression from embedding -> action chunk
|
||||
self.regressor = nn.Linear(input_dim, chunk_size * action_dim)
|
||||
self.action_dim = action_dim
|
||||
self.chunk_size = chunk_size
|
||||
self.loss_fn = nn.MSELoss()
|
||||
|
||||
def forward(self, embeddings: torch.Tensor, actions: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
||||
# Simple pooling over sequence dimension to get (B, Hidden)
|
||||
pooled_embed = embeddings.mean(dim=1)
|
||||
|
||||
# Predict actions: (B, Chunk * Act_Dim) -> (B, Chunk, Act_Dim)
|
||||
pred_flat = self.regressor(pooled_embed)
|
||||
pred_actions = pred_flat.view(-1, self.chunk_size, self.action_dim)
|
||||
|
||||
output = {"pred_actions": pred_actions}
|
||||
|
||||
if actions is not None:
|
||||
# Calculate MSE Loss against ground truth
|
||||
output["loss"] = self.loss_fn(pred_actions, actions)
|
||||
|
||||
return output
|
||||
@@ -5,170 +5,290 @@ from typing import Dict, Optional
|
||||
from diffusers import DDPMScheduler
|
||||
from roboimi.vla.core.interfaces import VLAHead
|
||||
|
||||
class DiffusionHead(VLAHead):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int, # 来自 Projector 的维度 (e.g. 384)
|
||||
action_dim: int, # 动作维度 (e.g. 16)
|
||||
chunk_size: int, # 预测视界 (e.g. 16)
|
||||
n_timesteps: int = 100, # 扩散步数
|
||||
hidden_dim: int = 256
|
||||
):
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import einops
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops.layers.torch import Rearrange
|
||||
import math
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.action_dim = action_dim
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
# 1. 噪声调度器 (DDPM)
|
||||
self.scheduler = DDPMScheduler(
|
||||
num_train_timesteps=n_timesteps,
|
||||
beta_schedule='squaredcos_cap_v2', # 现代 Diffusion 常用调度
|
||||
clip_sample=True,
|
||||
prediction_type='epsilon' # 预测噪声
|
||||
)
|
||||
self.dim = dim
|
||||
|
||||
# 2. 噪声预测网络 (Noise Predictor Network)
|
||||
# 输入: Noisy Action + Time Embedding + Image Embedding
|
||||
# 这是一个简单的 Conditional MLP/ResNet 结构
|
||||
self.time_emb = nn.Sequential(
|
||||
nn.Linear(1, hidden_dim),
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
'''
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
'''
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
||||
nn.Mish(),
|
||||
nn.Linear(hidden_dim, hidden_dim)
|
||||
)
|
||||
|
||||
self.cond_proj = nn.Linear(input_dim, hidden_dim) # 把图像特征投影一下
|
||||
|
||||
# 主干网络 (由几个 Residual Block 组成)
|
||||
self.mid_layers = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.Linear(hidden_dim + action_dim * chunk_size, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(hidden_dim, hidden_dim + action_dim * chunk_size) # 简单的残差
|
||||
) for _ in range(3)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
class ConditionalResidualBlock1D(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
cond_dim,
|
||||
kernel_size=3,
|
||||
n_groups=8,
|
||||
cond_predict_scale=False):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList([
|
||||
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
])
|
||||
|
||||
# 输出层: 预测噪声 (Shape 与 Action 相同)
|
||||
self.final_layer = nn.Linear(hidden_dim + action_dim * chunk_size, action_dim * chunk_size)
|
||||
|
||||
def forward(self, embeddings: torch.Tensor, actions: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Unified interface for Training and Inference.
|
||||
"""
|
||||
device = embeddings.device
|
||||
|
||||
# --- 1. 处理条件 (Conditioning) ---
|
||||
# embeddings: (B, Seq, Dim). 我们这里做一个简化,做 Average Pooling 变成 (B, Dim)
|
||||
# 如果你想做更复杂的 Cross-Attention,可以在这里改
|
||||
global_cond = embeddings.mean(dim=1)
|
||||
cond_feat = self.cond_proj(global_cond) # (B, Hidden)
|
||||
|
||||
# =========================================
|
||||
# 分支 A: 训练模式 (Training)
|
||||
# =========================================
|
||||
if actions is not None:
|
||||
batch_size = actions.shape[0]
|
||||
|
||||
# 1.1 准备数据 (Flatten: B, Chunk, ActDim -> B, Chunk*ActDim)
|
||||
actions_flat = actions.view(batch_size, -1)
|
||||
|
||||
# 1.2 采样噪声和时间步
|
||||
noise = torch.randn_like(actions_flat)
|
||||
timesteps = torch.randint(
|
||||
0, self.scheduler.config.num_train_timesteps,
|
||||
(batch_size,), device=device
|
||||
).long()
|
||||
|
||||
# 1.3 加噪 (Forward Diffusion)
|
||||
noisy_actions = self.scheduler.add_noise(actions_flat, noise, timesteps)
|
||||
|
||||
# 1.4 预测噪声 (Network Forward)
|
||||
pred_noise = self._predict_noise(noisy_actions, timesteps, cond_feat)
|
||||
|
||||
# 1.5 计算 Loss (MSE between actual noise and predicted noise)
|
||||
loss = nn.functional.mse_loss(pred_noise, noise)
|
||||
|
||||
return {"loss": loss}
|
||||
|
||||
# =========================================
|
||||
# 分支 B: 推理模式 (Inference)
|
||||
# =========================================
|
||||
cond_channels = out_channels
|
||||
if cond_predict_scale:
|
||||
cond_channels = out_channels * 2
|
||||
self.cond_predict_scale = cond_predict_scale
|
||||
self.out_channels = out_channels
|
||||
self.cond_encoder = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(cond_dim, cond_channels),
|
||||
Rearrange('batch t -> batch t 1'),
|
||||
)
|
||||
|
||||
# make sure dimensions compatible
|
||||
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
|
||||
if in_channels != out_channels else nn.Identity()
|
||||
|
||||
def forward(self, x, cond):
|
||||
'''
|
||||
x : [ batch_size x in_channels x horizon ]
|
||||
cond : [ batch_size x cond_dim]
|
||||
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
'''
|
||||
out = self.blocks[0](x)
|
||||
embed = self.cond_encoder(cond)
|
||||
if self.cond_predict_scale:
|
||||
embed = embed.reshape(
|
||||
embed.shape[0], 2, self.out_channels, 1)
|
||||
scale = embed[:,0,...]
|
||||
bias = embed[:,1,...]
|
||||
out = scale * out + bias
|
||||
else:
|
||||
batch_size = embeddings.shape[0]
|
||||
|
||||
# 2.1 从纯高斯噪声开始
|
||||
noisy_actions = torch.randn(
|
||||
batch_size, self.chunk_size * self.action_dim,
|
||||
device=device
|
||||
)
|
||||
|
||||
# 2.2 逐步去噪 (Reverse Diffusion Loop)
|
||||
# 使用 scheduler.timesteps 自动处理步长
|
||||
self.scheduler.set_timesteps(self.scheduler.config.num_train_timesteps)
|
||||
|
||||
for t in self.scheduler.timesteps:
|
||||
# 构造 batch 的 t
|
||||
timesteps = torch.tensor([t], device=device).repeat(batch_size)
|
||||
|
||||
# 预测噪声
|
||||
# 注意:diffusers 的 step 需要 model_output
|
||||
model_output = self._predict_noise(noisy_actions, timesteps, cond_feat)
|
||||
|
||||
# 移除噪声 (Step)
|
||||
noisy_actions = self.scheduler.step(
|
||||
model_output, t, noisy_actions
|
||||
).prev_sample
|
||||
out = out + embed
|
||||
out = self.blocks[1](out)
|
||||
out = out + self.residual_conv(x)
|
||||
return out
|
||||
|
||||
# 2.3 Reshape 回 (B, Chunk, ActDim)
|
||||
pred_actions = noisy_actions.view(batch_size, self.chunk_size, self.action_dim)
|
||||
|
||||
return {"pred_actions": pred_actions}
|
||||
|
||||
def _predict_noise(self, noisy_actions, timesteps, cond_feat):
|
||||
"""内部辅助函数:运行简单的 MLP 网络"""
|
||||
# Time Embed
|
||||
t_emb = self.time_emb(timesteps.float().unsqueeze(-1)) # (B, Hidden)
|
||||
class ConditionalUnet1D(nn.Module):
|
||||
def __init__(self,
|
||||
input_dim,
|
||||
local_cond_dim=None,
|
||||
global_cond_dim=None,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=[256,512,1024],
|
||||
kernel_size=3,
|
||||
n_groups=8,
|
||||
cond_predict_scale=False
|
||||
):
|
||||
super().__init__()
|
||||
all_dims = [input_dim] + list(down_dims)
|
||||
start_dim = down_dims[0]
|
||||
|
||||
dsed = diffusion_step_embed_dim
|
||||
diffusion_step_encoder = nn.Sequential(
|
||||
SinusoidalPosEmb(dsed),
|
||||
nn.Linear(dsed, dsed * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dsed * 4, dsed),
|
||||
)
|
||||
cond_dim = dsed
|
||||
if global_cond_dim is not None:
|
||||
cond_dim += global_cond_dim
|
||||
|
||||
in_out = list(zip(all_dims[:-1], all_dims[1:]))
|
||||
|
||||
local_cond_encoder = None
|
||||
if local_cond_dim is not None:
|
||||
_, dim_out = in_out[0]
|
||||
dim_in = local_cond_dim
|
||||
local_cond_encoder = nn.ModuleList([
|
||||
# down encoder
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in, dim_out, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale),
|
||||
# up encoder
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in, dim_out, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale)
|
||||
])
|
||||
|
||||
mid_dim = all_dims[-1]
|
||||
self.mid_modules = nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim, mid_dim, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim, mid_dim, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale
|
||||
),
|
||||
])
|
||||
|
||||
down_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
down_modules.append(nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in, dim_out, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out, dim_out, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
up_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
up_modules.append(nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out*2, dim_in, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in, dim_in, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
# Fusion: Concat Action + (Condition * Time)
|
||||
# 这里用简单的相加融合,实际可以更复杂
|
||||
fused_feat = cond_feat + t_emb
|
||||
final_conv = nn.Sequential(
|
||||
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
|
||||
nn.Conv1d(start_dim, input_dim, 1),
|
||||
)
|
||||
|
||||
self.diffusion_step_encoder = diffusion_step_encoder
|
||||
self.local_cond_encoder = local_cond_encoder
|
||||
self.up_modules = up_modules
|
||||
self.down_modules = down_modules
|
||||
self.final_conv = final_conv
|
||||
|
||||
|
||||
def forward(self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
local_cond=None, global_cond=None, **kwargs):
|
||||
"""
|
||||
x: (B,T,input_dim)
|
||||
timestep: (B,) or int, diffusion step
|
||||
local_cond: (B,T,local_cond_dim)
|
||||
global_cond: (B,global_cond_dim)
|
||||
output: (B,T,input_dim)
|
||||
"""
|
||||
sample = einops.rearrange(sample, 'b h t -> b t h')
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
global_feature = self.diffusion_step_encoder(timesteps)
|
||||
|
||||
if global_cond is not None:
|
||||
global_feature = torch.cat([
|
||||
global_feature, global_cond
|
||||
], axis=-1)
|
||||
|
||||
# Concat input
|
||||
x = torch.cat([noisy_actions, fused_feat], dim=-1) # 注意这里维度需要对齐,或者用 MLP 映射
|
||||
# encode local features
|
||||
h_local = list()
|
||||
if local_cond is not None:
|
||||
local_cond = einops.rearrange(local_cond, 'b h t -> b t h')
|
||||
resnet, resnet2 = self.local_cond_encoder
|
||||
x = resnet(local_cond, global_feature)
|
||||
h_local.append(x)
|
||||
x = resnet2(local_cond, global_feature)
|
||||
h_local.append(x)
|
||||
|
||||
# 修正:上面的 concat 维度可能不对,为了简化代码,我们用一种更简单的方式:
|
||||
# 将 cond_feat 加到 input 里需要维度匹配。
|
||||
# 这里重写一个极简的 Forward:
|
||||
|
||||
# 正确做法:先将 x 映射到 hidden,再加 t_emb 和 cond_feat
|
||||
# 但为了复用 self.mid_layers 定义的 Linear(Hidden + Input)...
|
||||
# 我们用最傻瓜的方式:Input = Action,Condition 直接拼接到每一层或者只拼输入
|
||||
|
||||
# 让我们修正一下网络结构逻辑,确保不报错:
|
||||
# Input: NoisyAction (Dim_A)
|
||||
# Cond: Hidden (Dim_H)
|
||||
|
||||
# 这种临时写的 MLP 容易维度不匹配,我们改用一个极其稳健的计算流:
|
||||
# x = Action
|
||||
# h = Cond + Time
|
||||
# input = cat([x, h]) -> Linear -> Output
|
||||
|
||||
# 重新定义 _predict_noise 的逻辑依赖于 __init__ 里的定义。
|
||||
# 为了保证一次跑通,我使用动态 cat:
|
||||
|
||||
x = noisy_actions
|
||||
# 假设 mid_layers 的输入是 hidden_dim + action_flat_dim
|
||||
# 我们把 condition 映射成 hidden_dim,然后 concat
|
||||
|
||||
# 真正的计算流:
|
||||
h = cond_feat + t_emb # (B, Hidden)
|
||||
|
||||
# 把 h 拼接到 x 上 (前提是 x 是 action flat)
|
||||
# Linear 输入维度是 Hidden + ActFlat
|
||||
model_input = torch.cat([h, x], dim=-1)
|
||||
|
||||
for layer in self.mid_layers:
|
||||
# Residual connection mechanism
|
||||
out = layer(model_input)
|
||||
model_input = out + model_input # Simple ResNet
|
||||
|
||||
return self.final_layer(model_input)
|
||||
x = sample
|
||||
h = []
|
||||
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||
x = resnet(x, global_feature)
|
||||
if idx == 0 and len(h_local) > 0:
|
||||
x = x + h_local[0]
|
||||
x = resnet2(x, global_feature)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
for mid_module in self.mid_modules:
|
||||
x = mid_module(x, global_feature)
|
||||
|
||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, global_feature)
|
||||
# The correct condition should be:
|
||||
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
|
||||
# However this change will break compatibility with published checkpoints.
|
||||
# Therefore it is left as a comment.
|
||||
if idx == len(self.up_modules) and len(h_local) > 0:
|
||||
x = x + h_local[1]
|
||||
x = resnet2(x, global_feature)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
x = einops.rearrange(x, 'b t h -> b h t')
|
||||
return x
|
||||
|
||||
|
||||
Reference in New Issue
Block a user