Files
roboimi/roboimi/vla/agent.py
2026-02-03 16:14:54 +08:00

114 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
from typing import Dict, Optional, Any
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
class VLAAgent(nn.Module):
"""
The main assembly class.
Flow: Obs -> Backbone -> Projector -> Head -> Action/Loss
"""
def __init__(
self,
backbone: VLABackbone,
projector: VLAProjector,
head: VLAHead
):
super().__init__()
self.backbone = backbone
self.projector = projector
self.head = head
def forward(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
"""
Args:
batch: Dict containing 'obs' (image/text) and 'actions' (ground truth)
"""
# 1. Extract Features
# Shape: (B, Seq, Backbone_Dim)
features = self.backbone(batch['obs'])
# 2. Project Features
# Shape: (B, Seq, Head_Dim)
embeddings = self.projector(features)
# 3. Compute Action/Loss
# We pass actions if they exist (training mode)
actions = batch.get('actions', None)
outputs = self.head(embeddings=embeddings, actions=actions)
return outputs
# # roboimi/vla/agent.py
# import torch
# import torch.nn as nn
# from typing import Optional, Dict, Union
# class VLAAgent(nn.Module):
# def __init__(self,
# vlm_backbone: nn.Module,
# img_projector: nn.Module,
# action_head: nn.Module,
# state_dim: int,
# embed_dim: int):
# super().__init__()
# self.vlm_backbone = vlm_backbone
# self.img_projector = img_projector
# self.action_head = action_head
# # 简单的状态编码器 (通常不需要复杂的 config直接写在这里即可)
# self.state_encoder = nn.Sequential(
# nn.Linear(state_dim, embed_dim),
# nn.Mish(),
# nn.Linear(embed_dim, embed_dim)
# )
# def forward(self,
# images: torch.Tensor,
# state: torch.Tensor,
# text: Optional[Union[str, list]] = None,
# actions: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Dict]:
# """
# Args:
# images: [Batch, Obs_Horizon, C, H, W] 注意: 这里需要处理时间维度
# state: [Batch, Obs_Horizon, State_Dim]
# text: Optional text instructions
# actions: [Batch, Pred_Horizon, Action_Dim] (Training only)
# Returns:
# Training: Loss scalar
# Inference: Predicted actions
# """
# B, T, C, H, W = images.shape
# # 1. 图像编码 (Flatten time dimension for efficiency)
# # [B*T, C, H, W] -> [B*T, Vision_Dim]
# flat_images = images.view(B * T, C, H, W)
# vision_feats_dict = self.vlm_backbone(flat_images)
# raw_img_emb = vision_feats_dict['image_embeds'] # [B*T, Vision_Dim]
# # 投影并还原时间维度 -> [B, T, Embed_Dim]
# img_emb = self.img_projector(raw_img_emb)
# img_emb = img_emb.view(B, T, -1)
# # 2. 状态编码
# state_emb = self.state_encoder(state) # [B, T, Embed_Dim]
# # 3. 特征融合 (这里做一个简单的 Early Fusion 示例)
# # 将图像特征和状态特征在特征维度拼接,或在时间维度拼接
# # 假设我们只用最近的一帧图像作为 Context或者将所有历史特征作为 Context
# # 这里演示Context = (Image_History + State_History)
# # [B, T, Embed] + [B, T, Embed] -> [B, 2*T, Embed] (Concat on time)
# context = torch.cat([img_emb, state_emb], dim=1)
# # 4. Action Head 分支
# if actions is not None:
# # --- Training Mode ---
# # 必须返回 Loss
# return self.action_head.compute_loss(context, actions)
# else:
# # --- Inference Mode ---
# # 必须返回预测的动作序列
# return self.action_head.predict_action(context)