debug(train): 在siglip和DiffusionHead下跑通训练流程
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
# Backbone models
|
||||
# Uncomment when these are implemented:
|
||||
# from .siglip import SigLIPBackbone
|
||||
from .siglip import SigLIPBackbone
|
||||
# from .clip import CLIPBackbone
|
||||
# from .dinov2 import DinoV2Backbone
|
||||
from .debug import DebugBackbone
|
||||
|
||||
__all__ = ["DebugBackbone"]
|
||||
__all__ = ["SigLIPBackbone"]
|
||||
|
||||
# from .debug import DebugBackbone
|
||||
# __all__ = ["DebugBackbone"]
|
||||
@@ -1 +1,62 @@
|
||||
# SigLIP Backbone 实现
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModel, AutoProcessor, SiglipVisionModel
|
||||
from typing import Dict, Optional
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
|
||||
class SigLIPBackbone(VLABackbone):
|
||||
"""
|
||||
Wraps Google's SigLIP Vision Encoder.
|
||||
HuggingFace ID example: "google/siglip-so400m-patch14-384"
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "google/siglip-so400m-patch14-384",
|
||||
freeze: bool = True,
|
||||
embed_dim: Optional[int] = None
|
||||
):
|
||||
super().__init__()
|
||||
print(f"Loading SigLIP: {model_name} ...")
|
||||
|
||||
# 加载视觉部分 (Vision Model only)
|
||||
# 我们不需要 Text Tower,因为 SigLIP 是对齐好的,只用 Vision Tower 抽特征即可
|
||||
self.vision_model = SiglipVisionModel.from_pretrained(model_name)
|
||||
|
||||
# 优先使用配置传入的 embed_dim,否则自动获取
|
||||
if embed_dim is not None:
|
||||
self._embed_dim = embed_dim
|
||||
print(f"✓ Using configured embed_dim: {embed_dim}")
|
||||
else:
|
||||
# 自动获取维度 (SigLIP so400m 通常是 1152)
|
||||
self._embed_dim = self.vision_model.config.hidden_size
|
||||
print(f"✓ Auto-detected embed_dim: {self._embed_dim}")
|
||||
|
||||
if freeze:
|
||||
self._freeze_parameters()
|
||||
|
||||
def _freeze_parameters(self):
|
||||
print("❄️ Freezing Vision Backbone parameters")
|
||||
for param in self.vision_model.parameters():
|
||||
param.requires_grad = False
|
||||
self.vision_model.eval()
|
||||
|
||||
def forward(self, obs: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
obs['image']: (B, C, H, W) normalized tensor
|
||||
Returns:
|
||||
features: (B, Seq_Len, Embed_Dim)
|
||||
"""
|
||||
images = obs['image']
|
||||
|
||||
# SigLIP 期望输入是 (B, C, H, W)
|
||||
# HuggingFace 的 VisionModel 输出是一个 BaseModelOutputWithPooling
|
||||
# last_hidden_state shape: (B, Num_Patches, Embed_Dim)
|
||||
outputs = self.vision_model(pixel_values=images)
|
||||
|
||||
return outputs.last_hidden_state
|
||||
|
||||
@property
|
||||
def embed_dim(self) -> int:
|
||||
return self._embed_dim
|
||||
@@ -1,9 +1,9 @@
|
||||
# # Action Head models
|
||||
# from .diffusion import DiffusionActionHead
|
||||
from .diffusion import DiffusionHead
|
||||
# from .act import ACTHead
|
||||
|
||||
# __all__ = ["DiffusionActionHead", "ACTHead"]
|
||||
__all__ = ["DiffusionHead"]
|
||||
|
||||
from .debug import DebugHead
|
||||
# from .debug import DebugHead
|
||||
|
||||
__all__ = ["DebugHead"]
|
||||
# __all__ = ["DebugHead"]
|
||||
@@ -1 +1,174 @@
|
||||
# Diffusion Policy Action Head 实现
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Optional
|
||||
from diffusers import DDPMScheduler
|
||||
from roboimi.vla.core.interfaces import VLAHead
|
||||
|
||||
class DiffusionHead(VLAHead):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int, # 来自 Projector 的维度 (e.g. 384)
|
||||
action_dim: int, # 动作维度 (e.g. 16)
|
||||
chunk_size: int, # 预测视界 (e.g. 16)
|
||||
n_timesteps: int = 100, # 扩散步数
|
||||
hidden_dim: int = 256
|
||||
):
|
||||
super().__init__()
|
||||
self.action_dim = action_dim
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
# 1. 噪声调度器 (DDPM)
|
||||
self.scheduler = DDPMScheduler(
|
||||
num_train_timesteps=n_timesteps,
|
||||
beta_schedule='squaredcos_cap_v2', # 现代 Diffusion 常用调度
|
||||
clip_sample=True,
|
||||
prediction_type='epsilon' # 预测噪声
|
||||
)
|
||||
|
||||
# 2. 噪声预测网络 (Noise Predictor Network)
|
||||
# 输入: Noisy Action + Time Embedding + Image Embedding
|
||||
# 这是一个简单的 Conditional MLP/ResNet 结构
|
||||
self.time_emb = nn.Sequential(
|
||||
nn.Linear(1, hidden_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(hidden_dim, hidden_dim)
|
||||
)
|
||||
|
||||
self.cond_proj = nn.Linear(input_dim, hidden_dim) # 把图像特征投影一下
|
||||
|
||||
# 主干网络 (由几个 Residual Block 组成)
|
||||
self.mid_layers = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.Linear(hidden_dim + action_dim * chunk_size, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(hidden_dim, hidden_dim + action_dim * chunk_size) # 简单的残差
|
||||
) for _ in range(3)
|
||||
])
|
||||
|
||||
# 输出层: 预测噪声 (Shape 与 Action 相同)
|
||||
self.final_layer = nn.Linear(hidden_dim + action_dim * chunk_size, action_dim * chunk_size)
|
||||
|
||||
def forward(self, embeddings: torch.Tensor, actions: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Unified interface for Training and Inference.
|
||||
"""
|
||||
device = embeddings.device
|
||||
|
||||
# --- 1. 处理条件 (Conditioning) ---
|
||||
# embeddings: (B, Seq, Dim). 我们这里做一个简化,做 Average Pooling 变成 (B, Dim)
|
||||
# 如果你想做更复杂的 Cross-Attention,可以在这里改
|
||||
global_cond = embeddings.mean(dim=1)
|
||||
cond_feat = self.cond_proj(global_cond) # (B, Hidden)
|
||||
|
||||
# =========================================
|
||||
# 分支 A: 训练模式 (Training)
|
||||
# =========================================
|
||||
if actions is not None:
|
||||
batch_size = actions.shape[0]
|
||||
|
||||
# 1.1 准备数据 (Flatten: B, Chunk, ActDim -> B, Chunk*ActDim)
|
||||
actions_flat = actions.view(batch_size, -1)
|
||||
|
||||
# 1.2 采样噪声和时间步
|
||||
noise = torch.randn_like(actions_flat)
|
||||
timesteps = torch.randint(
|
||||
0, self.scheduler.config.num_train_timesteps,
|
||||
(batch_size,), device=device
|
||||
).long()
|
||||
|
||||
# 1.3 加噪 (Forward Diffusion)
|
||||
noisy_actions = self.scheduler.add_noise(actions_flat, noise, timesteps)
|
||||
|
||||
# 1.4 预测噪声 (Network Forward)
|
||||
pred_noise = self._predict_noise(noisy_actions, timesteps, cond_feat)
|
||||
|
||||
# 1.5 计算 Loss (MSE between actual noise and predicted noise)
|
||||
loss = nn.functional.mse_loss(pred_noise, noise)
|
||||
|
||||
return {"loss": loss}
|
||||
|
||||
# =========================================
|
||||
# 分支 B: 推理模式 (Inference)
|
||||
# =========================================
|
||||
else:
|
||||
batch_size = embeddings.shape[0]
|
||||
|
||||
# 2.1 从纯高斯噪声开始
|
||||
noisy_actions = torch.randn(
|
||||
batch_size, self.chunk_size * self.action_dim,
|
||||
device=device
|
||||
)
|
||||
|
||||
# 2.2 逐步去噪 (Reverse Diffusion Loop)
|
||||
# 使用 scheduler.timesteps 自动处理步长
|
||||
self.scheduler.set_timesteps(self.scheduler.config.num_train_timesteps)
|
||||
|
||||
for t in self.scheduler.timesteps:
|
||||
# 构造 batch 的 t
|
||||
timesteps = torch.tensor([t], device=device).repeat(batch_size)
|
||||
|
||||
# 预测噪声
|
||||
# 注意:diffusers 的 step 需要 model_output
|
||||
model_output = self._predict_noise(noisy_actions, timesteps, cond_feat)
|
||||
|
||||
# 移除噪声 (Step)
|
||||
noisy_actions = self.scheduler.step(
|
||||
model_output, t, noisy_actions
|
||||
).prev_sample
|
||||
|
||||
# 2.3 Reshape 回 (B, Chunk, ActDim)
|
||||
pred_actions = noisy_actions.view(batch_size, self.chunk_size, self.action_dim)
|
||||
|
||||
return {"pred_actions": pred_actions}
|
||||
|
||||
def _predict_noise(self, noisy_actions, timesteps, cond_feat):
|
||||
"""内部辅助函数:运行简单的 MLP 网络"""
|
||||
# Time Embed
|
||||
t_emb = self.time_emb(timesteps.float().unsqueeze(-1)) # (B, Hidden)
|
||||
|
||||
# Fusion: Concat Action + (Condition * Time)
|
||||
# 这里用简单的相加融合,实际可以更复杂
|
||||
fused_feat = cond_feat + t_emb
|
||||
|
||||
# Concat input
|
||||
x = torch.cat([noisy_actions, fused_feat], dim=-1) # 注意这里维度需要对齐,或者用 MLP 映射
|
||||
|
||||
# 修正:上面的 concat 维度可能不对,为了简化代码,我们用一种更简单的方式:
|
||||
# 将 cond_feat 加到 input 里需要维度匹配。
|
||||
# 这里重写一个极简的 Forward:
|
||||
|
||||
# 正确做法:先将 x 映射到 hidden,再加 t_emb 和 cond_feat
|
||||
# 但为了复用 self.mid_layers 定义的 Linear(Hidden + Input)...
|
||||
# 我们用最傻瓜的方式:Input = Action,Condition 直接拼接到每一层或者只拼输入
|
||||
|
||||
# 让我们修正一下网络结构逻辑,确保不报错:
|
||||
# Input: NoisyAction (Dim_A)
|
||||
# Cond: Hidden (Dim_H)
|
||||
|
||||
# 这种临时写的 MLP 容易维度不匹配,我们改用一个极其稳健的计算流:
|
||||
# x = Action
|
||||
# h = Cond + Time
|
||||
# input = cat([x, h]) -> Linear -> Output
|
||||
|
||||
# 重新定义 _predict_noise 的逻辑依赖于 __init__ 里的定义。
|
||||
# 为了保证一次跑通,我使用动态 cat:
|
||||
|
||||
x = noisy_actions
|
||||
# 假设 mid_layers 的输入是 hidden_dim + action_flat_dim
|
||||
# 我们把 condition 映射成 hidden_dim,然后 concat
|
||||
|
||||
# 真正的计算流:
|
||||
h = cond_feat + t_emb # (B, Hidden)
|
||||
|
||||
# 把 h 拼接到 x 上 (前提是 x 是 action flat)
|
||||
# Linear 输入维度是 Hidden + ActFlat
|
||||
model_input = torch.cat([h, x], dim=-1)
|
||||
|
||||
for layer in self.mid_layers:
|
||||
# Residual connection mechanism
|
||||
out = layer(model_input)
|
||||
model_input = out + model_input # Simple ResNet
|
||||
|
||||
return self.final_layer(model_input)
|
||||
Reference in New Issue
Block a user