Files
roboimi/roboimi/vla/agent.py
2026-02-03 14:18:30 +08:00

73 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# roboimi/vla/agent.py
import torch
import torch.nn as nn
from typing import Optional, Dict, Union
class VLAAgent(nn.Module):
def __init__(self,
vlm_backbone: nn.Module,
img_projector: nn.Module,
action_head: nn.Module,
state_dim: int,
embed_dim: int):
super().__init__()
self.vlm_backbone = vlm_backbone
self.img_projector = img_projector
self.action_head = action_head
# 简单的状态编码器 (通常不需要复杂的 config直接写在这里即可)
self.state_encoder = nn.Sequential(
nn.Linear(state_dim, embed_dim),
nn.Mish(),
nn.Linear(embed_dim, embed_dim)
)
def forward(self,
images: torch.Tensor,
state: torch.Tensor,
text: Optional[Union[str, list]] = None,
actions: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Dict]:
"""
Args:
images: [Batch, Obs_Horizon, C, H, W] 注意: 这里需要处理时间维度
state: [Batch, Obs_Horizon, State_Dim]
text: Optional text instructions
actions: [Batch, Pred_Horizon, Action_Dim] (Training only)
Returns:
Training: Loss scalar
Inference: Predicted actions
"""
B, T, C, H, W = images.shape
# 1. 图像编码 (Flatten time dimension for efficiency)
# [B*T, C, H, W] -> [B*T, Vision_Dim]
flat_images = images.view(B * T, C, H, W)
vision_feats_dict = self.vlm_backbone(flat_images)
raw_img_emb = vision_feats_dict['image_embeds'] # [B*T, Vision_Dim]
# 投影并还原时间维度 -> [B, T, Embed_Dim]
img_emb = self.img_projector(raw_img_emb)
img_emb = img_emb.view(B, T, -1)
# 2. 状态编码
state_emb = self.state_encoder(state) # [B, T, Embed_Dim]
# 3. 特征融合 (这里做一个简单的 Early Fusion 示例)
# 将图像特征和状态特征在特征维度拼接,或在时间维度拼接
# 假设我们只用最近的一帧图像作为 Context或者将所有历史特征作为 Context
# 这里演示Context = (Image_History + State_History)
# [B, T, Embed] + [B, T, Embed] -> [B, 2*T, Embed] (Concat on time)
context = torch.cat([img_emb, state_emb], dim=1)
# 4. Action Head 分支
if actions is not None:
# --- Training Mode ---
# 必须返回 Loss
return self.action_head.compute_loss(context, actions)
else:
# --- Inference Mode ---
# 必须返回预测的动作序列
return self.action_head.predict_action(context)