feat: 更新框架,新增数据及定义和backbone
This commit is contained in:
@@ -44,7 +44,7 @@ smooth_method: "ema" # Options: "ema", "moving_avg", "lowpass", "none"
|
|||||||
smooth_alpha: 0.3 # Smoothing factor (0-1), smaller = smoother
|
smooth_alpha: 0.3 # Smoothing factor (0-1), smaller = smoother
|
||||||
|
|
||||||
# transformer settings
|
# transformer settings
|
||||||
batch_size: 15
|
batch_size: 10
|
||||||
state_dim: 16
|
state_dim: 16
|
||||||
action_dim: 16
|
action_dim: 16
|
||||||
lr_backbone: 0.00001
|
lr_backbone: 0.00001
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from hydra.utils import instantiate
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path="../../../roboimi/vla/conf", config_name="config")
|
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||||||
def main(cfg: DictConfig):
|
def main(cfg: DictConfig):
|
||||||
print(OmegaConf.to_yaml(cfg))
|
print(OmegaConf.to_yaml(cfg))
|
||||||
log.info(f"🚀 Starting VLA Training with Real Data (Device: {cfg.train.device})")
|
log.info(f"🚀 Starting VLA Training with Real Data (Device: {cfg.train.device})")
|
||||||
@@ -64,21 +64,30 @@ def main(cfg: DictConfig):
|
|||||||
# 我们在这里做一个映射,模拟多模态融合前的处理
|
# 我们在这里做一个映射,模拟多模态融合前的处理
|
||||||
|
|
||||||
# 假设我们只用配置里的第一个 key 作为主视觉
|
# 假设我们只用配置里的第一个 key 作为主视觉
|
||||||
primary_cam_key = cfg.data.obs_keys[0]
|
# primary_cam_key = cfg.data.obs_keys[0]
|
||||||
|
|
||||||
# Dataset 返回 shape: (B, Obs_Horizon, C, H, W)
|
# Dataset 返回 shape: (B, Obs_Horizon, C, H, W)
|
||||||
# DebugBackbone 期望: (B, C, H, W) 或者 (B, Seq, Dim)
|
# DebugBackbone 期望: (B, C, H, W) 或者 (B, Seq, Dim)
|
||||||
# 这里我们取 Obs_Horizon 的最后一帧 (Current Frame)
|
# 这里我们取 Obs_Horizon 的最后一帧 (Current Frame)
|
||||||
input_img = batch['obs'][primary_cam_key][:, -1, :, :, :]
|
# input_img = batch['obs'][primary_cam_key][:, -1, :, :, :]
|
||||||
|
|
||||||
|
# agent_input = {
|
||||||
|
# "obs": {
|
||||||
|
# "image": input_img,
|
||||||
|
# "text": batch["language"] # 传递语言指令
|
||||||
|
# },
|
||||||
|
# "actions": batch["actions"] # (B, Chunk, Dim)
|
||||||
|
# }
|
||||||
agent_input = {
|
agent_input = {
|
||||||
"obs": {
|
"action": batch["action"],
|
||||||
"image": input_img,
|
"qpos": batch["qpos"],
|
||||||
"text": batch["language"] # 传递语言指令
|
"images": {}
|
||||||
},
|
|
||||||
"actions": batch["actions"] # (B, Chunk, Dim)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for cam_name in cfg.data.camera_names:
|
||||||
|
key = f"image_{cam_name}"
|
||||||
|
agent_input["images"][cam_name] = batch[key].squeeze(1)
|
||||||
|
|
||||||
# --- 5. Forward & Backward ---
|
# --- 5. Forward & Backward ---
|
||||||
outputs = agent(agent_input)
|
outputs = agent(agent_input)
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ SIM_TASK_CONFIGS = {
|
|||||||
# },
|
# },
|
||||||
'sim_transfer': {
|
'sim_transfer': {
|
||||||
'dataset_dir': DATASET_DIR + '/sim_transfer',
|
'dataset_dir': DATASET_DIR + '/sim_transfer',
|
||||||
'num_episodes': 7,
|
'num_episodes': 20,
|
||||||
'episode_len': 700,
|
'episode_len': 700,
|
||||||
'camera_names': ['top','r_vis'],
|
'camera_names': ['top','r_vis'],
|
||||||
'xml_dir': HOME_PATH + '/assets'
|
'xml_dir': HOME_PATH + '/assets'
|
||||||
|
|||||||
@@ -4,29 +4,27 @@ from typing import Dict, Optional, Any
|
|||||||
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
||||||
|
|
||||||
class VLAAgent(nn.Module):
|
class VLAAgent(nn.Module):
|
||||||
"""
|
|
||||||
The main assembly class.
|
|
||||||
Flow: Obs -> Backbone -> Projector -> Head -> Action/Loss
|
|
||||||
"""
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
backbone: VLABackbone,
|
backbone: VLABackbone,
|
||||||
projector: VLAProjector,
|
projector: VLAProjector,
|
||||||
head: VLAHead
|
head: VLAHead,
|
||||||
|
state_encoder: nn.Module
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.projector = projector
|
self.projector = projector
|
||||||
self.head = head
|
self.head = head
|
||||||
|
self.state_encoder = state_encoder
|
||||||
|
|
||||||
def forward(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
def forward(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
|
||||||
Args:
|
action = batch["action"]
|
||||||
batch: Dict containing 'obs' (image/text) and 'actions' (ground truth)
|
state = batch["qpos"]
|
||||||
"""
|
images = batch["images"]
|
||||||
# 1. Extract Features
|
|
||||||
# Shape: (B, Seq, Backbone_Dim)
|
state_emb = self.state_encoder(state)
|
||||||
features = self.backbone(batch['obs'])
|
|
||||||
|
|
||||||
# 2. Project Features
|
# 2. Project Features
|
||||||
# Shape: (B, Seq, Head_Dim)
|
# Shape: (B, Seq, Head_Dim)
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
_target_: roboimi.vla.data.dataset.VLAChunkedDataset
|
|
||||||
|
|
||||||
data_path: "/home/d51/workspace/work/robo-imi-act/roboimi/demos/dataset/sim_transfer"
|
|
||||||
pred_horizon: 16
|
|
||||||
obs_horizon: 1
|
|
||||||
obs_keys: ["top"]
|
|
||||||
|
|
||||||
# 【新增】SigLIP 必须参数
|
|
||||||
resize_resolution: 384
|
|
||||||
train: true # 开启数据增强
|
|
||||||
8
roboimi/vla/conf/data/siglip2.yaml
Normal file
8
roboimi/vla/conf/data/siglip2.yaml
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
_target_: roboimi.vla.data.dataset.RobotDiffusionDataset
|
||||||
|
|
||||||
|
dataset_dir: "/home/d51/workspace/work/robo-imi-act/roboimi/demos/dataset/sim_transfer"
|
||||||
|
pred_horizon: 16
|
||||||
|
obs_horizon: 1
|
||||||
|
action_horizon: 8
|
||||||
|
camera_names: ['r_vis', 'top'] # ['angle', 'r_vis', 'top']
|
||||||
|
normalization_type: 'gaussian' # 'min_max' or 'gaussian'
|
||||||
@@ -18,11 +18,6 @@ class VLABackbone(nn.Module, abc.ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def embed_dim(self) -> int:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class VLAProjector(nn.Module, abc.ABC):
|
class VLAProjector(nn.Module, abc.ABC):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,103 +1,156 @@
|
|||||||
import h5py
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
from torch.utils.data import Dataset
|
import pickle
|
||||||
from typing import Dict, List, Any
|
|
||||||
|
|
||||||
# 【新增】导入刚才写好的处理器
|
class RobotDiffusionDataset(Dataset):
|
||||||
from .image_transform import VLAImageProcessor
|
def __init__(self,
|
||||||
|
dataset_dir,
|
||||||
class VLAChunkedDataset(Dataset):
|
pred_horizon=16,
|
||||||
def __init__(
|
obs_horizon=1,
|
||||||
self,
|
action_horizon=8,
|
||||||
data_path: str,
|
camera_names=['r_vis', 'top'],
|
||||||
pred_horizon: int = 16,
|
normalization_type='gaussian'):
|
||||||
obs_horizon: int = 1,
|
"""
|
||||||
obs_keys: List[str] = ["top"],
|
Args:
|
||||||
resize_resolution: int = 384, # SigLIP 默认 384
|
dataset_dir: 存放 episode_*.hdf5 的文件夹路径
|
||||||
train: bool = True # 【新增】控制是否增强
|
pred_horizon: 预测未来动作的长度 (Tp)
|
||||||
):
|
obs_horizon: 历史观测长度 (To)
|
||||||
self.data_path = data_path
|
action_horizon: 执行动作长度 (Ta) - 在Dataset中主要影响Evaluation,这里作为参数保留
|
||||||
|
"""
|
||||||
|
self.dataset_dir = dataset_dir
|
||||||
self.pred_horizon = pred_horizon
|
self.pred_horizon = pred_horizon
|
||||||
self.obs_horizon = obs_horizon
|
self.obs_horizon = obs_horizon
|
||||||
self.obs_keys = obs_keys
|
self.action_horizon = action_horizon
|
||||||
|
self.camera_names = camera_names
|
||||||
|
self.normalization_type = normalization_type
|
||||||
|
# 1. 扫描所有HDF5文件并建立索引
|
||||||
|
# 格式: [(file_path, episode_length), ...]
|
||||||
|
self.episode_files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
|
||||||
|
self.indices = []
|
||||||
|
|
||||||
# ... (这里保留之前的扫描文件代码 self.file_paths ...) ...
|
print(f"Found {len(self.episode_files)} episodes. Building index...")
|
||||||
if os.path.isdir(data_path):
|
|
||||||
self.file_paths = sorted(glob.glob(os.path.join(data_path, "*.hdf5")))
|
|
||||||
else:
|
|
||||||
self.file_paths = [data_path]
|
|
||||||
|
|
||||||
# ... (这里保留之前的建立索引代码 self.index_map ...) ...
|
for file_path in self.episode_files:
|
||||||
self.index_map = []
|
with h5py.File(file_path, 'r') as f:
|
||||||
for i, path in enumerate(self.file_paths):
|
# 获取该 episode 的长度 (例如 700)
|
||||||
with h5py.File(path, 'r') as f:
|
l = f['action'].shape[0]
|
||||||
total_len = f["action"].shape[0]
|
# 保存每个有效 step 的索引信息
|
||||||
for t in range(total_len):
|
# (file_path, episode_length, current_step_index)
|
||||||
self.index_map.append((i, t))
|
for i in range(l):
|
||||||
|
self.indices.append((file_path, l, i))
|
||||||
|
|
||||||
# 【核心修改】实例化处理器
|
# 2. 统计数据
|
||||||
self.image_processor = VLAImageProcessor(
|
with open(os.path.join(dataset_dir, 'data_stats.pkl'), 'rb') as f:
|
||||||
resolution=resize_resolution,
|
self.stats = pickle.load(f)
|
||||||
enable_augmentation=train, # 训练集开启增强
|
|
||||||
aug_strength=0.1
|
|
||||||
)
|
|
||||||
print(f"✅ Image Processor: {self.image_processor}")
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.index_map)
|
return len(self.indices)
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
def __getitem__(self, idx):
|
||||||
file_idx, t_start = self.index_map[idx]
|
file_path, episode_len, start_ts = self.indices[idx]
|
||||||
file_path = self.file_paths[file_idx]
|
|
||||||
|
|
||||||
with h5py.File(file_path, 'r') as f:
|
# -----------------------------
|
||||||
# ... (Action读取代码保持不变) ...
|
# 1. 打开文件
|
||||||
total_len = f["action"].shape[0]
|
# -----------------------------
|
||||||
t_end = min(t_start + self.pred_horizon, total_len)
|
# 注意: 在 __getitem__ 中打开文件对多进程 DataLoader 更友好
|
||||||
actions_np = f["action"][t_start:t_end]
|
# 如果追求极致IO性能,可以考虑使用 h5py 的 swmr 模式或内存缓存
|
||||||
# ... (Padding 逻辑保持不变) ...
|
with h5py.File(file_path, 'r') as root:
|
||||||
actual_len = actions_np.shape[0]
|
|
||||||
if actual_len < self.pred_horizon:
|
# -----------------------------
|
||||||
pad_len = self.pred_horizon - actual_len
|
# 2. 处理 Action (Prediction Target)
|
||||||
pad_block = np.tile(actions_np[-1], (pad_len, 1))
|
# -----------------------------
|
||||||
actions_np = np.concatenate([actions_np, pad_block], axis=0)
|
# 目标: 获取 [t, t + pred_horizon] 的动作
|
||||||
|
action_start = start_ts
|
||||||
|
action_end = min(start_ts + self.pred_horizon, episode_len)
|
||||||
|
|
||||||
|
actions = root['action'][action_start:action_end] # shape: (T_subset, 16)
|
||||||
|
|
||||||
|
# Padding: 如果剩余动作不足 pred_horizon,复制最后一步
|
||||||
|
if len(actions) < self.pred_horizon:
|
||||||
|
pad_len = self.pred_horizon - len(actions)
|
||||||
|
last_action = actions[-1]
|
||||||
|
# 重复最后一行
|
||||||
|
pad_content = np.repeat(last_action[np.newaxis, :], pad_len, axis=0)
|
||||||
|
actions = np.concatenate([actions, pad_content], axis=0)
|
||||||
|
|
||||||
|
# 归一化 Action
|
||||||
|
if self.stats:
|
||||||
|
actions = self._normalize_data(actions, self.stats['action'])
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# 3. 处理 Observations (History)
|
||||||
|
# -----------------------------
|
||||||
|
# 目标: 获取 [t - obs_horizon + 1, t + 1] 的观测
|
||||||
|
# 索引逻辑:
|
||||||
|
# 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding)
|
||||||
|
# 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5]
|
||||||
|
|
||||||
|
indices = []
|
||||||
|
for i in range(self.obs_horizon):
|
||||||
|
# t - (To - 1) + i
|
||||||
|
query_ts = start_ts - (self.obs_horizon - 1) + i
|
||||||
|
# 边界处理 (Padding first frame)
|
||||||
|
query_ts = max(query_ts, 0)
|
||||||
|
indices.append(query_ts)
|
||||||
|
|
||||||
|
# 读取 qpos (proprioception)
|
||||||
|
qpos_data = root['observations/qpos']
|
||||||
|
qpos = qpos_data[indices] # smart indexing
|
||||||
|
if self.stats:
|
||||||
|
qpos = self._normalize_data(qpos, self.stats['qpos'])
|
||||||
|
|
||||||
|
# 读取 Images
|
||||||
|
# 你有三个视角: angle, r_vis, top
|
||||||
|
# 建议将它们分开返回,或者在 Dataset 里 Concat
|
||||||
|
image_dict = {}
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
# HDF5 dataset
|
||||||
|
img_dset = root['observations']['images'][cam_name]
|
||||||
|
|
||||||
# --- 图像处理部分 ---
|
|
||||||
obs_dict = {}
|
|
||||||
for key in self.obs_keys:
|
|
||||||
imgs = []
|
imgs = []
|
||||||
for i in range(self.obs_horizon):
|
for t in indices:
|
||||||
# 计算历史帧索引
|
img = img_dset[t] # (480, 640, 3) uint8
|
||||||
query_t = max(0, t_start - (self.obs_horizon - 1) + i)
|
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 # (C, H, W)
|
||||||
|
imgs.append(img)
|
||||||
|
|
||||||
# 1. 读取原始数据 (Numpy uint8)
|
# Stack time dimension: (obs_horizon, 3, H, W)
|
||||||
raw_img = f[f"observations/images/{key}"][query_t]
|
image_dict[cam_name] = torch.stack(imgs)
|
||||||
|
|
||||||
# 2. 【调用处理器】 Numpy -> Tensor (384, 384) Normalized
|
# -----------------------------
|
||||||
processed_img = self.image_processor(raw_img)
|
# 4. 组装 Batch
|
||||||
|
# -----------------------------
|
||||||
|
data_batch = {
|
||||||
|
'action': torch.from_numpy(actions).float(), # (Tp, 16)
|
||||||
|
'qpos': torch.from_numpy(qpos).float(), # (To, 16)
|
||||||
|
}
|
||||||
|
# 将图像放入 batch
|
||||||
|
for cam_name, img_tensor in image_dict.items():
|
||||||
|
data_batch[f'image_{cam_name}'] = img_tensor # (To, 3, H, W)
|
||||||
|
|
||||||
imgs.append(processed_img)
|
# TODO: 添加 Language Instruction
|
||||||
|
# 如果所有 episode 共享任务,这里可以是固定 embedding
|
||||||
|
# 如果每个 episode 任务不同,你需要一个额外的 meta json 来映射 file_path -> text
|
||||||
|
# data_batch['lang_text'] = "pick up the red cube"
|
||||||
|
|
||||||
# Stack -> (T, C, H, W)
|
return data_batch
|
||||||
obs_dict[key] = torch.stack(imgs)
|
|
||||||
|
|
||||||
# ... (QPos 和 Language 读取保持不变) ...
|
def _normalize_data(self, data, stats):
|
||||||
qpos = f["observations/qpos"][t_start].astype(np.float32)
|
if self.normalization_type == 'min_max':
|
||||||
lang = f.attrs.get("language", "placeholder")
|
# 之前的逻辑: [-1, 1]
|
||||||
if isinstance(lang, bytes): lang = lang.decode("utf-8")
|
min_val = stats['min']
|
||||||
|
max_val = stats['max']
|
||||||
|
data = (data - min_val) / (max_val - min_val + 1e-8)
|
||||||
|
return data * 2 - 1
|
||||||
|
|
||||||
# 这里的 action_mask 只是临时补全代码,你原来的逻辑是对的
|
elif self.normalization_type == 'gaussian':
|
||||||
action_mask = torch.ones(self.pred_horizon, dtype=torch.float32)
|
# 新逻辑: Mean/Std
|
||||||
if actual_len < self.pred_horizon:
|
mean = stats['mean']
|
||||||
action_mask[actual_len:] = 0.0
|
std = stats['std']
|
||||||
|
# (data - mean) / std
|
||||||
return {
|
# 这里的 data 是 numpy array
|
||||||
"obs": obs_dict,
|
return (data - mean) / (std + 1e-8)
|
||||||
"qpos": torch.from_numpy(qpos),
|
|
||||||
"actions": torch.from_numpy(actions_np).float(),
|
|
||||||
"action_mask": action_mask,
|
|
||||||
"language": lang
|
|
||||||
}
|
|
||||||
37
roboimi/vla/models/backbones/siglip2.py
Normal file
37
roboimi/vla/models/backbones/siglip2.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from transformers import SiglipVisionModel
|
||||||
|
from roboimi.vla.core.interfaces import VLABackbone
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
class SigLIP2(VLABackbone):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name = "google/siglip2-base-patch16-384",
|
||||||
|
freeze: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.vision_model = SiglipVisionModel.from_pretrained(model_name)
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.Resize((384, 384), antialias=True),
|
||||||
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
|
])
|
||||||
|
|
||||||
|
if freeze:
|
||||||
|
self._freeze_parameters()
|
||||||
|
|
||||||
|
def _freeze_parameters(self):
|
||||||
|
print("❄️ Freezing Vision Backbone parameters")
|
||||||
|
for param in self.vision_model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.vision_model.eval()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
images
|
||||||
|
):
|
||||||
|
# images: (B, C, H, W), 归一化到 [0, 1]
|
||||||
|
images = self.transform(images) # 归一化到 [-1, 1]
|
||||||
|
|
||||||
|
outputs = self.vision_model(pixel_values=images)
|
||||||
|
|
||||||
|
return outputs.last_hidden_state
|
||||||
@@ -30,17 +30,17 @@ class MLP(nn.Module):
|
|||||||
class SinusoidalPositionalEncoding(nn.Module):
|
class SinusoidalPositionalEncoding(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
emb_dim
|
embed_dim
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb_dim = emb_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
def forward(self, timesteps):
|
def forward(self, timesteps):
|
||||||
timesteps = timesteps.float()
|
timesteps = timesteps.float()
|
||||||
B, T = timesteps.shape
|
B, T = timesteps.shape
|
||||||
device = timesteps.device
|
device = timesteps.device
|
||||||
|
|
||||||
half_dim = self.emb_dim // 2
|
half_dim = self.embed_dim // 2
|
||||||
|
|
||||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||||
torch.log(torch.tensor(10000.0)) / half_dim
|
torch.log(torch.tensor(10000.0)) / half_dim
|
||||||
@@ -58,14 +58,14 @@ class ActionEncoder(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
action_dim,
|
action_dim,
|
||||||
emb_dim,
|
embed_dim,
|
||||||
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.W1 = nn.Linear(action_dim, emb_dim)
|
self.W1 = nn.Linear(action_dim, embed_dim)
|
||||||
self.W2 = nn.Linear(2 * action_dim, action_dim)
|
self.W2 = nn.Linear(2 * action_dim, action_dim)
|
||||||
self.W3 = nn.Linear(emb_dim, emb_dim)
|
self.W3 = nn.Linear(embed_dim, embed_dim)
|
||||||
self.pos_encoder = SinusoidalPositionalEncoding(emb_dim)
|
self.pos_encoder = SinusoidalPositionalEncoding(embed_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -89,13 +89,13 @@ class StateEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
state_dim,
|
state_dim,
|
||||||
hidden_dim,
|
hidden_dim,
|
||||||
emb_dim
|
embed_dim
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mlp = MLP(
|
self.mlp = MLP(
|
||||||
state_dim,
|
state_dim,
|
||||||
hidden_dim,
|
hidden_dim,
|
||||||
emb_dim
|
embed_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -103,4 +103,4 @@ class StateEncoder(nn.Module):
|
|||||||
states
|
states
|
||||||
):
|
):
|
||||||
state_emb = self.mlp(states)
|
state_emb = self.mlp(states)
|
||||||
return state_emb # [B, 1, emb_dim]
|
return state_emb # [B, 1, embed_dim]
|
||||||
Reference in New Issue
Block a user