114 lines
4.1 KiB
Python
114 lines
4.1 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from typing import Dict, Optional, Any
|
||
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
||
|
||
class VLAAgent(nn.Module):
|
||
"""
|
||
The main assembly class.
|
||
Flow: Obs -> Backbone -> Projector -> Head -> Action/Loss
|
||
"""
|
||
def __init__(
|
||
self,
|
||
backbone: VLABackbone,
|
||
projector: VLAProjector,
|
||
head: VLAHead
|
||
):
|
||
super().__init__()
|
||
self.backbone = backbone
|
||
self.projector = projector
|
||
self.head = head
|
||
|
||
def forward(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
||
"""
|
||
Args:
|
||
batch: Dict containing 'obs' (image/text) and 'actions' (ground truth)
|
||
"""
|
||
# 1. Extract Features
|
||
# Shape: (B, Seq, Backbone_Dim)
|
||
features = self.backbone(batch['obs'])
|
||
|
||
# 2. Project Features
|
||
# Shape: (B, Seq, Head_Dim)
|
||
embeddings = self.projector(features)
|
||
|
||
# 3. Compute Action/Loss
|
||
# We pass actions if they exist (training mode)
|
||
actions = batch.get('actions', None)
|
||
outputs = self.head(embeddings=embeddings, actions=actions)
|
||
|
||
return outputs
|
||
|
||
# # roboimi/vla/agent.py
|
||
|
||
# import torch
|
||
# import torch.nn as nn
|
||
# from typing import Optional, Dict, Union
|
||
|
||
# class VLAAgent(nn.Module):
|
||
# def __init__(self,
|
||
# vlm_backbone: nn.Module,
|
||
# img_projector: nn.Module,
|
||
# action_head: nn.Module,
|
||
# state_dim: int,
|
||
# embed_dim: int):
|
||
# super().__init__()
|
||
# self.vlm_backbone = vlm_backbone
|
||
# self.img_projector = img_projector
|
||
# self.action_head = action_head
|
||
|
||
# # 简单的状态编码器 (通常不需要复杂的 config,直接写在这里即可)
|
||
# self.state_encoder = nn.Sequential(
|
||
# nn.Linear(state_dim, embed_dim),
|
||
# nn.Mish(),
|
||
# nn.Linear(embed_dim, embed_dim)
|
||
# )
|
||
|
||
# def forward(self,
|
||
# images: torch.Tensor,
|
||
# state: torch.Tensor,
|
||
# text: Optional[Union[str, list]] = None,
|
||
# actions: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Dict]:
|
||
# """
|
||
# Args:
|
||
# images: [Batch, Obs_Horizon, C, H, W] 注意: 这里需要处理时间维度
|
||
# state: [Batch, Obs_Horizon, State_Dim]
|
||
# text: Optional text instructions
|
||
# actions: [Batch, Pred_Horizon, Action_Dim] (Training only)
|
||
|
||
# Returns:
|
||
# Training: Loss scalar
|
||
# Inference: Predicted actions
|
||
# """
|
||
|
||
# B, T, C, H, W = images.shape
|
||
|
||
# # 1. 图像编码 (Flatten time dimension for efficiency)
|
||
# # [B*T, C, H, W] -> [B*T, Vision_Dim]
|
||
# flat_images = images.view(B * T, C, H, W)
|
||
# vision_feats_dict = self.vlm_backbone(flat_images)
|
||
# raw_img_emb = vision_feats_dict['image_embeds'] # [B*T, Vision_Dim]
|
||
|
||
# # 投影并还原时间维度 -> [B, T, Embed_Dim]
|
||
# img_emb = self.img_projector(raw_img_emb)
|
||
# img_emb = img_emb.view(B, T, -1)
|
||
|
||
# # 2. 状态编码
|
||
# state_emb = self.state_encoder(state) # [B, T, Embed_Dim]
|
||
|
||
# # 3. 特征融合 (这里做一个简单的 Early Fusion 示例)
|
||
# # 将图像特征和状态特征在特征维度拼接,或在时间维度拼接
|
||
# # 假设我们只用最近的一帧图像作为 Context,或者将所有历史特征作为 Context
|
||
# # 这里演示:Context = (Image_History + State_History)
|
||
# # [B, T, Embed] + [B, T, Embed] -> [B, 2*T, Embed] (Concat on time)
|
||
# context = torch.cat([img_emb, state_emb], dim=1)
|
||
|
||
# # 4. Action Head 分支
|
||
# if actions is not None:
|
||
# # --- Training Mode ---
|
||
# # 必须返回 Loss
|
||
# return self.action_head.compute_loss(context, actions)
|
||
# else:
|
||
# # --- Inference Mode ---
|
||
# # 必须返回预测的动作序列
|
||
# return self.action_head.predict_action(context) |