73 lines
2.8 KiB
Python
73 lines
2.8 KiB
Python
# roboimi/vla/agent.py
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
from typing import Optional, Dict, Union
|
||
|
||
class VLAAgent(nn.Module):
|
||
def __init__(self,
|
||
vlm_backbone: nn.Module,
|
||
img_projector: nn.Module,
|
||
action_head: nn.Module,
|
||
state_dim: int,
|
||
embed_dim: int):
|
||
super().__init__()
|
||
self.vlm_backbone = vlm_backbone
|
||
self.img_projector = img_projector
|
||
self.action_head = action_head
|
||
|
||
# 简单的状态编码器 (通常不需要复杂的 config,直接写在这里即可)
|
||
self.state_encoder = nn.Sequential(
|
||
nn.Linear(state_dim, embed_dim),
|
||
nn.Mish(),
|
||
nn.Linear(embed_dim, embed_dim)
|
||
)
|
||
|
||
def forward(self,
|
||
images: torch.Tensor,
|
||
state: torch.Tensor,
|
||
text: Optional[Union[str, list]] = None,
|
||
actions: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Dict]:
|
||
"""
|
||
Args:
|
||
images: [Batch, Obs_Horizon, C, H, W] 注意: 这里需要处理时间维度
|
||
state: [Batch, Obs_Horizon, State_Dim]
|
||
text: Optional text instructions
|
||
actions: [Batch, Pred_Horizon, Action_Dim] (Training only)
|
||
|
||
Returns:
|
||
Training: Loss scalar
|
||
Inference: Predicted actions
|
||
"""
|
||
|
||
B, T, C, H, W = images.shape
|
||
|
||
# 1. 图像编码 (Flatten time dimension for efficiency)
|
||
# [B*T, C, H, W] -> [B*T, Vision_Dim]
|
||
flat_images = images.view(B * T, C, H, W)
|
||
vision_feats_dict = self.vlm_backbone(flat_images)
|
||
raw_img_emb = vision_feats_dict['image_embeds'] # [B*T, Vision_Dim]
|
||
|
||
# 投影并还原时间维度 -> [B, T, Embed_Dim]
|
||
img_emb = self.img_projector(raw_img_emb)
|
||
img_emb = img_emb.view(B, T, -1)
|
||
|
||
# 2. 状态编码
|
||
state_emb = self.state_encoder(state) # [B, T, Embed_Dim]
|
||
|
||
# 3. 特征融合 (这里做一个简单的 Early Fusion 示例)
|
||
# 将图像特征和状态特征在特征维度拼接,或在时间维度拼接
|
||
# 假设我们只用最近的一帧图像作为 Context,或者将所有历史特征作为 Context
|
||
# 这里演示:Context = (Image_History + State_History)
|
||
# [B, T, Embed] + [B, T, Embed] -> [B, 2*T, Embed] (Concat on time)
|
||
context = torch.cat([img_emb, state_emb], dim=1)
|
||
|
||
# 4. Action Head 分支
|
||
if actions is not None:
|
||
# --- Training Mode ---
|
||
# 必须返回 Loss
|
||
return self.action_head.compute_loss(context, actions)
|
||
else:
|
||
# --- Inference Mode ---
|
||
# 必须返回预测的动作序列
|
||
return self.action_head.predict_action(context) |