64 Commits
ddt ... dev

Author SHA1 Message Date
gouhanke
23088e5e33 feat: 架构引入DiT 2026-03-06 11:31:37 +08:00
gouhanke
ca1716c67f chore: 导入gr00t 2026-03-06 11:31:37 +08:00
JiajunLI
642d41dd8f feat: 添加resume机制 2026-03-06 11:19:30 +08:00
gouhanke
7d39933a5b feat: 缓存worker内的句柄 2026-03-04 10:49:41 +08:00
gouhanke
8bcad5844e fix: 修复VLA设备与损失计算逻辑,并优化Transformer默认训练参数 2026-03-03 17:56:12 +08:00
gouhanke
cdb887c9bf feat: 添加transformer头 2026-02-28 19:07:27 +08:00
gouhanke
abb4f501e3 chore: 删除unet里的local_cond(未使用) 2026-02-28 10:42:16 +08:00
gouhanke
1d33db0ef0 chore: 缩小物块的大小 2026-02-27 18:23:30 +08:00
gouhanke
f27e397f98 chore: 修改了采数时的一些参数 2026-02-26 17:09:40 +08:00
gouhanke
4e0add4e1d debug: 修复episode首帧图像不正确的问题;修复前2个episode帧重复的问题 2026-02-26 16:17:54 +08:00
gouhanke
40c40695dd chore: 添加测试文件
- check_all_episodes.py:检查各个episode是否有重复帧。
- check_specific_frames.py:检查前几帧是否位于正确初始位置。
- generate_dataset_videos.py:根据hdf5生成视频
2026-02-26 13:59:47 +08:00
gouhanke
3deeffb9fe chore:改变了一些参数配置 2026-02-26 13:56:03 +08:00
gouhanke
0b05c01024 feat: 推理时输出action 2026-02-12 19:54:11 +08:00
gouhanke
926a78eb66 feat: 添加finetune 2026-02-12 19:31:44 +08:00
gouhanke
efbe4b6ac9 Revert "Merge branch 'dev' of gitlab.com:leeeezd0016-group/gouhanke-vla into dev"
This reverts commit acb1467473, reversing
changes made to 624b926e33.
2026-02-12 18:31:56 +08:00
gouhanke
acb1467473 Merge branch 'dev' of gitlab.com:leeeezd0016-group/gouhanke-vla into dev 2026-02-12 18:08:27 +08:00
gouhanke
624b926e33 debug: 添加推理时缩放,加大采数以及推理时物块的放置范围 2026-02-12 17:14:23 +08:00
gouhanke
926d8cf894 chore: 加载时将图像缩放到224*224, resnet禁用crop 2026-02-12 15:02:18 +08:00
gouhanke
116ba13fb9 chore: 验证归一化是否有效 2026-02-12 13:01:13 +08:00
gouhanke
37a47ac2dd debug: 保存stats到ckpt 2026-02-12 13:00:43 +08:00
gouhanke
ab971b3f96 debug: 归一化 2026-02-12 12:23:34 +08:00
gouhanke
83cd55e67b 添加pad_loss 2026-02-11 20:33:26 +08:00
gouhanke
eeb07cad15 feat: 冻结resnet 2026-02-11 20:11:25 +08:00
gouhanke
83d11ab640 Merge branch 'dev' of gitlab.com:leeeezd0016-group/gouhanke-vla into dev 2026-02-11 17:20:21 +08:00
JiajunLI
aba8779671 更改默认参数 2026-02-11 17:14:32 +08:00
gouhanke
b42c1c68fd debug: 将归一化放在GPU上 2026-02-11 17:13:55 +08:00
gouhanke
320369ffb8 debug: 归一化图像到[0, 1] 2026-02-11 16:47:39 +08:00
gouhanke
130d4bb3c5 refactor:大重构 2026-02-11 15:53:55 +08:00
gouhanke
1e95d40bf9 debug 2026-02-10 15:56:05 +08:00
gouhanke
3c27d6d793 refactor: 重构resnet 2026-02-10 15:26:10 +08:00
gouhanke
88b9c10a75 refactor(dataset): 重新创建robotdataset最小实现
- 内部实现__getitem__参数,可以通过滑动窗口进行采样
-
2026-02-10 10:26:19 +08:00
gouhanke
ac870f6110 chore: 计算推理频率 2026-02-09 15:39:22 +08:00
gouhanke
8b700b6d99 暂存 2026-02-09 14:41:35 +08:00
gouhanke
f833c6d9f1 添加readme文件 2026-02-07 09:57:59 +08:00
gouhanke
4332530a5f feat(train): 添加warmup学习率调度器 2026-02-06 22:54:34 +08:00
gouhanke
456056347f debug: 固定验证集上的随机噪声,修复resnet在训练时bn层会切换到train的问题 2026-02-06 21:31:19 +08:00
gouhanke
05f3cc1e47 chore: 删除detr和gr00t 2026-02-06 20:21:01 +08:00
gouhanke
a6fcb88203 chore: 删除多余文件 2026-02-06 20:19:11 +08:00
gouhanke
3d0c2ec5b1 feat(train): 添加验证集 2026-02-06 18:00:09 +08:00
gouhanke
ea49e63eb7 feat: 注册了自定义 resolver计算长度 2026-02-06 16:08:56 +08:00
gouhanke
7a9ca06aa0 feat(dependency): 生成environment.yml文件 2026-02-06 15:40:24 +08:00
gouhanke
f006d50814 chore: 自动获取cameras的长度 2026-02-06 15:33:07 +08:00
gouhanke
f4a5c77b7c refactor: 归一化从agent解耦到训练、推理脚本中 2026-02-06 14:29:36 +08:00
gouhanke
a43a2e3d18 chore: 删除多余脚本 2026-02-06 13:45:35 +08:00
gouhanke
31419a6fc1 chore(camera): 添加front相机 2026-02-06 11:53:01 +08:00
gouhanke
66009473ad debug(inference): 添加推理阶段qpos归一化 2026-02-06 09:00:44 +08:00
gouhanke
b0a944f7aa feat(train): 跑通训练脚本 2026-02-05 14:08:43 +08:00
gouhanke
dd2749cb12 feat: 更新框架,新增数据及定义和backbone 2026-02-05 01:44:43 +08:00
gouhanke
92660562fb feat(dataset): 添加统计数据计算脚本 2026-02-05 01:44:43 +08:00
gouhanke
03f10b0c22 feat: 编写状态编码器、动作编码器 2026-02-05 01:44:43 +08:00
gouhanke
8fce9c89ef chore: 删除多余文件 2026-02-05 01:44:43 +08:00
JiajunLI
30b8ff4d7d Merge branch 'main' into 'dev'
# Conflicts:
#   README.md
2026-02-04 09:30:00 +00:00
gouhanke
3f8c3dbf5d chore(readme): 修改readme里的数据结构标准 2026-02-04 14:33:52 +08:00
gouhanke
3465782256 feat: 添加保存模型的功能和推理脚本 2026-02-03 18:03:47 +08:00
gouhanke
f5e2eca809 debug(train): 在siglip和DiffusionHead下跑通训练流程 2026-02-03 17:42:32 +08:00
gouhanke
3b58760469 跑通配置和训练脚本 2026-02-03 16:51:04 +08:00
gouhanke
bd8bbb0cfc debug: 核心骨架伪实现 2026-02-03 16:14:54 +08:00
gouhanke
d3863ea1dd feat(dataset): 定义VLAChunkedDataset类,构建数据可视化工具 2026-02-03 15:24:09 +08:00
gouhanke
57acfd645f feat(vla): vla框架初始化 2026-02-03 14:18:30 +08:00
gouhanke
c1ce560b32 feat(inference): 添加动作平滑器 2026-02-03 10:32:09 +08:00
gouhanke
a977cc4f5e chore(Git LFS): 配置 Git LFS 以支持 .safetensors 模型文件 2026-02-03 10:30:06 +08:00
JiajunLI
fdf4dd8bed feat(policy): 引入gr00t(DiT) 2026-02-02 17:16:28 +08:00
JiajunLI
fd1bd20c4f chore(constants): 修改参与训练和推理的相机
- 现在使用顶部相机、右手腕相机。
2026-01-28 19:32:56 +08:00
Li Zonda
ab1f50cc66 Initial commit 2025-12-08 08:27:37 +00:00
89 changed files with 6621 additions and 4769 deletions

2
.gitignore vendored
View File

@@ -124,3 +124,5 @@ GEMINI.md
# Copilot # Copilot
.github/copilot-instructions.md .github/copilot-instructions.md
.hydra/

View File

@@ -1,36 +0,0 @@
# robo-imi-act
#### Description
{**When you're done, you can delete the content in this README and update the file with details for others getting started with your repository**}
#### Software Architecture
Software architecture description
#### Installation
1. xxxx
2. xxxx
3. xxxx
#### Instructions
1. xxxx
2. xxxx
3. xxxx
#### Contribution
1. Fork the repository
2. Create Feat_xxx branch
3. Commit your code
4. Create Pull Request
#### Gitee Feature
1. You can use Readme\_XXX.md to support different languages, such as Readme\_en.md, Readme\_zh.md
2. Gitee blog [blog.gitee.com](https://blog.gitee.com)
3. Explore open source project [https://gitee.com/explore](https://gitee.com/explore)
4. The most valuable open source project [GVP](https://gitee.com/gvp)
5. The manual of Gitee [https://gitee.com/help](https://gitee.com/help)
6. The most popular members [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/)

223
README.md
View File

@@ -1,39 +1,208 @@
# robo-imi-act # RoboIMI
#### 介绍 基于 MuJoCo 的机器人仿真与模仿学习框架,实现了使用扩散策略的视觉-语言-动作VLA模型用于机器人操作任务。
{**以下是 Gitee 平台说明,您可以替换此简介**
Gitee 是 OSCHINA 推出的基于 Git 的代码托管平台(同时支持 SVN。专为开发者提供稳定、高效、安全的云端软件开发协作平台
无论是个人、团队、或是企业,都能够用 Gitee 实现代码托管、项目管理、协作开发。企业项目请看 [https://gitee.com/enterprises](https://gitee.com/enterprises)}
#### 软件架构 ## 主要特性
软件架构说明
- **多机器人平台支持**:支持 Diana 和 vx300s 机械臂,可扩展至其他机器人
- **扩散策略**采用最先进的扩散模型DDPM/DDIM进行动作序列预测
- **视觉-语言-动作模型**:使用 ResNet-18 视觉骨干网络和空间 softmax 进行视觉特征提取
- **灵活的控制模式**:支持关节空间和末端执行器(笛卡尔)控制
- **Hydra 配置系统**:模块化配置系统,便于实验
- **HDF5 数据集格式**:高效存储和加载演示数据
- **单臂和双臂任务**:支持单臂和双臂操作任务
#### 安装教程 ## 安装
1. xxxx ### 环境要求
2. xxxx
3. xxxx
#### 使用说明 - Python 3.8+
- 支持 CUDA 的 GPU训练时推荐
- Conda 或 Miniconda
1. xxxx ### 安装步骤
2. xxxx
3. xxxx
#### 参与贡献 ```bash
# 克隆仓库
git clone <repository-url>
cd robo-imi-act
1. Fork 本仓库 # 创建并激活 conda 环境
2. 新建 Feat_xxx 分支 conda env create -f environment.yml
3. 提交代码 conda activate roboimi
4. 新建 Pull Request
# 以开发模式安装包
pip install -e .
```
#### 特技 ## 快速开始
1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md ### 1. 数据采集
2. Gitee 官方博客 [blog.gitee.com](https://blog.gitee.com)
3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解 Gitee 上的优秀开源项目 在仿真环境中记录演示轨迹:
4. [GVP](https://gitee.com/gvp) 全称是 Gitee 最有价值开源项目,是综合评定出的优秀开源项目
5. Gitee 官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help) ```bash
6. Gitee 封面人物是一档用来展示 Gitee 会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) # 为 vx300s 机器人记录轨迹
python roboimi/demos/record_sim_episodes.py
# 为 Diana 机器人记录轨迹
python roboimi/demos/diana_record_sim_episodes.py
```
轨迹数据以 HDF5 文件格式保存,包含机器人状态、动作和相机观测。
### 2. 计算数据集统计信息
训练前需要计算归一化统计数据:
```bash
python roboimi/vla/scripts/calculate_stats.py
```
该命令会生成 `data_stats.pkl` 文件,包含动作和观测的均值/标准差或最小值/最大值。
### 3. 训练 VLA 模型
使用采集的数据训练视觉-语言-动作模型:
```bash
# 使用默认配置训练
python roboimi/demos/vla_scripts/train_vla.py
# 覆盖特定参数
python roboimi/demos/vla_scripts/train_vla.py train.batch_size=32 train.lr=5e-5 train.max_steps=50000
# 使用不同的模型架构
python roboimi/demos/vla_scripts/train_vla.py agent=resnet_diffusion data=resnet_dataset
```
训练输出保存至 `outputs/<日期>/<时间>/`,模型检查点保存至 `checkpoints/`
### 4. 评估模型
在仿真环境中评估训练好的模型:
```bash
# 使用默认配置评估(使用最佳检查点)
python roboimi/demos/vla_scripts/eval_vla.py
# 指定检查点和评估轮数
python roboimi/demos/vla_scripts/eval_vla.py eval.ckpt_path=checkpoints/vla_model_step_8000.pt eval.num_episodes=5
# 启用动作平滑以获得更流畅的执行
python roboimi/demos/vla_scripts/eval_vla.py eval.use_smoothing=true eval.smooth_alpha=0.5
```
## 项目结构
```
robo-imi-act/
├── roboimi/
│ ├── assets/ # 机器人模型和资源
│ │ ├── models/manipulators/ # URDF 和 MuJoCo XML 文件
│ │ └── robots/ # 机器人抽象类
│ ├── envs/ # 仿真环境
│ │ ├── mujoco_base.py # MuJoCo 环境基类
│ │ ├── single_base.py # 单臂任务基类
│ │ └── double_base.py # 双臂任务基类
│ ├── vla/ # 视觉-语言-动作模型
│ │ ├── agent.py # VLAAgent训练与推理
│ │ ├── models/
│ │ │ ├── backbones/ # 视觉编码器ResNet 等)
│ │ │ └── heads/ # 策略头(扩散 UNet1D
│ │ ├── conf/ # Hydra 配置文件
│ │ └── scripts/ # 训练和工具脚本
│ └── demos/ # 演示脚本和示例
├── checkpoints/ # 保存的模型检查点
├── outputs/ # 训练输出Hydra
├── environment.yml # Conda 环境配置
└── CLAUDE.md # Claude Code 开发指南
```
## 架构设计
### VLA 训练流程
```
HDF5 轨迹数据 → Dataset → DataLoader → VLAAgent → 模型检查点
```
**模型组件**
- **视觉骨干网络**ResNet-18 + 空间 softmax用于从相机图像中提取视觉特征
- **扩散头**:条件 UNet1D使用 DDPM/DDIM 预测动作序列
- **VLAAgent**:组合视觉编码器和扩散策略,处理训练和推理
### 配置系统
基于 Hydra 的配置文件位于 `roboimi/vla/conf/`
- `config.yaml`:主要训练配置(批次大小、学习率、设备)
- `agent/resnet_diffusion.yaml`:模型架构(动作维度、观测维度、时间窗口)
- `data/resnet_dataset.yaml`:数据集路径、相机名称、归一化类型
- `eval/eval.yaml`:评估设置(检查点路径、轮数、平滑参数)
使用配置插值保持一致性:`${agent.obs_horizon}`
### 数据集格式
HDF5 轨迹文件(`episode_*.hdf5`)包含:
- `action`:机器人动作 `[T, action_dim]`
- `observations/qpos`:关节位置 `[T, obs_dim]`
- `observations/images/<cam_name>`:相机图像 `[T, H, W, C]`
统计文件(`data_stats.pkl`)存储归一化参数(最小值/最大值/均值/标准差)。
## 开发指南
### 添加新机器人
1.`roboimi/assets/models/manipulators/<robot_name>/` 创建 URDF/XML 文件
2.`roboimi/assets/robots/<robot_name>.py` 定义机器人类(继承自 `arm_base.py`
3.`roboimi/envs/<robot_name>_*.py` 创建环境类
4. 如需要,在常量中注册机器人
### 修改 VLA 架构
1. **自定义骨干网络**:在 `roboimi/vla/models/backbones/` 创建新类,继承 `VLABackbone`
2. **自定义头部**:在 `roboimi/vla/models/heads/` 创建新类,继承 `VLAHead`
3. **更新配置**:在 `roboimi/vla/conf/agent/` 添加新的 YAML 文件
4. **接口定义**:参考 `roboimi/vla/core/interfaces.py` 的抽象基类
### 训练最佳实践
- 采集新数据后务必运行 `calculate_stats.py`
- 训练时会归一化输入/输出;推理时使用检查点中保存的统计信息进行反归一化
- 模型预测 `pred_horizon` 步,但只执行前 `action_horizon`
- 推理使用 DDIM10 步)快速采样;训练使用 DDPM100 步)
- 监控验证损失以防止过拟合
## 技术细节
- **坐标系**关节空间qpos或末端执行器空间xyz + rpy + 夹爪)
- **动作时间窗口**`obs_horizon` 为观测窗口,`pred_horizon` 为预测窗口,`action_horizon` 为执行窗口
- **归一化**:对稳定训练至关重要 - 训练前务必计算统计信息
- **推理加速**:使用 DDIM 调度器,比训练时的 DDPM 快 10 倍
- **设备配置**:通过 `train.device` 配置cuda/cpu
## 许可证
[在此添加许可证信息]
## 引用
如果您在研究中使用了本代码库,请引用:
```bibtex
[在此添加引用信息]
```
## 贡献
欢迎贡献!请随时提交 Pull Request 或开启 Issue。
## 致谢
本项目基于以下开源项目构建:
- [MuJoCo](https://mujoco.org/) - 物理仿真引擎
- [PyTorch](https://pytorch.org/) - 深度学习框架
- [Hydra](https://hydra.cc/) - 配置管理系统
- [Diffusers](https://github.com/huggingface/diffusers) - 扩散模型库

91
check_all_episodes.py Normal file
View File

@@ -0,0 +1,91 @@
#!/usr/bin/env python3
"""
检查所有 episode 的重复帧情况
找出哪些 episode 有问题,需要删除或重新收集
"""
import os
import h5py
import glob
import numpy as np
def check_all_episodes():
"""检查所有 episode 的质量"""
dataset_dir = "roboimi/demos/dataset/sim_transfer"
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', '')))
print("="*80)
print("所有 Episode 质量检查")
print("="*80)
good_episodes = []
bad_episodes = []
for ep_idx, ep_file in enumerate(episode_files):
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
try:
with h5py.File(ep_file, 'r') as f:
img_path = '/observations/images/top'
if img_path not in f:
continue
images = f[img_path][:]
# 检查前 50 帧的重复情况
check_frames = min(50, len(images))
duplicate_count = 0
for i in range(check_frames - 1):
img1 = images[i]
img2 = images[i + 1]
diff = np.mean(np.abs(img1.astype(float) - img2.astype(float)))
if diff < 1.0: # 重复
duplicate_count += 1
duplicate_rate = duplicate_count / check_frames * 100
# 判断质量
if duplicate_rate > 10: # 超过10%重复
bad_episodes.append((ep_idx, ep_name, duplicate_rate, duplicate_count))
status = ""
else:
good_episodes.append((ep_idx, ep_name, duplicate_rate, duplicate_count))
status = ""
print(f"{status} Episode {ep_idx:2d}: {duplicate_rate:5.1f}% 重复 ({duplicate_count:2d}/{check_frames}) - {ep_name}")
except Exception as e:
print(f"❌ Episode {ep_idx}: 错误 - {e}")
# 总结
print("\n" + "="*80)
print("总结")
print("="*80)
print(f"总共检查: {len(episode_files)} 个 episodes")
print(f"正常的: {len(good_episodes)} 个 ✅")
print(f"有问题的: {len(bad_episodes)} 个 ❌")
if bad_episodes:
print(f"\n有问题的 episodes:")
for ep_idx, ep_name, rate, count in bad_episodes:
print(f" - episode_{ep_idx}.hdf5: {rate:.1f}% 重复")
print(f"\n删除命令:")
ep_names = [name for _, name, _, _ in bad_episodes]
print(f" rm " + " ".join([f"{dataset_dir}/{name}.hdf5" for name in ep_names]))
print(f"\n建议:")
if len(bad_episodes) > 0:
print(f" 1. 删除有问题的 {len(bad_episodes)} 个 episodes")
print(f" 2. 重新收集数据,或使用剩余的 {len(good_episodes)} 个正常 episodes")
else:
print(f" ✅ 所有 episodes 都正常,可以直接使用!")
if __name__ == "__main__":
check_all_episodes()

202
check_specific_frames.py Normal file
View File

@@ -0,0 +1,202 @@
#!/usr/bin/env python3
"""
检查特定帧的图像 - 用于验证数据记录问题
功能:
1. 提取每个 episode 的第 0、1、2 帧图像
2. 对比不同 episode 的相同帧号
3. 保存图像供人工检查
"""
import os
import h5py
import glob
import cv2
import numpy as np
def check_specific_frames(frame_indices=[0, 1, 2], camera='top', num_episodes=10):
"""
检查特定帧的图像和 qpos
Args:
frame_indices: 要检查的帧索引列表
camera: 相机名称
num_episodes: 要检查的 episode 数量
"""
dataset_dir = "roboimi/demos/dataset/sim_transfer"
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
# 按数字排序
episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', '')))
# 创建输出目录
output_dir = f'/tmp/dataset_frames'
os.makedirs(output_dir, exist_ok=True)
print(f"检查前 {min(num_episodes, len(episode_files))} 个 episode 的特定帧")
print(f"帧索引: {frame_indices}")
print(f"相机: {camera}")
print("="*80)
# 收集所有数据
for ep_idx in range(min(num_episodes, len(episode_files))):
ep_file = episode_files[ep_idx]
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
try:
with h5py.File(ep_file, 'r') as f:
# 读取 qpos
qpos = f['/observations/qpos'][:]
# 读取图像
img_path = f'/observations/images/{camera}'
if img_path not in f:
print(f"Episode {ep_name}: 相机 {camera} 不存在")
continue
images = f[img_path][:]
print(f"\nEpisode {ep_name}:")
print(f" 总帧数: {len(images)}")
# 保存指定帧
for frame_idx in frame_indices:
if frame_idx >= len(images):
print(f"{frame_idx}: 超出范围")
continue
# 保存图像
img = images[frame_idx]
filename = f"{output_dir}/ep{ep_idx:02d}_frame{frame_idx:03d}.png"
cv2.imwrite(filename, img)
# 打印 qpos
q = qpos[frame_idx]
print(f"{frame_idx}: qpos[0:3]=[{q[0]:6.2f}, {q[1]:6.2f}, {q[2]:6.2f}], qpos[3]={q[3]:6.2f}{filename}")
except Exception as e:
print(f"Episode {ep_name}: 错误 - {e}")
print("\n" + "="*80)
print(f"✅ 所有图像已保存到: {output_dir}")
print(f"\n查看方法:")
print(f" eog {output_dir}/*.png")
print(f" ")
print(f" # 或对比特定帧:")
print(f" eog {output_dir}/*_frame000.png # 所有 episode 的第 0 帧")
print(f" eog {output_dir}/*_frame001.png # 所有 episode 的第 1 帧")
print(f" eog {output_dir}/*_frame002.png # 所有 episode 的第 2 帧")
def compare_frame_across_episodes(frame_idx=0, camera='top', num_episodes=10):
"""
并排对比所有 episode 的某一帧
生成一个大的对比图,包含所有 episode 的指定帧
"""
dataset_dir = "roboimi/demos/dataset/sim_transfer"
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', '')))
num_compare = min(num_episodes, len(episode_files))
cols = 5 # 每行 5 个
rows = (num_compare + cols - 1) // cols
# 创建输出目录
output_dir = f'/tmp/dataset_frames'
os.makedirs(output_dir, exist_ok=True)
print(f"生成对比图: 所有 Episode 的第 {frame_idx}")
print("="*80)
# 收集图像
images_compare = []
qpos_list = []
for ep_idx in range(num_compare):
ep_file = episode_files[ep_idx]
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
try:
with h5py.File(ep_file, 'r') as f:
qpos = f['/observations/qpos'][:]
img_path = f'/observations/images/{camera}'
if img_path in f and frame_idx < f[img_path].shape[0]:
img = f[img_path][frame_idx]
images_compare.append(img)
qpos_list.append(qpos[frame_idx])
print(f"Episode {ep_name}: qpos[0:3]=[{qpos[frame_idx][0]:.2f}, {qpos[frame_idx][1]:.2f}, {qpos[frame_idx][2]:.2f}]")
except Exception as e:
print(f"Episode {ep_name}: 错误 - {e}")
if not images_compare:
print("❌ 没有收集到图像")
return
# 获取图像尺寸
h, w = images_compare[0].shape[:2]
# 创建对比图
compare_img = np.zeros((rows * h + 50, cols * w, 3), dtype=np.uint8)
for i, (img, qpos) in enumerate(zip(images_compare, qpos_list)):
row = i // cols
col = i % cols
y_start = row * h + 30
y_end = y_start + h
x_start = col * w
x_end = x_start + w
# 调整大小(如果需要)
if img.shape[:2] != (h, w):
img = cv2.resize(img, (w, h))
compare_img[y_start:y_end, x_start:x_end] = img
# 添加信息
ep_name = f"Ep {i}"
cv2.putText(compare_img, ep_name, (x_start + 10, row * h + 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
cv2.putText(compare_img, f"qpos[3]={qpos[3]:.2f}", (x_start + 10, y_end - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
# 保存对比图
output_path = f"{output_dir}/compare_frame{frame_idx:03d}.png"
cv2.imwrite(output_path, compare_img)
print(f"\n✅ 对比图已保存: {output_path}")
print(f" 查看方法: eog {output_path}")
if __name__ == "__main__":
import sys
print("="*80)
print("特定帧检查工具")
print("="*80)
if len(sys.argv) > 1:
frame_idx = int(sys.argv[1])
compare_frame_across_episodes(frame_idx=frame_idx, camera='top', num_episodes=10)
else:
# 默认检查第 0、1、2 帧
check_specific_frames(frame_indices=[0, 1, 2], camera='top', num_episodes=10)
print("\n" + "="*80)
print("生成对比图...")
print("="*80)
# 生成第 0 帧的对比图
compare_frame_across_episodes(frame_idx=0, camera='top', num_episodes=10)
compare_frame_across_episodes(frame_idx=1, camera='top', num_episodes=10)
compare_frame_across_episodes(frame_idx=2, camera='top', num_episodes=10)
print("\n" + "="*80)
print("其他用法:")
print(" python check_specific_frames.py 0 # 只检查第 0 帧")
print(" python check_specific_frames.py 1 # 只检查第 1 帧")
print("="*80)

View File

@@ -0,0 +1,238 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamConfig
from lerobot.optim.schedulers import DiffuserSchedulerConfig
@PreTrainedConfig.register_subclass("diffusion")
@dataclass
class DiffusionConfig(PreTrainedConfig):
"""Configuration class for DiffusionPolicy.
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- Either:
- At least one key starting with "observation.image is required as an input.
AND/OR
- The key "observation.environment_state" is required as input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views. Right now we only support all images having the same shape.
- "action" is required as an output key.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
You may provide a variable number of dimensions, therefore also controlling the degree of
downsampling.
kernel_size: The convolutional kernel size of the diffusion modeling Unet.
n_groups: Number of groups used in the group norm of the Unet's convolutional blocks.
diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear
network. This is the output dimension of that network, i.e., the embedding dimension.
use_film_scale_modulation: FiLM (https://huggingface.co/papers/1709.07871) is used for the Unet conditioning.
Bias modulation is used be default, while this parameter indicates whether to also use scale
modulation.
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
beta_start: Beta value for the first forward-diffusion step.
beta_end: Beta value for the last forward-diffusion step.
prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon"
or "sample". These have equivalent outcomes from a latent variable modeling perspective, but
"epsilon" has been shown to work better in many deep neural network settings.
clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
denoising step at inference time. WARNING: you will need to make sure your action-space is
normalized to fit within this range.
clip_sample_range: The magnitude of the clipping range as described above.
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
`LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
to False as the original Diffusion Policy implementation does the same.
"""
# Inputs / output structure.
n_obs_steps: int = 2
horizon: int = 16
n_action_steps: int = 8
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results.
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
# Architecture / modeling.
# Vision backbone.
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84)
crop_is_random: bool = True
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
spatial_softmax_num_keypoints: int = 32
use_separate_rgb_encoder_per_camera: bool = False
# Unet.
down_dims: tuple[int, ...] = (512, 1024, 2048)
kernel_size: int = 5
n_groups: int = 8
diffusion_step_embed_dim: int = 128
use_film_scale_modulation: bool = True
# Noise scheduler.
noise_scheduler_type: str = "DDPM"
num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001
beta_end: float = 0.02
prediction_type: str = "epsilon"
clip_sample: bool = True
clip_sample_range: float = 1.0
# Inference
num_inference_steps: int | None = None
# Loss computation
do_mask_loss_for_padding: bool = False
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-6
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 500
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:
raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
)
supported_noise_schedulers = ["DDPM", "DDIM"]
if self.noise_scheduler_type not in supported_noise_schedulers:
raise ValueError(
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
f"Got {self.noise_scheduler_type}."
)
# Check that the horizon size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims)
if self.horizon % downsampling_factor != 0:
raise ValueError(
"The horizon should be an integer multiple of the downsampling factor (which is determined "
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
)
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
)
def validate_features(self) -> None:
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
# Check that all input images have the same shape.
if len(self.image_features) > 0:
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
)
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,764 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
TODO(alexander-soare):
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
"""
import math
from collections import deque
from collections.abc import Callable
import einops
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor, nn
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
get_output_shape,
populate_queues,
)
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class DiffusionPolicy(PreTrainedPolicy):
"""
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
(paper: https://huggingface.co/papers/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
"""
config_class = DiffusionConfig
name = "diffusion"
def __init__(
self,
config: DiffusionConfig,
**kwargs,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
self.diffusion = DiffusionModel(config)
self.reset()
def get_optim_params(self) -> dict:
return self.diffusion.parameters()
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
ACTION: deque(maxlen=self.config.n_action_steps),
}
if self.config.image_features:
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
if self.config.env_state_feature:
self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch, noise=noise)
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the
underlying diffusion model. Here's how it works:
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
copied `n_obs_steps` times to fill the cache).
- The diffusion model generates `horizon` steps worth of actions.
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
Schematically this looks like:
----------------------------------------------------------------------------------------------
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
----------------------------------------------------------------------------------------------
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# NOTE: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch, noise=noise)
self._queues[ACTION].extend(actions.transpose(0, 1))
action = self._queues[ACTION].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
"""Run the batch through the model and compute the loss for training or validation."""
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
loss = self.diffusion.compute_loss(batch)
# no output_dict so returning None
return loss, None
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
"""
Factory for noise scheduler instances of the requested type. All kwargs are passed
to the scheduler.
"""
if name == "DDPM":
return DDPMScheduler(**kwargs)
elif name == "DDIM":
return DDIMScheduler(**kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {name}")
class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()
self.config = config
# Build observation encoders (depending on which observations are provided).
global_cond_dim = self.config.robot_state_feature.shape[0]
if self.config.image_features:
num_images = len(self.config.image_features)
if self.config.use_separate_rgb_encoder_per_camera:
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
self.rgb_encoder = nn.ModuleList(encoders)
global_cond_dim += encoders[0].feature_dim * num_images
else:
self.rgb_encoder = DiffusionRgbEncoder(config)
global_cond_dim += self.rgb_encoder.feature_dim * num_images
if self.config.env_state_feature:
global_cond_dim += self.config.env_state_feature.shape[0]
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start,
beta_end=config.beta_end,
beta_schedule=config.beta_schedule,
clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range,
prediction_type=config.prediction_type,
)
if config.num_inference_steps is None:
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
else:
self.num_inference_steps = config.num_inference_steps
# ========= inference ============
def conditional_sample(
self,
batch_size: int,
global_cond: Tensor | None = None,
generator: torch.Generator | None = None,
noise: Tensor | None = None,
) -> Tensor:
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
# Sample prior.
sample = (
noise
if noise is not None
else torch.randn(
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
dtype=dtype,
device=device,
generator=generator,
)
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
# Predict model output.
model_output = self.unet(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
global_cond=global_cond,
)
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
return sample
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
"""Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
global_cond_feats = [batch[OBS_STATE]]
# Extract image features.
if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(batch[OBS_IMAGES], "b s n ... -> n (b s) ...")
img_features_list = torch.cat(
[
encoder(images)
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
]
)
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = self.rgb_encoder(
einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...")
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
global_cond_feats.append(img_features)
if self.config.env_state_feature:
global_cond_feats.append(batch[OBS_ENV_STATE])
# Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
def generate_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""
This function expects `batch` to have:
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
"observation.environment_state": (B, n_obs_steps, environment_dim)
}
"""
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
# Encode image features and concatenate them all together along with the state vector.
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# run sampling
actions = self.conditional_sample(batch_size, global_cond=global_cond, noise=noise)
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.config.n_action_steps
actions = actions[:, start:end]
return actions
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
"observation.environment_state": (B, n_obs_steps, environment_dim)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)
}
"""
# Input validation.
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
n_obs_steps = batch[OBS_STATE].shape[1]
horizon = batch[ACTION].shape[1]
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps
# Encode image features and concatenate them all together along with the state vector.
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# Forward diffusion.
trajectory = batch[ACTION]
# Sample noise to add to the trajectory.
eps = torch.randn(trajectory.shape, device=trajectory.device)
# Sample a random noising timestep for each item in the batch.
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(trajectory.shape[0],),
device=trajectory.device,
).long()
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
# Compute the loss.
# The target is either the original trajectory, or the noise.
if self.config.prediction_type == "epsilon":
target = eps
elif self.config.prediction_type == "sample":
target = batch[ACTION]
else:
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if self.config.do_mask_loss_for_padding:
if "action_is_pad" not in batch:
raise ValueError(
"You need to provide 'action_is_pad' in the batch when "
f"{self.config.do_mask_loss_for_padding=}."
)
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
return loss.mean()
class SpatialSoftmax(nn.Module):
"""
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
(https://huggingface.co/papers/1509.06113). A minimal port of the robomimic implementation.
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
-----------------------------------------------------
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
| ... | ... | ... | ... |
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
-----------------------------------------------------
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
linear mapping (in_channels, H, W) -> (num_kp, H, W).
"""
def __init__(self, input_shape, num_kp=None):
"""
Args:
input_shape (list): (C, H, W) input feature map shape.
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
"""
super().__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._out_c = num_kp
else:
self.nets = None
self._out_c = self._in_c
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
def forward(self, features: Tensor) -> Tensor:
"""
Args:
features: (B, C, H, W) input feature maps.
Returns:
(B, K, 2) image-space coordinates of keypoints.
"""
if self.nets is not None:
features = self.nets(features)
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
features = features.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization
attention = F.softmax(features, dim=-1)
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
expected_xy = attention @ self.pos_grid
# reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
return feature_keypoints
class DiffusionRgbEncoder(nn.Module):
"""Encodes an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first.
"""
def __init__(self, config: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if config.crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
# Set up backbone.
backbone_model = getattr(torchvision.models, config.vision_backbone)(
weights=config.pretrained_backbone_weights
)
# Note: This assumes that the layer4 feature map is children()[-3]
# TODO(alexander-soare): Use a safer alternative.
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if config.use_group_norm:
if config.pretrained_backbone_weights:
raise ValueError(
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
)
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.image_features` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.image_features`.
# Note: we have a check in the config class to make sure all images have the same shape.
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: (B, C, H, W) image tensor with pixel values in [0, 1].
Returns:
(B, D) image feature.
"""
# Preprocess: maybe crop (if it was set up in the __init__).
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
else:
# Always use center crop for eval.
x = self.center_crop(x)
# Extract backbone feature.
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
# Final linear layer with non-linearity.
x = self.relu(self.out(x))
return x
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
Args:
root_module: The module for which the submodules need to be replaced
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
func: Takes a module as an argument and returns a new module to replace it with.
Returns:
The root module with its submodules replaced.
"""
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
parent_module = root_module.get_submodule(".".join(parents))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
return root_module
class DiffusionSinusoidalPosEmb(nn.Module):
"""1D sinusoidal positional embeddings as in Attention is All You Need."""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class DiffusionConv1dBlock(nn.Module):
"""Conv1d --> GroupNorm --> Mish"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class DiffusionConditionalUnet1d(nn.Module):
"""A 1D convolutional UNet with FiLM modulation for conditioning.
Note: this removes local conditioning as compared to the original diffusion policy code.
"""
def __init__(self, config: DiffusionConfig, global_cond_dim: int):
super().__init__()
self.config = config
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
nn.Mish(),
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
)
# The FiLM conditioning dimension.
cond_dim = config.diffusion_step_embed_dim + global_cond_dim
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these.
in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list(
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
)
# Unet encoder.
common_res_block_kwargs = {
"cond_dim": cond_dim,
"kernel_size": config.kernel_size,
"n_groups": config.n_groups,
"use_film_scale_modulation": config.use_film_scale_modulation,
}
self.down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
self.down_modules.append(
nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
]
)
)
# Processing in the middle of the auto-encoder.
self.mid_modules = nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
),
]
)
# Unet decoder.
self.up_modules = nn.ModuleList([])
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
self.up_modules.append(
nn.ModuleList(
[
# dim_in * 2, because it takes the encoder's skip connection as well
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
]
)
)
self.final_conv = nn.Sequential(
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
)
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
"""
Args:
x: (B, T, input_dim) tensor for input to the Unet.
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
global_cond: (B, global_cond_dim)
output: (B, T, input_dim)
Returns:
(B, T, input_dim) diffusion model prediction.
"""
# For 1D convolutions we'll need feature dimension first.
x = einops.rearrange(x, "b t d -> b d t")
timesteps_embed = self.diffusion_step_encoder(timestep)
# If there is a global conditioning feature, concatenate it to the timestep embedding.
if global_cond is not None:
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
else:
global_feature = timesteps_embed
# Run encoder, keeping track of skip features to pass to the decoder.
encoder_skip_features: list[Tensor] = []
for resnet, resnet2, downsample in self.down_modules:
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
encoder_skip_features.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
# Run decoder, using the skip features from the encoder.
for resnet, resnet2, upsample in self.up_modules:
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, "b d t -> b t d")
return x
class DiffusionConditionalResidualBlock1d(nn.Module):
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
def __init__(
self,
in_channels: int,
out_channels: int,
cond_dim: int,
kernel_size: int = 3,
n_groups: int = 8,
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
# FiLM just modulates bias).
use_film_scale_modulation: bool = False,
):
super().__init__()
self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
# FiLM modulation (https://huggingface.co/papers/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
"""
Args:
x: (B, in_channels, T)
cond: (B, cond_dim)
Returns:
(B, out_channels, T)
"""
out = self.conv1(x)
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
if self.use_film_scale_modulation:
# Treat the embedding as a list of scales and biases.
scale = cond_embed[:, : self.out_channels]
bias = cond_embed[:, self.out_channels :]
out = scale * out + bias
else:
# Treat the embedding as biases.
out = out + cond_embed
out = self.conv2(out)
out = out + self.residual_conv(x)
return out

View File

@@ -0,0 +1,92 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_diffusion_pre_post_processors(
config: DiffusionConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for a diffusion policy.
The pre-processing pipeline prepares the input data for the model by:
1. Renaming features.
2. Normalizing the input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Moving the data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving the data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the diffusion policy,
containing feature definitions, normalization mappings, and device information.
dataset_stats: A dictionary of statistics used for normalization.
Defaults to None.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

474
environment.yml Normal file
View File

@@ -0,0 +1,474 @@
name: roboimi
channels:
- conda-forge
dependencies:
- _libgcc_mutex=0.1
- _openmp_mutex=4.5
- _python_abi3_support=1.0
- aiohappyeyeballs=2.6.1
- aiohttp=3.13.3
- aiosignal=1.4.0
- alsa-lib=1.2.9
- anyio=4.12.1
- aom=3.5.0
- async-timeout=5.0.1
- attr=2.5.1
- attrs=25.4.0
- aws-c-auth=0.7.22
- aws-c-cal=0.6.15
- aws-c-common=0.9.23
- aws-c-compression=0.2.18
- aws-c-event-stream=0.4.2
- aws-c-http=0.8.2
- aws-c-io=0.14.9
- aws-c-mqtt=0.10.4
- aws-c-s3=0.5.10
- aws-c-sdkutils=0.1.16
- aws-checksums=0.1.18
- aws-crt-cpp=0.26.12
- aws-sdk-cpp=1.11.329
- box2d-py=2.3.8
- brotli=1.1.0
- brotli-bin=1.1.0
- brotli-python=1.1.0
- bzip2=1.0.8
- c-ares=1.34.6
- ca-certificates=2026.1.4
- cairo=1.16.0
- certifi=2026.1.4
- cffi=1.17.1
- charset-normalizer=3.4.4
- click=8.3.1
- cloudpickle=3.0.0
- contourpy=1.3.0
- cpython=3.10.19
- cuda-cudart=12.6.68
- cuda-cudart_linux-64=12.6.68
- cuda-nvrtc=12.6.68
- cuda-nvtx=12.6.68
- cuda-version=12.6
- cudnn=8.9.7.29
- cycler=0.12.1
- datasets=4.0.0
- dav1d=1.2.1
- dbus=1.13.6
- dill=0.3.8
- eigen=3.4.0
- exceptiongroup=1.3.1
- expat=2.6.3
- farama-notifications=0.0.4
- filelock=3.15.4
- fluidsynth=2.3.3
- font-ttf-dejavu-sans-mono=2.37
- font-ttf-inconsolata=3.000
- font-ttf-source-code-pro=2.038
- font-ttf-ubuntu=0.83
- fontconfig=2.14.2
- fonts-conda-ecosystem=1
- fonts-conda-forge=1
- fonttools=4.53.1
- freetype=2.12.1
- frozenlist=1.7.0
- fsspec=2024.6.1
- gettext=0.22.5
- gettext-tools=0.22.5
- gflags=2.2.2
- git-lfs=3.7.1
- glog=0.7.1
- gmp=6.3.0
- gmpy2=2.1.5
- graphite2=1.3.13
- gym=0.26.1
- gym-box2d=0.26.1
- gym-notices=0.0.8
- gymnasium=0.29.1
- h11=0.16.0
- h2=4.3.0
- harfbuzz=7.3.0
- hf-xet=1.2.1
- hpack=4.1.0
- httpcore=1.0.9
- httpx=0.28.1
- huggingface_hub=1.3.5
- hyperframe=6.1.0
- icu=72.1
- idna=3.11
- jack=1.9.22
- jax-jumpy=1.0.0
- jinja2=3.1.4
- jpeg=9e
- keyutils=1.6.3
- kiwisolver=1.4.9
- krb5=1.21.3
- lame=3.100
- lcms2=2.15
- ld_impl_linux-64=2.40
- lerc=4.0.0
- libabseil=20240116.2
- libarrow=16.1.0
- libarrow-acero=16.1.0
- libarrow-dataset=16.1.0
- libarrow-substrait=16.1.0
- libasprintf=0.22.5
- libasprintf-devel=0.22.5
- libavif=0.11.1
- libblas=3.9.0
- libbrotlicommon=1.1.0
- libbrotlidec=1.1.0
- libbrotlienc=1.1.0
- libcap=2.69
- libcblas=3.9.0
- libcrc32c=1.1.2
- libcublas=12.6.1.4
- libcufft=11.2.6.59
- libcurand=10.3.7.68
- libcurl=8.12.1
- libcusolver=11.6.4.69
- libcusparse=12.5.3.3
- libdb=6.2.32
- libdeflate=1.17
- libedit=3.1.20250104
- libev=4.33
- libevent=2.1.12
- libexpat=2.6.3
- libffi=3.4.2
- libflac=1.4.3
- libgcc=14.1.0
- libgcc-ng=14.1.0
- libgcrypt=1.11.0
- libgettextpo=0.22.5
- libgettextpo-devel=0.22.5
- libgfortran=14.1.0
- libgfortran-ng=14.1.0
- libgfortran5=14.1.0
- libglib=2.80.3
- libgoogle-cloud=2.25.0
- libgoogle-cloud-storage=2.25.0
- libgpg-error=1.50
- libgrpc=1.62.2
- libhwloc=2.9.3
- libiconv=1.17
- libjpeg-turbo=2.1.4
- liblapack=3.9.0
- libmad=0.15.1b
- libmagma=2.8.0
- libmagma_sparse=2.8.0
- libnghttp2=1.67.0
- libnsl=2.0.1
- libnvjitlink=12.6.68
- libogg=1.3.5
- libopenblas=0.3.27
- libopus=1.3.1
- libparquet=16.1.0
- libpng=1.6.43
- libprotobuf=4.25.3
- libre2-11=2023.09.01
- libsndfile=1.2.2
- libsqlite=3.46.0
- libssh2=1.11.1
- libstdcxx=14.1.0
- libstdcxx-ng=14.1.0
- libsystemd0=256.5
- libthrift=0.19.0
- libtiff=4.5.0
- libtorch=2.4.0
- libutf8proc=2.8.0
- libuuid=2.38.1
- libuv=1.48.0
- libvorbis=1.3.7
- libwebp-base=1.4.0
- libxcb=1.13
- libxcrypt=4.4.36
- libxml2=2.11.5
- libzlib=1.3.1
- llvm-openmp=18.1.8
- lz4-c=1.9.4
- markupsafe=2.1.5
- matplotlib-base=3.9.2
- mkl=2023.2.0
- mpc=1.3.1
- mpfr=4.2.1
- mpg123=1.31.3
- mpmath=1.3.0
- multidict=6.7.0
- multiprocess=0.70.16
- munkres=1.1.4
- nccl=2.22.3.1
- ncurses=6.5
- networkx=3.3
- numpy=1.26.4
- openjpeg=2.5.0
- openssl=3.6.1
- opusfile=0.12
- orc=2.0.1
- orocos-kdl=1.5.1
- packaging=24.1
- pandas=2.2.2
- pcre2=10.44
- pillow=9.4.0
- pip=24.2
- pixman=0.43.2
- portaudio=19.6.0
- portmidi=2.0.4
- propcache=0.3.1
- pthread-stubs=0.4
- pulseaudio-client=16.1
- pyarrow=16.1.0
- pyarrow-core=16.1.0
- pybind11=2.13.5
- pybind11-global=2.13.5
- pycparser=2.22
- pygame=2.1.3
- pyparsing=3.1.4
- pysocks=1.7.1
- python=3.10.14
- python-dateutil=2.9.0
- python-gil=3.10.19
- python-orocos-kdl=1.5.1
- python-tzdata=2024.1
- python-xxhash=3.6.0
- python_abi=3.10
- pytorch=2.4.0
- pytz=2024.1
- pyyaml=6.0.3
- qhull=2020.2
- re2=2023.09.01
- readline=8.2
- regex=2026.1.15
- requests=2.32.5
- s2n=1.4.16
- safetensors=0.7.0
- sdl2=2.26.5
- sdl2_image=2.6.3
- sdl2_mixer=2.6.3
- sdl2_ttf=2.20.2
- setuptools=72.2.0
- shellingham=1.5.4
- six=1.16.0
- sleef=3.6.1
- snappy=1.2.2
- sniffio=1.3.1
- stable-baselines3=2.3.2
- sympy=1.13.2
- tbb=2021.11.0
- tk=8.6.13
- tokenizers=0.22.2
- tqdm=4.67.2
- transformers=5.0.0
- typer-slim=0.21.1
- typing-extensions=4.12.2
- typing_extensions=4.12.2
- tzdata=2024a
- unicodedata2=15.1.0
- urllib3=2.5.0
- wheel=0.44.0
- xorg-kbproto=1.0.7
- xorg-libice=1.1.1
- xorg-libsm=1.2.4
- xorg-libx11=1.8.4
- xorg-libxau=1.0.11
- xorg-libxdmcp=1.1.3
- xorg-libxext=1.3.4
- xorg-libxrender=0.9.10
- xorg-renderproto=0.11.1
- xorg-xextproto=7.3.0
- xorg-xproto=7.0.31
- xxhash=0.8.3
- xz=5.2.6
- yaml=0.2.5
- yarl=1.22.0
- zlib=1.3.1
- zstandard=0.23.0
- zstd=1.5.6
- pip:
- GitPython==3.1.46
- Jinja2==3.1.6
- MarkupSafe==3.0.3
- PyOpenGL==3.1.7
- PyYAML==6.0.3
- Pygments==2.19.2
- absl-py==2.1.0
- accelerate==1.12.0
- aiofiles==24.1.0
- aiohappyeyeballs==2.6.1
- aiohttp==3.13.3
- aiosignal==1.4.0
- annotated-doc==0.0.4
- annotated-types==0.7.0
- antlr4-python3-runtime==4.9.3
- anyio==4.12.1
- asciitree==0.3.3
- asttokens==3.0.1
- async-timeout==5.0.1
- attrs==25.4.0
- av==15.1.0
- brotli==1.2.0
- charset-normalizer==3.4.4
- cmake==4.1.3
- cmeel==0.58.0
- cmeel-assimp==5.4.3.1
- cmeel-boost==1.87.0.1
- cmeel-console-bridge==1.0.2.3
- cmeel-octomap==1.10.0
- cmeel-qhull==8.0.2.1
- cmeel-tinyxml==2.6.2.3
- cmeel-tinyxml2==10.0.0
- cmeel-urdfdom==3.1.1.1
- cmeel-zlib==1.3.1
- coal==3.0.2
- coal-library==3.0.1
- colorama==0.4.6
- datasets==4.5.0
- decorator==5.2.1
- deepdiff==8.6.1
- diffusers==0.30.0
- dill==0.4.0
- docstring_parser==0.17.0
- draccus==0.10.0
- eigenpy==3.10.3
- einops==0.8.1
- etils==1.7.0
- evdev==1.9.2
- exceptiongroup==1.3.1
- executing==2.2.1
- fastapi==0.128.0
- fasteners==0.20
- ffmpy==1.0.0
- filelock==3.20.3
- frozenlist==1.8.0
- fsspec==2025.10.0
- gitdb==4.0.12
- glfw==2.7.0
- gradio==6.3.0
- gradio_client==2.0.3
- groovy==0.1.2
- gymnasium==1.2.3
- h11==0.16.0
- h5py==3.15.1
- hf-xet==1.2.0
- hf_transfer==0.1.9
- httpcore==1.0.9
- httpx==0.28.1
- huggingface_hub==1.3.2
- hydra-core==1.3.2
- imageio==2.35.1
- imageio-ffmpeg==0.6.0
- importlib_metadata==8.7.1
- importlib_resources==6.5.2
- inquirerpy==0.3.4
- ipython==8.38.0
- jedi==0.19.2
- jsonargparse==4.45.0
- jsonlines==4.0.0
- kiwisolver==1.4.5
- lerobot==0.4.2
- libcoal==3.0.2
- libpinocchio==3.8.0
- lightning==2.5.0.post0
- lightning-utilities==0.15.2
- lxml==5.3.0
- markdown-it-py==4.0.0
- matplotlib-inline==0.2.1
- mdurl==0.1.2
- mergedeep==1.3.4
- mpmath==1.3.0
- mujoco==3.2.2
- mujoco-python-viewer==0.1.4
- multidict==6.7.0
- multiprocess==0.70.18
- mypy_extensions==1.1.0
- networkx==3.4.2
- numcodecs==0.13.1
- numpy==2.2.6
- nvidia-cublas-cu12==12.4.5.8
- nvidia-cuda-cupti-cu12==12.4.127
- nvidia-cuda-nvrtc-cu12==12.4.127
- nvidia-cuda-runtime-cu12==12.4.127
- nvidia-cudnn-cu12==9.1.0.70
- nvidia-cufft-cu12==11.2.1.3
- nvidia-cufile-cu12==1.11.1.6
- nvidia-curand-cu12==10.3.5.147
- nvidia-cusolver-cu12==11.6.1.9
- nvidia-cusparse-cu12==12.3.1.170
- nvidia-cusparselt-cu12==0.6.3
- nvidia-nccl-cu12==2.21.5
- nvidia-nvjitlink-cu12==12.4.127
- nvidia-nvshmem-cu12==3.3.20
- nvidia-nvtx-cu12==12.4.127
- omegaconf==2.3.0
- opencv-contrib-python==4.10.0.84
- opencv-python==4.13.0.90
- orderly-set==5.5.0
- orjson==3.11.5
- packaging==24.2
- pandas==2.3.3
- parso==0.8.5
- pexpect==4.9.0
- pfzy==0.3.4
- pillow==12.1.0
- pin==3.3.1
- platformdirs==4.5.1
- prompt_toolkit==3.0.52
- propcache==0.4.1
- protobuf==6.33.4
- proxsuite==0.7.2
- psutil==7.2.1
- ptyprocess==0.7.0
- pure_eval==0.2.3
- pyarrow==22.0.0
- pydantic==2.12.5
- pydantic_core==2.41.5
- pydub==0.25.1
- pynput==1.8.1
- pyquaternion==0.9.9
- pyserial==3.5
- python-dateutil==2.9.0.post0
- python-multipart==0.0.21
- python-xlib==0.33
- pytorch-lightning==2.6.0
- pyyaml-include==1.4.1
- qwen-vl-utils==0.0.14
- regex==2026.1.15
- requests==2.32.5
- rerun-sdk==0.26.2
- rich==14.2.0
- ruckig==0.9.2
- safehttpx==0.1.7
- safetensors==0.7.0
- scipy==1.14.1
- semantic-version==2.10.0
- sentry-sdk==2.49.0
- shellingham==1.5.4
- smmap==5.0.2
- stack-data==0.6.3
- starlette==0.50.0
- sympy==1.13.1
- termcolor==3.3.0
- timm==1.0.24
- toml==0.10.2
- tomli==2.4.0
- tomlkit==0.13.3
- torch==2.5.0
- torchcodec==0.5
- torchmetrics==1.8.2
- torchvision==0.20.0
- tqdm==4.67.1
- traitlets==5.14.3
- triton==3.1.0
- typer==0.21.1
- typer-slim==0.21.1
- typeshed_client==2.8.2
- typing-inspect==0.9.0
- typing-inspection==0.4.2
- typing_extensions==4.15.0
- tzdata==2025.3
- urdf_parser_py==0.0.4
- urllib3==2.6.3
- uv==0.9.28
- uvicorn==0.40.0
- wandb==0.24.0
- wcwidth==0.2.14
- xxhash==3.6.0
- yarl==1.22.0
- zarr==2.18.3
- zipp==3.20.1

324
generate_dataset_videos.py Normal file
View File

@@ -0,0 +1,324 @@
#!/usr/bin/env python3
"""
将 HDF5 数据集转换为视频,用于可视化检查
功能:
1. 将单个 episode 转换为视频
2. 对比多个 episode 的视频
3. 放慢播放速度便于观察
"""
import os
import h5py
import glob
import cv2
import numpy as np
def episode_to_video(episode_file, output_path, camera='top', fps=30, slow_factor=1):
"""
将单个 episode 转换为视频
Args:
episode_file: HDF5 文件路径
output_path: 输出视频路径
camera: 要使用的相机名称
fps: 帧率
slow_factor: 慢放倍数1=正常2=半速)
"""
try:
with h5py.File(episode_file, 'r') as f:
# 读取图像序列
img_path = f'/observations/images/{camera}'
if img_path not in f:
print(f" ❌ 相机 {camera} 不存在")
return False
images = f[img_path][:] # shape: (T, H, W, C)
qpos = f['/observations/qpos'][:]
actions = f['/action'][:]
total_frames = len(images)
height, width = images.shape[1], images.shape[2]
# 创建视频写入器
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
actual_fps = fps // slow_factor
out = cv2.VideoWriter(output_path, fourcc, actual_fps, (width, height))
# 逐帧写入
for i in range(total_frames):
frame = images[i].astype(np.uint8)
# 在图像上添加信息
info_text = [
f"Episode: {os.path.basename(episode_file).replace('.hdf5', '')}",
f"Frame: {i}/{total_frames}",
f"qpos[0:3]: [{qpos[i, 0]:.2f}, {qpos[i, 1]:.2f}, {qpos[i, 2]:.2f}]",
]
for j, text in enumerate(info_text):
cv2.putText(frame, text, (10, 30 + j*30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
out.write(frame)
out.release()
print(f" ✅ 保存: {output_path}")
print(f" 帧数: {total_frames}, 尺寸: {width}x{height}, FPS: {actual_fps}")
return True
except Exception as e:
print(f" ❌ 错误: {e}")
return False
def generate_all_videos(camera='top', num_episodes=5, slow_factor=1):
"""生成前 N 个 episode 的视频"""
dataset_dir = "roboimi/demos/dataset/sim_transfer"
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
if len(episode_files) == 0:
print(f"❌ 没有找到数据文件: {dataset_dir}")
return
# 创建输出目录
output_dir = '/tmp/dataset_videos'
os.makedirs(output_dir, exist_ok=True)
print(f"找到 {len(episode_files)} 个 episode 文件")
print(f"将生成前 {min(num_episodes, len(episode_files))} 个 episode 的视频\n")
# 生成视频
for i in range(min(num_episodes, len(episode_files))):
ep_file = episode_files[i]
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
output_path = f"{output_dir}/{ep_name}_{camera}.mp4"
print(f"[{i+1}/{min(num_episodes, len(episode_files))}] {ep_name}")
episode_to_video(ep_file, output_path, camera=camera, slow_factor=slow_factor)
print()
print(f"✅ 所有视频已保存到: {output_dir}")
print(f"\n播放方法:")
print(f" # 播放单个视频")
print(f" vlc {output_dir}/*.mp4")
print(f" ")
print(f" # 或用文件管理器")
print(f" nautilus {output_dir}")
def generate_multi_camera_video(episode_idx=0, slow_factor=1):
"""生成包含多个相机的视频(分屏显示)"""
dataset_dir = "roboimi/demos/dataset/sim_transfer"
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
if episode_idx >= len(episode_files):
print(f"❌ Episode {episode_idx} 不存在")
return
ep_file = episode_files[episode_idx]
try:
with h5py.File(ep_file, 'r') as f:
# 获取所有相机
cameras = []
for key in f.keys():
if 'images' in key:
for cam_name in f[key].keys():
if cam_name not in cameras:
cameras.append(cam_name)
print(f"Episode {episode_idx} 的相机: {cameras}")
# 读取所有相机的图像
all_images = {}
for cam in cameras:
img_path = f'/observations/images/{cam}'
if img_path in f:
all_images[cam] = f[img_path][:]
if not all_images:
print("❌ 没有找到图像数据")
return
# 获取第一个相机的尺寸
first_cam = list(all_images.keys())[0]
total_frames = len(all_images[first_cam])
height, width = all_images[first_cam].shape[1], all_images[first_cam].shape[2]
# 创建多相机布局
num_cams = len(all_images)
cols = min(2, num_cams)
rows = (num_cams + cols - 1) // cols
canvas_width = width * cols
canvas_height = height * rows
# 创建视频写入器
output_path = f'/tmp/dataset_videos/episode_{episode_idx}_all_cameras.mp4'
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, 30 // slow_factor, (canvas_width, canvas_height))
# 逐帧合成
for i in range(total_frames):
canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8)
for cam_idx, cam_name in enumerate(all_images.keys()):
img = all_images[cam_name][i]
# 计算在画布上的位置
row = cam_idx // cols
col = cam_idx % cols
y_start = row * height
y_end = y_start + height
x_start = col * width
x_end = x_start + width
# 调整大小(如果需要)
if img.shape[:2] != (height, width):
img = cv2.resize(img, (width, height))
# 放到画布上
canvas[y_start:y_end, x_start:x_end] = img
# 添加相机名称
cv2.putText(canvas, cam_name, (x_start + 10, y_start + 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
# 添加帧信息
cv2.putText(canvas, f"Frame: {i}/{total_frames}", (10, canvas_height - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
out.write(canvas)
out.release()
print(f"✅ 保存多相机视频: {output_path}")
except Exception as e:
print(f"❌ 错误: {e}")
def compare_episodes(camera='top', slow_factor=2):
"""并排对比多个 episode 的视频"""
dataset_dir = "roboimi/demos/dataset/sim_transfer"
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
# 选择要对比的 episode
episodes_to_compare = [0, 1, 2, 3, 4] # 对比前 5 个
print(f"对比 Episodes: {episodes_to_compare}")
# 读取所有 episode 的数据
all_data = []
for ep_idx in episodes_to_compare:
if ep_idx >= len(episode_files):
continue
try:
with h5py.File(episode_files[ep_idx], 'r') as f:
img_path = f'/observations/images/{camera}'
if img_path in f:
all_data.append({
'idx': ep_idx,
'images': f[img_path][:],
'qpos': f['/observations/qpos'][:]
})
except:
pass
if len(all_data) == 0:
print("❌ 没有数据")
return
# 获取参数
first_data = all_data[0]
height, width = first_data['images'].shape[1], first_data['images'].shape[2]
total_frames = min([d['images'].shape[0] for d in all_data])
# 创建并排布局
num_compare = len(all_data)
canvas_width = width * num_compare
canvas_height = height
# 创建视频
output_path = f'/tmp/dataset_videos/compare_{camera}.mp4'
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, 30 // slow_factor, (canvas_width, canvas_height))
print(f"生成对比视频,共 {total_frames} 帧...")
# 逐帧对比
for i in range(total_frames):
canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8)
for j, data in enumerate(all_data):
img = data['images'][i]
qpos = data['qpos'][i]
# 调整大小(如果需要)
if img.shape[:2] != (height, width):
img = cv2.resize(img, (width, height))
# 放到画布上
x_start = j * width
x_end = x_start + width
canvas[:, x_start:x_end] = img
# 添加信息
ep_name = f"Ep {data['idx']}"
cv2.putText(canvas, ep_name, (x_start + 10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
cv2.putText(canvas, f"qpos[0:3]: [{qpos[0]:.2f}, {qpos[1]:.2f}, {qpos[2]:.2f}]",
(x_start + 10, height - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
# 添加帧号
cv2.putText(canvas, f"Frame: {i}/{total_frames}", (10, canvas_height - 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
out.write(canvas)
if i % 100 == 0:
print(f" 进度: {i}/{total_frames}")
out.release()
print(f"✅ 保存对比视频: {output_path}")
if __name__ == "__main__":
import sys
print("="*60)
print("数据集视频生成工具")
print("="*60)
if len(sys.argv) > 1:
command = sys.argv[1]
if command == 'compare':
# 对比多个 episode
camera = sys.argv[2] if len(sys.argv) > 2 else 'top'
compare_episodes(camera=camera, slow_factor=2)
elif command == 'multi':
# 多相机视频
ep_idx = int(sys.argv[2]) if len(sys.argv) > 2 else 0
generate_multi_camera_video(episode_idx=ep_idx, slow_factor=1)
else:
print("未知命令")
else:
# 默认:生成前 5 个 episode 的视频
print("\n生成前 5 个 episode 的视频top 相机,慢放 2x...")
print("="*60 + "\n")
generate_all_videos(camera='top', num_episodes=5, slow_factor=2)
print("\n" + "="*60)
print("其他用法:")
print(" python generate_dataset_videos.py compare top # 对比多个 episode")
print(" python generate_dataset_videos.py multi 0 # 多相机视频")
print("="*60)

125
gr00t/main.py Normal file
View File

@@ -0,0 +1,125 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
GR00T (diffusion-based DiT policy) model builder.
This module provides functions to build GR00T models and optimizers
from configuration dictionaries (typically from config.yaml's 'gr00t:' section).
"""
import argparse
from pathlib import Path
import numpy as np
import torch
from .models import build_gr00t_model
def get_args_parser():
"""
Create argument parser for GR00T model configuration.
All parameters can be overridden via args_override dictionary in
build_gr00t_model_and_optimizer(). This allows loading from config.yaml.
"""
parser = argparse.ArgumentParser('GR00T training and evaluation script', add_help=False)
# Training parameters
parser.add_argument('--lr', default=1e-5, type=float,
help='Learning rate for main parameters')
parser.add_argument('--lr_backbone', default=1e-5, type=float,
help='Learning rate for backbone parameters')
parser.add_argument('--weight_decay', default=1e-4, type=float,
help='Weight decay for optimizer')
# GR00T model architecture parameters
parser.add_argument('--embed_dim', default=1536, type=int,
help='Embedding dimension for transformer')
parser.add_argument('--hidden_dim', default=1024, type=int,
help='Hidden dimension for MLP layers')
parser.add_argument('--state_dim', default=16, type=int,
help='State (qpos) dimension')
parser.add_argument('--action_dim', default=16, type=int,
help='Action dimension')
parser.add_argument('--num_queries', default=16, type=int,
help='Number of action queries (chunk size)')
# DiT (Diffusion Transformer) parameters
parser.add_argument('--num_layers', default=16, type=int,
help='Number of transformer layers')
parser.add_argument('--nheads', default=32, type=int,
help='Number of attention heads')
parser.add_argument('--mlp_ratio', default=4, type=float,
help='MLP hidden dimension ratio')
parser.add_argument('--dropout', default=0.2, type=float,
help='Dropout rate')
# Backbone parameters
parser.add_argument('--backbone', default='dino_v2', type=str,
help='Backbone architecture (dino_v2, resnet18, resnet34)')
parser.add_argument('--position_embedding', default='sine', type=str,
choices=('sine', 'learned'),
help='Type of positional encoding')
# Camera configuration
parser.add_argument('--camera_names', default=[], nargs='+',
help='List of camera names for observations')
# Other parameters (not directly used but kept for compatibility)
parser.add_argument('--batch_size', default=15, type=int)
parser.add_argument('--epochs', default=20000, type=int)
parser.add_argument('--masks', action='store_true',
help='Use intermediate layer features')
parser.add_argument('--dilation', action='store_false',
help='Use dilated convolution in backbone')
return parser
def build_gr00t_model_and_optimizer(args_override):
"""
Build GR00T model and optimizer from config dictionary.
This function is designed to work with config.yaml loading:
1. Parse default arguments
2. Override with values from args_override (typically from config['gr00t'])
3. Build model and optimizer
Args:
args_override: Dictionary of config values, typically from config.yaml's 'gr00t:' section
Expected keys: embed_dim, hidden_dim, state_dim, action_dim,
num_queries, nheads, mlp_ratio, dropout, num_layers,
lr, lr_backbone, camera_names, backbone, etc.
Returns:
model: GR00T model on CUDA
optimizer: AdamW optimizer with separate learning rates for backbone and other params
"""
parser = argparse.ArgumentParser('GR00T training and evaluation script',
parents=[get_args_parser()])
args = parser.parse_args()
# Override with config values
for k, v in args_override.items():
setattr(args, k, v)
# Build model
model = build_gr00t_model(args)
model.cuda()
# Create parameter groups with different learning rates
param_dicts = [
{
"params": [p for n, p in model.named_parameters()
if "backbone" not in n and p.requires_grad]
},
{
"params": [p for n, p in model.named_parameters()
if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts,
lr=args.lr,
weight_decay=args.weight_decay)
return model, optimizer

3
gr00t/models/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .gr00t import build_gr00t_model
__all__ = ['build_gr00t_model']

142
gr00t/models/dit.py Normal file
View File

@@ -0,0 +1,142 @@
from typing import Optional
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
import torch
from torch import nn
import torch.nn.functional as F
class TimestepEncoder(nn.Module):
def __init__(self, args):
super().__init__()
embedding_dim = args.embed_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps):
dtype = next(self.parameters()).dtype
timesteps_proj = self.time_proj(timesteps).to(dtype)
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
return timesteps_emb
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim, norm_eps=1e-5, norm_elementwise_affine=False):
super().__init__()
output_dim = embedding_dim * 2
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self,
x: torch.Tensor,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
temb = self.linear(self.silu(temb))
scale, shift = temb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
return x
class BasicTransformerBlock(nn.Module):
def __init__(self, args, crosss_attention_dim, use_self_attn=False):
super().__init__()
dim = args.embed_dim
num_heads = args.nheads
mlp_ratio = args.mlp_ratio
dropout = args.dropout
self.norm1 = AdaLayerNorm(dim)
if not use_self_attn:
self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
dropout=dropout,
kdim=crosss_attention_dim,
vdim=crosss_attention_dim,
batch_first=True,
)
else:
self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * mlp_ratio),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mlp_ratio, dim),
nn.Dropout(dropout)
)
def forward(self, hidden_states, temb, context=None):
norm_hidden_states = self.norm1(hidden_states, temb)
attn_output = self.attn(
norm_hidden_states,
context if context is not None else norm_hidden_states,
context if context is not None else norm_hidden_states,
)[0]
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm2(hidden_states)
ff_output = self.mlp(norm_hidden_states)
hidden_states = ff_output + hidden_states
return hidden_states
class DiT(nn.Module):
def __init__(self, args, cross_attention_dim):
super().__init__()
inner_dim = args.embed_dim
num_layers = args.num_layers
output_dim = args.hidden_dim
self.timestep_encoder = TimestepEncoder(args)
all_blocks = []
for idx in range(num_layers):
use_self_attn = idx % 2 == 1
if use_self_attn:
block = BasicTransformerBlock(args, crosss_attention_dim=None, use_self_attn=True)
else:
block = BasicTransformerBlock(args, crosss_attention_dim=cross_attention_dim, use_self_attn=False)
all_blocks.append(block)
self.transformer_blocks = nn.ModuleList(all_blocks)
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, output_dim)
def forward(self, hidden_states, timestep, encoder_hidden_states):
temb = self.timestep_encoder(timestep)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1:
hidden_states = block(hidden_states, temb)
else:
hidden_states = block(hidden_states, temb, context=encoder_hidden_states)
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(hidden_states)
def build_dit(args, cross_attention_dim):
return DiT(args, cross_attention_dim)

124
gr00t/models/gr00t.py Normal file
View File

@@ -0,0 +1,124 @@
from .modules import (
build_action_decoder,
build_action_encoder,
build_state_encoder,
build_time_sampler,
build_noise_scheduler,
)
from .backbone import build_backbone
from .dit import build_dit
import torch
import torch.nn as nn
import torch.nn.functional as F
class gr00t(nn.Module):
def __init__(
self,
backbones,
dit,
state_encoder,
action_encoder,
action_decoder,
time_sampler,
noise_scheduler,
num_queries,
camera_names,
):
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.dit = dit
self.state_encoder = state_encoder
self.action_encoder = action_encoder
self.action_decoder = action_decoder
self.time_sampler = time_sampler
self.noise_scheduler = noise_scheduler
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
else:
raise NotImplementedError
def forward(self, qpos, image, actions=None, is_pad=None):
is_training = actions is not None # train or val
bs, _ = qpos.shape
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
B, C, H, W = features.shape
features_seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)
all_cam_features.append(features_seq)
encoder_hidden_states = torch.cat(all_cam_features, dim=1)
state_features = self.state_encoder(qpos) # [B, 1, emb_dim]
if is_training:
# training logic
timesteps = self.time_sampler(bs, actions.device, actions.dtype)
noisy_actions, target_velocity = self.noise_scheduler.add_noise(
actions, timesteps
)
t_discretized = (timesteps[:, 0, 0] * 1000).long()
action_features = self.action_encoder(noisy_actions, t_discretized)
sa_embs = torch.cat((state_features, action_features), dim=1)
model_output = self.dit(sa_embs, t_discretized, encoder_hidden_states)
pred = self.action_decoder(model_output)
pred_actions = pred[:, -actions.shape[1] :]
action_loss = F.mse_loss(pred_actions, target_velocity, reduction='none')
return pred_actions, action_loss
else:
actions = torch.randn(bs, self.num_queries, qpos.shape[-1], device=qpos.device, dtype=qpos.dtype)
k = 5
dt = 1.0 / k
for t in range(k):
t_cont = t / float(k)
t_discretized = int(t_cont * 1000)
timesteps = torch.full((bs,), t_discretized, device=qpos.device, dtype=qpos.dtype)
action_features = self.action_encoder(actions, timesteps)
sa_embs = torch.cat((state_features, action_features), dim=1)
# Create tensor of shape [B] for DiT (consistent with training path)
model_output = self.dit(sa_embs, timesteps, encoder_hidden_states)
pred = self.action_decoder(model_output)
pred_velocity = pred[:, -self.num_queries :]
actions = actions + pred_velocity * dt
return actions, _
def build_gr00t_model(args):
state_dim = args.state_dim
action_dim = args.action_dim
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
cross_attention_dim = backbones[0].num_channels
dit = build_dit(args, cross_attention_dim)
state_encoder = build_state_encoder(args)
action_encoder = build_action_encoder(args)
action_decoder = build_action_decoder(args)
time_sampler = build_time_sampler(args)
noise_scheduler = build_noise_scheduler(args)
model = gr00t(
backbones,
dit,
state_encoder,
action_encoder,
action_decoder,
time_sampler,
noise_scheduler,
args.num_queries,
args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters/1e6,))
return model

179
gr00t/models/modules.py Normal file
View File

@@ -0,0 +1,179 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# ActionEncoder
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, args):
super().__init__()
self.embed_dim = args.embed_dim
def forward(self, timesteps):
timesteps = timesteps.float()
B, T = timesteps.shape
device = timesteps.device
half_dim = self.embed_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
freqs = timesteps.unsqueeze(-1) * exponent.exp()
sin = torch.sin(freqs)
cos = torch.cos(freqs)
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
return enc
class ActionEncoder(nn.Module):
def __init__(self, args):
super().__init__()
action_dim = args.action_dim
embed_dim = args.embed_dim
self.W1 = nn.Linear(action_dim, embed_dim)
self.W2 = nn.Linear(2 * embed_dim, embed_dim)
self.W3 = nn.Linear(embed_dim, embed_dim)
self.pos_encoder = SinusoidalPositionalEncoding(args)
def forward(self, actions, timesteps):
B, T, _ = actions.shape
# 1) Expand each batch's single scalar time 'tau' across all T steps
# so that shape => (B, T)
# Handle different input shapes: (B,), (B, 1), (B, 1, 1)
# Reshape to (B,) then expand to (B, T)
# if timesteps.dim() == 3:
# # Shape (B, 1, 1) or (B, T, 1) -> (B,)
# timesteps = timesteps[:, 0, 0]
# elif timesteps.dim() == 2:
# # Shape (B, 1) or (B, T) -> take first element if needed
# if timesteps.shape[1] == 1:
# timesteps = timesteps[:, 0]
# # else: already (B, T), use as is
# elif timesteps.dim() != 1:
# raise ValueError(
# f"Expected `timesteps` to have shape (B,), (B, 1), or (B, 1, 1), got {timesteps.shape}"
# )
# Now timesteps should be (B,), expand to (B, T)
if timesteps.dim() == 1 and timesteps.shape[0] == B:
timesteps = timesteps.unsqueeze(1).expand(-1, T)
else:
raise ValueError(
"Expected `timesteps` to have shape (B,) so we can replicate across T."
)
# 2) Standard action MLP step for shape => (B, T, w)
a_emb = self.W1(actions)
# 3) Get the sinusoidal encoding (B, T, w)
tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype)
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
x = torch.cat([a_emb, tau_emb], dim=-1)
x = F.silu(self.W2(x))
# 5) Finally W3 => (B, T, w)
x = self.W3(x)
return x
def build_action_encoder(args):
return ActionEncoder(args)
# StateEncoder
class StateEncoder(nn.Module):
def __init__(self, args):
super().__init__()
input_dim = args.state_dim
hidden_dim = args.hidden_dim
output_dim = args.embed_dim
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, states):
state_emb = self.mlp(states) # [B, emb_dim]
state_emb = state_emb.unsqueeze(1)
return state_emb # [B, 1, emb_dim]
def build_state_encoder(args):
return StateEncoder(args)
# ActionDecoder
class ActionDecoder(nn.Module):
def __init__(self,args):
super().__init__()
input_dim = args.hidden_dim
hidden_dim = args.hidden_dim
output_dim = args.action_dim
self.num_queries = args.num_queries
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, model_output):
pred_actions = self.mlp(model_output)
return pred_actions[:, -self.num_queries:]
def build_action_decoder(args):
return ActionDecoder(args)
# TimeSampler
class TimeSampler(nn.Module):
def __init__(self, noise_s = 0.999, noise_beta_alpha=1.5, noise_beta_beta=1.0):
super().__init__()
self.noise_s = noise_s
self.beta_dist = torch.distributions.Beta(noise_beta_alpha, noise_beta_beta)
def forward(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
sample = (1 - sample) * self.noise_s
return sample[:, None, None]
def build_time_sampler(args):
return TimeSampler()
# NoiseScheduler
import torch
import torch.nn as nn
class FlowMatchingScheduler(nn.Module):
def __init__(self):
super().__init__()
# --- 训练逻辑:加噪并计算目标 ---
def add_noise(self, actions, timesteps):
noise = torch.randn_like(actions)
noisy_samples = actions * timesteps + noise * (1 - timesteps)
target_velocity = actions - noise
return noisy_samples, target_velocity
# --- 推理逻辑:欧拉步 (Euler Step) ---
def step(self, model_output, sample, dt):
prev_sample = sample + model_output * dt
return prev_sample
def build_noise_scheduler(args):
return FlowMatchingScheduler()

90
gr00t/policy.py Normal file
View File

@@ -0,0 +1,90 @@
"""
GR00T Policy wrapper for imitation learning.
This module provides the gr00tPolicy class that wraps the GR00T model
for training and evaluation in the imitation learning framework.
"""
import torch.nn as nn
from torch.nn import functional as F
from torchvision.transforms import v2
import torch
from roboimi.gr00t.main import build_gr00t_model_and_optimizer
class gr00tPolicy(nn.Module):
"""
GR00T Policy for action prediction using diffusion-based DiT architecture.
This policy wraps the GR00T model and handles:
- Image resizing to match DINOv2 patch size requirements
- Image normalization (ImageNet stats)
- Training with action chunks and loss computation
- Inference with diffusion sampling
"""
def __init__(self, args_override):
super().__init__()
model, optimizer = build_gr00t_model_and_optimizer(args_override)
self.model = model
self.optimizer = optimizer
# DINOv2 requires image dimensions to be multiples of patch size (14)
# Common sizes: 224x224, 336x336, etc. (14*16=224, 14*24=336)
self.patch_h = 16 # Number of patches vertically
self.patch_w = 22 # Number of patches horizontally
target_size = (self.patch_h * 14, self.patch_w * 14) # (224, 308)
# Training transform with data augmentation
self.train_transform = v2.Compose([
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
v2.RandomPerspective(distortion_scale=0.5),
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
v2.Resize(target_size),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
# Inference transform (no augmentation)
self.inference_transform = v2.Compose([
v2.Resize(target_size),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
def __call__(self, qpos, image, actions=None, is_pad=None):
"""
Forward pass for training or inference.
Args:
qpos: Joint positions [B, state_dim]
image: Camera images [B, num_cameras, C, H, W]
actions: Ground truth actions [B, chunk_size, action_dim] (training only)
is_pad: Padding mask [B, chunk_size] (training only)
Returns:
Training: dict with 'mse' loss
Inference: predicted actions [B, num_queries, action_dim]
"""
# Apply transforms (resize + normalization)
if actions is not None: # training time
image = self.train_transform(image)
else: # inference time
image = self.inference_transform(image)
if actions is not None: # training time
actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries]
_, action_loss = self.model(qpos, image, actions, is_pad)
# Mask out padded positions
mse_loss = (action_loss * ~is_pad.unsqueeze(-1)).mean()
loss_dict = {
'loss': mse_loss
}
return loss_dict
else: # inference time
a_hat, _ = self.model(qpos, image)
return a_hat
def configure_optimizers(self):
"""Return the optimizer for training."""
return self.optimizer

1
roboimi/.gitattributes vendored Normal file
View File

@@ -0,0 +1 @@
*.safetensors filter=lfs diff=lfs merge=lfs -text

0
roboimi/__init__.py Normal file
View File

View File

@@ -3,7 +3,7 @@
<body name="box" pos="0.2 1.0 0.47"> <body name="box" pos="0.2 1.0 0.47">
<joint name="red_box_joint" type="free" frictionloss="0.01" /> <joint name="red_box_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" /> <inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<geom contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" /> <geom contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.018 0.018 0.02" type="box" name="red_box" rgba="1 0 0 1" />
</body> </body>
</worldbody> </worldbody>
</mujoco> </mujoco>

View File

@@ -8,5 +8,6 @@
</body> </body>
<camera name="top" pos="0.0 1.0 2.0" fovy="44" mode="targetbody" target="table"/> <camera name="top" pos="0.0 1.0 2.0" fovy="44" mode="targetbody" target="table"/>
<camera name="angle" pos="0.0 0.0 2.0" fovy="37" mode="targetbody" target="table"/> <camera name="angle" pos="0.0 0.0 2.0" fovy="37" mode="targetbody" target="table"/>
<camera name="front" pos="0 0 0.8" fovy="65" mode="fixed" quat="0.7071 0.7071 0 0"/>
</worldbody> </worldbody>
</mujoco> </mujoco>

View File

@@ -58,8 +58,8 @@ class BiDianaMed(ArmBase):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
name="Bidiana", name="Bidiana",
urdf_path="./assets/models/manipulators/DianaMed/DualDianaMed.urdf", urdf_path="roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf",
xml_path="./assets/models/manipulators/DianaMed/bi_diana_transfer_ee.xml", xml_path="roboimi/assets/models/manipulators/DianaMed/bi_diana_transfer_ee.xml",
gripper=None gripper=None
) )
self.left_arm = self.Arm(self, 'single', self.urdf_path) self.left_arm = self.Arm(self, 'single', self.urdf_path)

View File

@@ -1,112 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DDT 模型构建和优化器配置。
"""
import argparse
from pathlib import Path
import numpy as np
import torch
from .models import build_DDT_model
def get_args_parser():
"""获取 DDT 模型的参数解析器。"""
parser = argparse.ArgumentParser('DDT model configuration', add_help=False)
# 学习率
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--lr_backbone', default=1e-5, type=float)
parser.add_argument('--batch_size', default=2, type=int)
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--epochs', default=300, type=int)
parser.add_argument('--lr_drop', default=200, type=int)
parser.add_argument('--clip_max_norm', default=0.1, type=float,
help='gradient clipping max norm')
parser.add_argument('--qpos_noise_std', action='store', default=0, type=float)
# Backbone 参数
parser.add_argument('--backbone', default='resnet18', type=str,
help="Name of the convolutional backbone to use")
parser.add_argument('--dilation', action='store_true',
help="If true, replace stride with dilation in the last conv block")
parser.add_argument('--position_embedding', default='sine', type=str,
choices=('sine', 'learned'),
help="Type of positional embedding")
parser.add_argument('--camera_names', default=[], type=list,
help="A list of camera names")
# Transformer 参数
parser.add_argument('--enc_layers', default=4, type=int,
help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=6, type=int,
help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int,
help="Intermediate size of the feedforward layers")
parser.add_argument('--hidden_dim', default=512, type=int,
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float,
help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int,
help="Number of attention heads")
parser.add_argument('--num_queries', default=100, type=int,
help="Number of query slots (action horizon)")
parser.add_argument('--pre_norm', action='store_true')
parser.add_argument('--state_dim', default=14, type=int)
parser.add_argument('--action_dim', default=14, type=int)
# DDT 特有参数
parser.add_argument('--num_blocks', default=12, type=int,
help="Total number of transformer blocks in DDT")
parser.add_argument('--mlp_ratio', default=4.0, type=float,
help="MLP hidden dimension ratio")
parser.add_argument('--num_inference_steps', default=10, type=int,
help="Number of diffusion inference steps")
# Segmentation (未使用)
parser.add_argument('--masks', action='store_true',
help="Train segmentation head if provided")
return parser
def build_DDT_model_and_optimizer(args_override):
"""构建 DDT 模型和优化器。
Args:
args_override: 覆盖默认参数的字典
Returns:
model: DDT 模型
optimizer: AdamW 优化器
"""
parser = argparse.ArgumentParser('DDT training script', parents=[get_args_parser()])
args = parser.parse_args([]) # 空列表避免命令行参数干扰
# 应用参数覆盖
for k, v in args_override.items():
setattr(args, k, v)
# 构建模型
model = build_DDT_model(args)
model.cuda()
# 配置优化器backbone 使用较小学习率)
param_dicts = [
{
"params": [p for n, p in model.named_parameters()
if "backbone" not in n and p.requires_grad]
},
{
"params": [p for n, p in model.named_parameters()
if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
param_dicts,
lr=args.lr,
weight_decay=args.weight_decay
)
return model, optimizer

View File

@@ -1,7 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .model import build as build_ddt
from .model import build_ddt
def build_DDT_model(args):
"""构建 DDT 模型的统一入口。"""
return build_ddt(args)

View File

@@ -1,631 +0,0 @@
"""
动作序列扩散 Transformer (Action Decoupled Diffusion Transformer)
基于 DDT 架构修改,用于生成机器人动作序列。
主要改动:
1. 2D RoPE → 1D RoPE (适配时序数据)
2. LabelEmbedder → ObservationEncoder (观测条件)
3. 去除 patchify/unpatchify (动作序列已是 1D)
"""
import math
from typing import Tuple, Optional
import torch
import torch.nn as nn
from torch.nn.functional import scaled_dot_product_attention
# ============================================================================
# 通用工具函数
# ============================================================================
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""AdaLN 调制函数。
Args:
x: 输入张量。
shift: 偏移量。
scale: 缩放量。
Returns:
调制后的张量: x * (1 + scale) + shift
"""
return x * (1 + scale) + shift
# ============================================================================
# 1D 旋转位置编码 (RoPE)
# ============================================================================
def precompute_freqs_cis_1d(dim: int, seq_len: int, theta: float = 10000.0) -> torch.Tensor:
"""预计算 1D 旋转位置编码的复数频率。
用于时序数据(如动作序列)的位置编码,相比 2D RoPE 更简单高效。
Args:
dim: 每个注意力头的维度 (head_dim)。
seq_len: 序列长度。
theta: RoPE 的基础频率,默认 10000.0。
Returns:
复数频率张量,形状为 (seq_len, dim//2)。
"""
# 计算频率: 1 / (theta^(2i/dim))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) # [dim//2]
# 位置索引
t = torch.arange(seq_len).float() # [seq_len]
# 外积得到位置-频率矩阵
freqs = torch.outer(t, freqs) # [seq_len, dim//2]
# 转换为复数形式 (极坐标)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # [seq_len, dim//2]
return freqs_cis
def apply_rotary_emb_1d(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""应用 1D 旋转位置编码到 Query 和 Key。
Args:
xq: Query 张量,形状为 (B, N, H, Hc)。
xk: Key 张量,形状为 (B, N, H, Hc)。
freqs_cis: 预计算的复数频率,形状为 (N, Hc//2)。
Returns:
应用 RoPE 后的 (xq, xk),形状不变。
"""
# 调整 freqs_cis 形状以便广播: [1, N, 1, Hc//2]
freqs_cis = freqs_cis[None, :, None, :]
# 将实数张量视为复数: [B, N, H, Hc] -> [B, N, H, Hc//2] (复数)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# 复数乘法实现旋转
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [B, N, H, Hc]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# ============================================================================
# 基础组件
# ============================================================================
class Embed(nn.Module):
"""线性嵌入层,将输入投影到隐藏空间。"""
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[nn.Module] = None,
bias: bool = True,
):
"""初始化 Embed。
Args:
in_chans: 输入通道数/维度。
embed_dim: 输出嵌入维度。
norm_layer: 可选的归一化层。
bias: 是否使用偏置。
"""
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
"""扩散时间步嵌入器。
使用正弦位置编码 + MLP 将标量时间步映射到高维向量。
"""
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
"""初始化 TimestepEmbedder。
Args:
hidden_size: 输出嵌入维度。
frequency_embedding_size: 正弦编码的维度。
"""
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10.0) -> torch.Tensor:
"""生成正弦时间步嵌入。
Args:
t: 时间步张量,形状为 (B,)。
dim: 嵌入维度。
max_period: 最大周期。
Returns:
时间步嵌入,形状为 (B, dim)。
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t: torch.Tensor) -> torch.Tensor:
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class ObservationEncoder(nn.Module):
"""观测状态编码器。
将机器人的观测向量(如关节位置、末端位姿、图像特征等)
编码为条件向量,用于条件扩散生成。
Attributes:
encoder: 两层 MLP 编码器。
Example:
>>> encoder = ObservationEncoder(obs_dim=128, hidden_size=512)
>>> obs = torch.randn(2, 128)
>>> cond = encoder(obs) # [2, 512]
"""
def __init__(self, obs_dim: int, hidden_size: int):
"""初始化 ObservationEncoder。
Args:
obs_dim: 观测向量的维度。
hidden_size: 输出条件向量的维度。
"""
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(obs_dim, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
"""前向传播。
Args:
obs: 观测向量,形状为 (B, obs_dim)。
Returns:
条件向量,形状为 (B, hidden_size)。
"""
return self.encoder(obs)
class FinalLayer(nn.Module):
"""最终输出层,使用 AdaLN 调制后输出预测结果。"""
def __init__(self, hidden_size: int, out_channels: int):
"""初始化 FinalLayer。
Args:
hidden_size: 输入隐藏维度。
out_channels: 输出通道数/维度。
"""
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
"""前向传播。
Args:
x: 输入张量,形状为 (B, N, hidden_size)。
c: 条件张量,形状为 (B, N, hidden_size) 或 (B, 1, hidden_size)。
Returns:
输出张量,形状为 (B, N, out_channels)。
"""
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
# ============================================================================
# 归一化和前馈网络
# ============================================================================
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization (RMS 归一化层)。
RMSNorm 是 LayerNorm 的简化版本,去掉了均值中心化操作,只保留缩放。
相比 LayerNorm 计算更快,效果相当,被广泛用于 LLaMA、Mistral 等大模型。
数学公式:
RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
"""初始化 RMSNorm。
Args:
hidden_size: 输入特征的维度。
eps: 防止除零的小常数。
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
"""SwiGLU 前馈网络 (Feed-Forward Network)。
使用 SwiGLU 门控激活函数的前馈网络,来自 LLaMA 架构。
结构:
output = W2(SiLU(W1(x)) * W3(x))
"""
def __init__(self, dim: int, hidden_dim: int):
"""初始化 FeedForward。
Args:
dim: 输入和输出的特征维度。
hidden_dim: 隐藏层维度(实际使用 2/3 * hidden_dim
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
# ============================================================================
# 注意力机制
# ============================================================================
class RAttention(nn.Module):
"""带有旋转位置编码的多头自注意力 (Rotary Attention)。
集成了以下技术:
- 1D RoPE: 通过复数旋转编码时序位置信息
- QK-Norm: 对 Query 和 Key 进行归一化,稳定训练
- Flash Attention: 使用 scaled_dot_product_attention
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
"""初始化 RAttention。
Args:
dim: 输入特征维度,必须能被 num_heads 整除。
num_heads: 注意力头数。
qkv_bias: QKV 投影是否使用偏置。
qk_norm: 是否对 Q, K 进行归一化。
attn_drop: 注意力权重的 dropout 率。
proj_drop: 输出投影的 dropout 率。
norm_layer: 归一化层类型。
"""
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(
self,
x: torch.Tensor,
pos: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""前向传播。
Args:
x: 输入张量,形状为 (B, N, C)。
pos: 1D RoPE 位置编码,形状为 (N, head_dim//2)。
mask: 可选的注意力掩码。
Returns:
输出张量,形状为 (B, N, C)。
"""
B, N, C = x.shape
# QKV 投影
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # [B, N, H, Hc]
# QK-Norm
q = self.q_norm(q)
k = self.k_norm(k)
# 应用 1D RoPE
q, k = apply_rotary_emb_1d(q, k, freqs_cis=pos)
# 调整维度: [B, N, H, Hc] -> [B, H, N, Hc]
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2)
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
# Scaled Dot-Product Attention
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
# 输出投影
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
# ============================================================================
# Transformer Block
# ============================================================================
class ActionDDTBlock(nn.Module):
"""动作 DDT Transformer Block。
结构: Pre-Norm + AdaLN + Attention + FFN
"""
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0):
"""初始化 ActionDDTBlock。
Args:
hidden_size: 隐藏层维度。
num_heads: 注意力头数。
mlp_ratio: FFN 隐藏层倍率。
"""
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=num_heads, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor,
pos: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""前向传播。
Args:
x: 输入张量,形状为 (B, N, hidden_size)。
c: 条件张量,形状为 (B, 1, hidden_size) 或 (B, N, hidden_size)。
pos: 位置编码。
mask: 可选的注意力掩码。
Returns:
输出张量,形状为 (B, N, hidden_size)。
"""
# AdaLN 调制参数
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(c).chunk(6, dim=-1)
# Attention 分支
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
# FFN 分支
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
# ============================================================================
# 主模型: ActionDDT
# ============================================================================
class ActionDDT(nn.Module):
"""动作序列解耦扩散 Transformer (Action Decoupled Diffusion Transformer)。
基于 DDT 架构,专为机器人动作序列生成设计。
将模型解耦为编码器和解码器两部分,编码器状态可缓存以加速推理。
架构:
- 编码器: 前 num_encoder_blocks 个 block生成状态 s
- 解码器: 剩余 block使用状态 s 对动作序列 x 去噪
- 使用 1D RoPE 进行时序位置编码
- 使用 AdaLN 注入时间步和观测条件
Args:
action_dim: 动作向量维度(如 7-DoF 机械臂为 7
obs_dim: 观测向量维度。
action_horizon: 预测的动作序列长度。
hidden_size: Transformer 隐藏层维度。
num_blocks: Transformer block 总数。
num_encoder_blocks: 编码器 block 数量。
num_heads: 注意力头数。
mlp_ratio: FFN 隐藏层倍率。
输入:
x (Tensor): 带噪声的动作序列,形状为 (B, T, action_dim)。
t (Tensor): 扩散时间步,形状为 (B,),取值范围 [0, 1]。
obs (Tensor): 观测条件,形状为 (B, obs_dim)。
s (Tensor, optional): 缓存的编码器状态。
输出:
x (Tensor): 预测的速度场/噪声,形状为 (B, T, action_dim)。
s (Tensor): 编码器状态,可缓存复用。
Example:
>>> model = ActionDDT(action_dim=7, obs_dim=128, action_horizon=16)
>>> x = torch.randn(2, 16, 7) # 带噪声的动作序列
>>> t = torch.rand(2) # 随机时间步
>>> obs = torch.randn(2, 128) # 观测条件
>>> out, state = model(x, t, obs)
>>> out.shape
torch.Size([2, 16, 7])
"""
def __init__(
self,
action_dim: int = 7,
obs_dim: int = 128,
action_horizon: int = 16,
hidden_size: int = 512,
num_blocks: int = 12,
num_encoder_blocks: int = 4,
num_heads: int = 8,
mlp_ratio: float = 4.0,
):
super().__init__()
# 保存配置
self.action_dim = action_dim
self.obs_dim = obs_dim
self.action_horizon = action_horizon
self.hidden_size = hidden_size
self.num_blocks = num_blocks
self.num_encoder_blocks = num_encoder_blocks
self.num_heads = num_heads
# 动作嵌入层
self.x_embedder = Embed(action_dim, hidden_size, bias=True)
self.s_embedder = Embed(action_dim, hidden_size, bias=True)
# 条件嵌入
self.t_embedder = TimestepEmbedder(hidden_size)
self.obs_encoder = ObservationEncoder(obs_dim, hidden_size)
# 输出层
self.final_layer = FinalLayer(hidden_size, action_dim)
# Transformer blocks
self.blocks = nn.ModuleList([
ActionDDTBlock(hidden_size, num_heads, mlp_ratio)
for _ in range(num_blocks)
])
# 预计算 1D 位置编码
pos = precompute_freqs_cis_1d(hidden_size // num_heads, action_horizon)
self.register_buffer('pos', pos)
# 初始化权重
self.initialize_weights()
def initialize_weights(self):
"""初始化模型权重。"""
# 嵌入层使用 Xavier 初始化
for embedder in [self.x_embedder, self.s_embedder]:
w = embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(embedder.proj.bias, 0)
# 时间步嵌入 MLP
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# 观测编码器
for m in self.obs_encoder.encoder:
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# 输出层零初始化 (AdaLN-Zero)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
obs: torch.Tensor,
s: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""前向传播。
Args:
x: 带噪声的动作序列 [B, T, action_dim]
t: 扩散时间步 [B] 或 [B, 1],取值范围 [0, 1]
obs: 观测条件 [B, obs_dim]
s: 可选的编码器状态缓存 [B, T, hidden_size]
mask: 可选的注意力掩码
Returns:
x: 预测的速度场/噪声 [B, T, action_dim]
s: 编码器状态 [B, T, hidden_size],可缓存复用
"""
B, T, _ = x.shape
# 1. 时间步嵌入: [B] -> [B, 1, hidden_size]
t_emb = self.t_embedder(t.view(-1)).view(B, 1, self.hidden_size)
# 2. 观测条件嵌入: [B, obs_dim] -> [B, 1, hidden_size]
obs_emb = self.obs_encoder(obs).view(B, 1, self.hidden_size)
# 3. 融合条件: c = SiLU(t + obs)
c = nn.functional.silu(t_emb + obs_emb)
# 4. 编码器部分: 生成状态 s
if s is None:
# 状态嵌入: [B, T, action_dim] -> [B, T, hidden_size]
s = self.s_embedder(x)
# 通过编码器 blocks
for i in range(self.num_encoder_blocks):
s = self.blocks[i](s, c, self.pos, mask)
# 融合时间信息
s = nn.functional.silu(t_emb + s)
# 5. 解码器部分: 去噪
# 输入嵌入: [B, T, action_dim] -> [B, T, hidden_size]
x = self.x_embedder(x)
# 通过解码器 blocks以 s 作为条件
for i in range(self.num_encoder_blocks, self.num_blocks):
x = self.blocks[i](x, s, self.pos, None)
# 6. 最终层: [B, T, hidden_size] -> [B, T, action_dim]
x = self.final_layer(x, s)
return x, s

View File

@@ -1,304 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DDT model and criterion classes.
核心组装文件,将 Backbone、Transformer、Diffusion 组件组装为完整模型。
"""
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
from .backbone import build_backbone
class SpatialSoftmax(nn.Module):
"""Spatial Softmax 层,将特征图转换为关键点坐标。
来自 Diffusion Policy保留空间位置信息。
对每个通道计算软注意力加权的期望坐标。
Args:
num_kp: 关键点数量(等于输入通道数)
temperature: Softmax 温度参数(可学习)
learnable_temperature: 是否学习温度参数
输入: [B, C, H, W]
输出: [B, C * 2] - 每个通道输出 (x, y) 坐标
"""
def __init__(self, num_kp: int = None, temperature: float = 1.0, learnable_temperature: bool = True):
super().__init__()
self.num_kp = num_kp
if learnable_temperature:
self.temperature = nn.Parameter(torch.ones(1) * temperature)
else:
self.register_buffer('temperature', torch.ones(1) * temperature)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
# 生成归一化坐标网格 [-1, 1]
pos_x = torch.linspace(-1, 1, W, device=x.device, dtype=x.dtype)
pos_y = torch.linspace(-1, 1, H, device=x.device, dtype=x.dtype)
# 展平空间维度
x_flat = x.view(B, C, -1) # [B, C, H*W]
# Softmax 得到注意力权重
attention = F.softmax(x_flat / self.temperature, dim=-1) # [B, C, H*W]
# 计算期望坐标
# pos_x: [W] -> [1, 1, W] -> repeat -> [1, 1, H*W]
pos_x_grid = pos_x.view(1, 1, 1, W).expand(1, 1, H, W).reshape(1, 1, -1)
pos_y_grid = pos_y.view(1, 1, H, 1).expand(1, 1, H, W).reshape(1, 1, -1)
# 加权求和得到期望坐标
expected_x = (attention * pos_x_grid).sum(dim=-1) # [B, C]
expected_y = (attention * pos_y_grid).sum(dim=-1) # [B, C]
# 拼接 x, y 坐标
keypoints = torch.cat([expected_x, expected_y], dim=-1) # [B, C * 2]
return keypoints
from .ddt import ActionDDT
def get_sinusoid_encoding_table(n_position, d_hid):
"""生成正弦位置编码表。"""
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class DDT(nn.Module):
"""DDT (Decoupled Diffusion Transformer) 模型。
将视觉 Backbone 和 ActionDDT 扩散模型组合,实现基于图像观测的动作序列生成。
架构:
1. Backbone: 提取多相机图像特征
2. 特征投影: 将图像特征投影到隐藏空间 (Bottleneck 降维)
3. 状态编码: 编码机器人关节状态
4. ActionDDT: 扩散 Transformer 生成动作序列
Args:
backbones: 视觉骨干网络列表(每个相机一个)
state_dim: 机器人状态维度
action_dim: 动作维度
num_queries: 预测的动作序列长度
camera_names: 相机名称列表
hidden_dim: Transformer 隐藏维度
num_blocks: Transformer block 数量
num_encoder_blocks: 编码器 block 数量
num_heads: 注意力头数
num_kp: Spatial Softmax 的关键点数量 (默认 32)
"""
def __init__(
self,
backbones,
state_dim: int,
action_dim: int,
num_queries: int,
camera_names: list,
hidden_dim: int = 512,
num_blocks: int = 12,
num_encoder_blocks: int = 4,
num_heads: int = 8,
mlp_ratio: float = 4.0,
num_kp: int = 32, # [修改] 新增参数,默认 32
):
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.hidden_dim = hidden_dim
self.state_dim = state_dim
self.action_dim = action_dim
self.num_kp = num_kp
# Backbone 相关
self.backbones = nn.ModuleList(backbones)
# [修改] 投影层: ResNet Channels -> num_kp (32)
# 这是一个 Bottleneck 层,大幅减少特征通道数
self.input_proj = nn.Conv2d(
backbones[0].num_channels, num_kp, kernel_size=1
)
# 状态编码 (2层 MLP与 Diffusion Policy 一致)
# 状态依然映射到 hidden_dim (512),保持信息量
self.input_proj_robot_state = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
# [修改] 图像特征聚合 (SpatialSoftmax)
# 输入: [B, num_kp, H, W]
# 输出: [B, num_kp * 2] (每个通道的 x, y 坐标)
self.img_feature_proj = SpatialSoftmax(num_kp=num_kp)
# [修改] 计算观测维度: 图像特征 + 状态
# 图像部分: 关键点数量 * 2(x,y) * 摄像头数量
img_feature_dim = num_kp * 2 * len(camera_names)
obs_dim = img_feature_dim + hidden_dim
# ActionDDT 扩散模型
self.action_ddt = ActionDDT(
action_dim=action_dim,
obs_dim=obs_dim, # 使用新的、更紧凑的维度
action_horizon=num_queries,
hidden_size=hidden_dim,
num_blocks=num_blocks,
num_encoder_blocks=num_encoder_blocks,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
)
def encode_observations(self, qpos, image):
"""编码观测(图像 + 状态)为条件向量。
Args:
qpos: 机器人关节状态 [B, state_dim]
image: 多相机图像 [B, num_cam, C, H, W]
Returns:
obs: 观测条件向量 [B, obs_dim]
"""
bs = qpos.shape[0]
# 编码图像特征
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # 取最后一层特征
# [说明] 这里的 input_proj 现在会将通道压缩到 32
features = self.input_proj(features) # [B, num_kp, H', W']
# [说明] SpatialSoftmax 提取 32 个关键点坐标
features = self.img_feature_proj(features) # [B, num_kp * 2]
all_cam_features.append(features)
# 拼接所有相机特征
img_features = torch.cat(all_cam_features, dim=-1) # [B, num_kp * 2 * num_cam]
# 编码状态
qpos_features = self.input_proj_robot_state(qpos) # [B, hidden_dim]
# 拼接观测
obs = torch.cat([img_features, qpos_features], dim=-1) # [B, obs_dim]
return obs
def forward(
self,
qpos,
image,
env_state,
actions=None,
is_pad=None,
timesteps=None,
):
"""前向传播。
训练时:
输入带噪声的动作序列和时间步,预测噪声/速度场
推理时:
通过扩散采样生成动作序列
Args:
qpos: 机器人关节状态 [B, state_dim]
image: 多相机图像 [B, num_cam, C, H, W]
env_state: 环境状态(未使用)
actions: 动作序列 [B, T, action_dim](训练时为带噪声动作)
is_pad: padding 标记 [B, T](未使用)
timesteps: 扩散时间步 [B](训练时提供)
Returns:
训练时: (noise_pred, encoder_state)
推理时: (action_pred, encoder_state)
"""
# 1. 编码观测
obs = self.encode_observations(qpos, image)
# 2. 扩散模型前向
if actions is not None and timesteps is not None:
# 训练模式: 预测噪声
noise_pred, encoder_state = self.action_ddt(
x=actions,
t=timesteps,
obs=obs,
)
return noise_pred, encoder_state
else:
# 推理模式: 需要在 Policy 层进行扩散采样
# 这里返回编码的观测,供 Policy 层使用
return obs, None
def get_obs_dim(self):
"""返回观测向量的维度。"""
# [修改] 使用 num_kp 重新计算
return self.num_kp * 2 * len(self.camera_names) + self.hidden_dim
def build(args):
"""构建 DDT 模型。
Args:
args: 包含模型配置的参数对象
- state_dim: 状态维度
- action_dim: 动作维度
- camera_names: 相机名称列表
- hidden_dim: 隐藏维度
- num_queries: 动作序列长度
- num_blocks: Transformer block 数量
- enc_layers: 编码器层数
- nheads: 注意力头数
- num_kp: 关键点数量 (可选默认32)
Returns:
model: DDT 模型实例
"""
state_dim = args.state_dim
action_dim = args.action_dim
# 构建 Backbone每个相机一个
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
# 构建 DDT 模型
model = DDT(
backbones=backbones,
state_dim=state_dim,
action_dim=action_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
hidden_dim=args.hidden_dim,
num_blocks=getattr(args, 'num_blocks', 12),
num_encoder_blocks=getattr(args, 'enc_layers', 4),
num_heads=args.nheads,
mlp_ratio=getattr(args, 'mlp_ratio', 4.0),
num_kp=getattr(args, 'num_kp', 32), # [修改] 传递 num_kp 参数
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
return model
def build_ddt(args):
"""build 的别名,保持接口一致性。"""
return build(args)

View File

@@ -1,312 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
import copy
from typing import Optional, List
import torch
import torch.nn.functional as F
from torch import nn, Tensor
class Transformer(nn.Module):
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False,
return_intermediate_dec=False):
super().__init__()
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
return_intermediate=return_intermediate_dec)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
# TODO flatten only when input has H and W
if len(src.shape) == 4: # has H and W
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
# mask = mask.flatten(1)
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
addition_input = torch.stack([latent_input, proprio_input], axis=0)
src = torch.cat([addition_input, src], axis=0)
else:
assert len(src.shape) == 3
# flatten NxHWxC to HWxNxC
bs, hw, c = src.shape
src = src.permute(1, 0, 2)
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
pos=pos_embed, query_pos=query_embed)
hs = hs.transpose(1, 2)
return hs
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
output = src
for layer in self.layers:
output = layer(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output.unsqueeze(0)
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def build_transformer(args):
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
return_intermediate_dec=True,
)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")

View File

@@ -1,147 +0,0 @@
"""
DDT Policy - 基于扩散模型的动作生成策略。
支持 Flow Matching 训练和推理。
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision.transforms as transforms
from torchvision.transforms import v2
import math
from roboimi.ddt.main import build_DDT_model_and_optimizer
class DDTPolicy(nn.Module):
"""DDT (Decoupled Diffusion Transformer) 策略。
使用 Flow Matching 进行训练,支持多步扩散采样推理。
带数据增强,适配 DINOv2 等 ViT backbone。
Args:
args_override: 配置参数字典
- num_inference_steps: 推理时的扩散步数
- qpos_noise_std: qpos 噪声标准差(训练时数据增强)
- patch_h, patch_w: 图像 patch 数量(用于计算目标尺寸)
"""
def __init__(self, args_override):
super().__init__()
model, optimizer = build_DDT_model_and_optimizer(args_override)
self.model = model
self.optimizer = optimizer
self.num_inference_steps = args_override.get('num_inference_steps', 10)
self.qpos_noise_std = args_override.get('qpos_noise_std', 0.0)
# 图像尺寸配置 (适配 DINOv2)
self.patch_h = args_override.get('patch_h', 16)
self.patch_w = args_override.get('patch_w', 22)
print(f'DDT Policy: {self.num_inference_steps} steps, '
f'image size ({self.patch_h*14}, {self.patch_w*14})')
def __call__(self, qpos, image, actions=None, is_pad=None):
"""前向传播。
训练时: 使用 Flow Matching 损失
推理时: 通过扩散采样生成动作
Args:
qpos: 机器人关节状态 [B, state_dim]
image: 多相机图像 [B, num_cam, C, H, W]
actions: 目标动作序列 [B, T, action_dim](训练时提供)
is_pad: padding 标记 [B, T]
Returns:
训练时: loss_dict
推理时: 预测的动作序列 [B, T, action_dim]
"""
env_state = None
# 图像预处理
if actions is not None: # 训练时:数据增强
transform = v2.Compose([
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
v2.RandomPerspective(distortion_scale=0.5),
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
v2.Resize((self.patch_h * 14, self.patch_w * 14)),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
if self.qpos_noise_std > 0:
qpos = qpos + (self.qpos_noise_std ** 0.5) * torch.randn_like(qpos)
else: # 推理时
transform = v2.Compose([
v2.Resize((self.patch_h * 14, self.patch_w * 14)),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
image = transform(image)
if actions is not None:
actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries]
loss_dict = self._compute_loss(qpos, image, actions, is_pad)
return loss_dict
else:
a_hat = self._sample(qpos, image)
return a_hat
def _compute_loss(self, qpos, image, actions, is_pad):
"""计算 Flow Matching 损失。
Flow Matching 目标: 学习从噪声到数据的向量场
损失: ||v_theta(x_t, t) - (x_1 - x_0)||^2
其中 x_t = (1-t)*x_0 + t*x_1, x_0 是噪声, x_1 是目标动作
"""
B, T, action_dim = actions.shape
device = actions.device
t = torch.rand(B, device=device)
noise = torch.randn_like(actions)
t_expand = t.view(B, 1, 1).expand(B, T, action_dim)
x_t = (1 - t_expand) * noise + t_expand * actions
target_velocity = actions - noise
pred_velocity, _ = self.model(
qpos=qpos,
image=image,
env_state=None,
actions=x_t,
timesteps=t,
)
all_loss = F.mse_loss(pred_velocity, target_velocity, reduction='none')
loss = (all_loss * ~is_pad.unsqueeze(-1)).mean()
return {'flow_loss': loss, 'loss': loss}
@torch.no_grad()
def _sample(self, qpos, image):
"""通过 ODE 求解进行扩散采样。
使用 Euler 方法从 t=0 积分到 t=1:
x_{t+dt} = x_t + v_theta(x_t, t) * dt
"""
B = qpos.shape[0]
T = self.model.num_queries
action_dim = self.model.action_dim
device = qpos.device
x = torch.randn(B, T, action_dim, device=device)
obs = self.model.encode_observations(qpos, image)
dt = 1.0 / self.num_inference_steps
for i in range(self.num_inference_steps):
t = torch.full((B,), i * dt, device=device)
velocity, _ = self.model.action_ddt(x=x, t=t, obs=obs)
x = x + velocity * dt
return x
def configure_optimizers(self):
"""返回优化器。"""
return self.optimizer

View File

@@ -1 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

View File

@@ -1,88 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Utilities for bounding box manipulation and GIoU.
"""
import torch
from torchvision.ops.boxes import box_area
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2,
(x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
return iou - (area - union) / area
def masks_to_boxes(masks):
"""Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
Returns a [N, 4] tensors, with the boxes in xyxy format
"""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device)
h, w = masks.shape[-2:]
y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x)
x_mask = (masks * x.unsqueeze(0))
x_max = x_mask.flatten(1).max(-1)[0]
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
y_mask = (masks * y.unsqueeze(0))
y_max = y_mask.flatten(1).max(-1)[0]
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
return torch.stack([x_min, y_min, x_max, y_max], 1)

View File

@@ -1,468 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import os
import subprocess
import time
from collections import defaultdict, deque
import datetime
import pickle
from packaging import version
from typing import Optional, List
import torch
import torch.distributed as dist
from torch import Tensor
# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
if version.parse(torchvision.__version__) < version.parse('0.7'):
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
sha = 'N/A'
diff = "clean"
branch = 'N/A'
try:
sha = _run(['git', 'rev-parse', 'HEAD'])
subprocess.check_output(['git', 'diff'], cwd=cwd)
diff = _run(['git', 'diff-index', 'HEAD'])
diff = "has uncommited changes" if diff else "clean"
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def collate_fn(batch):
batch = list(zip(*batch))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return NestedTensor(tensor, mask)
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
max_size.append(max_size_i)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if version.parse(torchvision.__version__) < version.parse('0.7'):
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
else:
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)

View File

@@ -1,107 +0,0 @@
"""
Plotting utilities to visualize training logs.
"""
import torch
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path, PurePath
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
'''
Function to plot specific fields from training log(s). Plots both training and test results.
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
- fields = which results to plot from each log file - plots both training and test for each field.
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
- log_name = optional, name of log file if different than default 'log.txt'.
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
- solid lines are training results, dashed lines are test results.
'''
func_name = "plot_utils.py::plot_logs"
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
# convert single Path to list to avoid 'not iterable' error
if not isinstance(logs, list):
if isinstance(logs, PurePath):
logs = [logs]
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
else:
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
Expect list[Path] or single Path obj, received {type(logs)}")
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
for i, dir in enumerate(logs):
if not isinstance(dir, PurePath):
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
if not dir.exists():
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
# verify log_name exists
fn = Path(dir / log_name)
if not fn.exists():
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
print(f"--> full path of missing log file: {fn}")
return
# load log file(s) and plot
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
for j, field in enumerate(fields):
if field == 'mAP':
coco_eval = pd.DataFrame(
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
).ewm(com=ewm_col).mean()
axs[j].plot(coco_eval, c=color)
else:
df.interpolate().ewm(com=ewm_col).mean().plot(
y=[f'train_{field}', f'test_{field}'],
ax=axs[j],
color=[color] * 2,
style=['-', '--']
)
for ax, field in zip(axs, fields):
ax.legend([Path(p).name for p in logs])
ax.set_title(field)
def plot_precision_recall(files, naming_scheme='iter'):
if naming_scheme == 'exp_id':
# name becomes exp_id
names = [f.parts[-3] for f in files]
elif naming_scheme == 'iter':
names = [f.stem for f in files]
else:
raise ValueError(f'not supported {naming_scheme}')
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
data = torch.load(f)
# precision is n_iou, n_points, n_cat, n_area, max_det
precision = data['precision']
recall = data['params'].recThrs
scores = data['scores']
# take precision for all classes, all areas and 100 detections
precision = precision[0, :, :, 0, -1].mean(1)
scores = scores[0, :, :, 0, -1].mean(1)
prec = precision.mean()
rec = data['recall'][0, :, 0, -1].mean()
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
f'score={scores.mean():0.3f}, ' +
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
)
axs[0].plot(recall, precision, c=color)
axs[1].plot(recall, scores, c=color)
axs[0].set_title('Precision / Recall')
axs[0].legend(names)
axs[1].set_title('Scores / Recall')
axs[1].legend(names)
return fig, axs

View File

@@ -8,8 +8,7 @@ temporal_agg: false
# policy_class: "ACT" # policy_class: "ACT"
# backbone: 'resnet18' # backbone: 'resnet18'
policy_class: "ACTTV" policy_class: "GR00T"
# policy_class: "DDT"
backbone: 'dino_v2' backbone: 'dino_v2'
seed: 0 seed: 0
@@ -39,8 +38,13 @@ episode_len: # leave empty here by default
camera_names: [] # leave empty here by default camera_names: [] # leave empty here by default
xml_dir: # leave empty here by default xml_dir: # leave empty here by default
# action smoothing settings (for GR00T)
use_action_smoothing: true
smooth_method: "ema" # Options: "ema", "moving_avg", "lowpass", "none"
smooth_alpha: 0.3 # Smoothing factor (0-1), smaller = smoother
# transformer settings # transformer settings
batch_size: 32 batch_size: 10
state_dim: 16 state_dim: 16
action_dim: 16 action_dim: 16
lr_backbone: 0.00001 lr_backbone: 0.00001
@@ -52,6 +56,21 @@ nheads: 8
qpos_noise_std: 0 qpos_noise_std: 0
DT: 0.02 DT: 0.02
gr00t:
action_dim: 16
state_dim: 16
embed_dim: 1536
hidden_dim: 1024
num_queries: 8
nheads: 32
mlp_ratio: 4
dropout: 0.2
num_layers: 16
# DO NOT CHANGE IF UNNECESSARY # DO NOT CHANGE IF UNNECESSARY
lr: 0.00001 lr: 0.00001
kl_weight: 100 kl_weight: 100
@@ -59,8 +78,3 @@ chunk_size: 10
hidden_dim: 512 hidden_dim: 512
dim_feedforward: 3200 dim_feedforward: 3200
# DDT 特有参数
num_blocks: 12 # Transformer blocks 数量
mlp_ratio: 4.0 # MLP 维度比例
num_inference_steps: 10 # 扩散推理步数

View File

@@ -1,119 +0,0 @@
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from einops import rearrange
from roboimi.utils.utils import set_seed
from roboimi.utils.io_utils import IOUtils
from roboimi.utils.model_interface import ModelInterface
from roboimi.envs.double_pos_ctrl_env import make_sim_env
# from visualize_episodes import save_videos
from roboimi.utils.act_ex_utils import sample_transfer_pose
#should be added into IOUtils
def get_image(obs,camera_names):
curr_images = []
for cam_name in camera_names:
curr_image = rearrange(obs['images'][cam_name], 'h w c -> c h w')
curr_images.append(curr_image)
curr_image = np.stack(curr_images, axis=0)
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
return curr_image
def eval_bc(config, ckpt_name='policy_best.ckpt', save_episode=True):
set_seed(1)
model_interface = ModelInterface(config)
model_interface.setup()
policy = IOUtils.load_policy(config, ckpt_name)
stats = IOUtils.load_stats(config['ckpt_dir'])
num_rollouts = 3
episode_returns = []
highest_rewards = []
run_episode(config, policy, stats,
save_episode,num_rollouts)
# episode_return, episode_highest_reward = run_episode(config, policy, stats,
# save_episode,num_rollouts)
def run_episode(config, policy, stats, save_episode,num_rollouts):
if 'sim_transfer' in config['task_name']:
task_name = 'sim_transfer' #config['task_name']
env = make_sim_env(task_name)
max_timesteps = config['episode_len']
max_timesteps = int(max_timesteps * 1)
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
box_pos = sample_transfer_pose()
for rollout_id in range(num_rollouts):
print("\nrollout_id===",rollout_id,"\n")
image_list = []
rewards = []
query_frequency = config['policy_config'].get('num_queries', 1)
print("query_freq =====",query_frequency)
env.reset(box_pos)
with torch.inference_mode():
for t in range(700):
image_list.append(env._get_image_obs()['images'] if 'images' in env._get_image_obs() else {print("img error")})
qpos_numpy = np.array(env._get_qpos_obs()['qpos'])
qpos = pre_process(qpos_numpy)
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
curr_image = get_image(env._get_image_obs(), config['camera_names'])
if config['policy_class'] == "ACT" or "ACTTV":
if t % query_frequency == 0:
all_actions = policy(qpos, curr_image)
raw_action = all_actions[:, t % query_frequency]
# raw_action = all_actions[:, t % 1]
raw_action = raw_action.squeeze(0).cpu().numpy()
elif config['policy_class'] == "CNNMLP":
raw_action = policy(qpos, curr_image)
else:
raise NotImplementedError
action = post_process(raw_action)
print("action == ",action)
env.step_jnt(action)
rewards.append(env.rew)
env.render()
rewards = np.array(rewards)
# episode_return = np.sum(rewards[rewards != None])
# episode_highest_reward = np.max(rewards)
# env.viewer.close()
# del env
# return episode_return, episode_highest_reward
def test_env():
try:
env = make_sim_env('sim_transfer')
env.reset()
while True: pass
except KeyboardInterrupt:
del env
print("stop")
if __name__ == '__main__':
# test_env()
io_utils = IOUtils()
config = io_utils.load_config()
eval_bc(config)

View File

@@ -104,8 +104,8 @@ class TestPickAndTransferPolicy(PolicyBase):
{"t": 1, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": -100}, # sleep {"t": 1, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": -100}, # sleep
{"t": 75, "xyz": np.array([(0.8+box_xyz[0])*0.5,(1.0+box_xyz[1])*0.5,init_mocap_pose_right[2]]), "quat": gripper_approach_quat.elements, "gripper": 100}, {"t": 75, "xyz": np.array([(0.8+box_xyz[0])*0.5,(1.0+box_xyz[1])*0.5,init_mocap_pose_right[2]]), "quat": gripper_approach_quat.elements, "gripper": 100},
{"t": 225, "xyz": box_xyz + np.array([0, 0, 0.3]), "quat": gripper_pick_quat.elements, "gripper": 100}, # approach the cube {"t": 225, "xyz": box_xyz + np.array([0, 0, 0.3]), "quat": gripper_pick_quat.elements, "gripper": 100}, # approach the cube
{"t": 275, "xyz": box_xyz + np.array([0, 0, 0.12]), "quat": gripper_pick_quat.elements, "gripper": 100}, # go down {"t": 275, "xyz": box_xyz + np.array([0, 0, 0.11]), "quat": gripper_pick_quat.elements, "gripper": 100}, # go down
{"t": 280, "xyz": box_xyz + np.array([0, 0, 0.12]), "quat": gripper_pick_quat.elements, "gripper": -100}, # close gripper {"t": 280, "xyz": box_xyz + np.array([0, 0, 0.11]), "quat": gripper_pick_quat.elements, "gripper": -100}, # close gripper
{"t": 450, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": -100},# approach wait position {"t": 450, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": -100},# approach wait position
{"t": 500, "xyz": meet_xyz + np.array([0.1, 0, 0.0]), "quat": meet_right_quat.elements, "gripper": -100},# approach meet position {"t": 500, "xyz": meet_xyz + np.array([0.1, 0, 0.0]), "quat": meet_right_quat.elements, "gripper": -100},# approach meet position
{"t": 510, "xyz": meet_xyz + np.array([0.1, 0, 0.0]), "quat": meet_right_quat.elements, "gripper": 100}, # open gripper {"t": 510, "xyz": meet_xyz + np.array([0.1, 0, 0.0]), "quat": meet_right_quat.elements, "gripper": 100}, # open gripper
@@ -116,8 +116,8 @@ class TestPickAndTransferPolicy(PolicyBase):
self.left_trajectory = [ self.left_trajectory = [
{"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": -100},# sleep {"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": -100},# sleep
{"t": 250, "xyz": meet_xyz + np.array([-0.5, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": 100}, # approach meet position {"t": 250, "xyz": meet_xyz + np.array([-0.5, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": 100}, # approach meet position
{"t": 500, "xyz": meet_xyz + np.array([-0.15, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": 100}, # move to meet position {"t": 500, "xyz": meet_xyz + np.array([-0.14, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": 100}, # move to meet position
{"t": 505, "xyz": meet_xyz + np.array([-0.15, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # close gripper {"t": 505, "xyz": meet_xyz + np.array([-0.14, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # close gripper
{"t": 675, "xyz": meet_xyz + np.array([-0.3, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # move left {"t": 675, "xyz": meet_xyz + np.array([-0.3, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # move left
{"t": 700, "xyz": meet_xyz + np.array([-0.3, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # stay {"t": 700, "xyz": meet_xyz + np.array([-0.3, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # stay
] ]

View File

@@ -21,7 +21,7 @@ def main():
render_cam_name = 'angle' render_cam_name = 'angle'
episode_len = 700 #SIM_TASK_CONFIGS[task_name]['episode_len'] episode_len = 700 #SIM_TASK_CONFIGS[task_name]['episode_len']
camera_names = ['angle','r_vis', 'top'] #SIM_TASK_CONFIGS[task_name]['camera_names'] camera_names = ['angle','r_vis', 'top', 'front'] #SIM_TASK_CONFIGS[task_name]['camera_names']
if task_name == 'sim_transfer': if task_name == 'sim_transfer':
policy = TestPickAndTransferPolicy(inject_noise) policy = TestPickAndTransferPolicy(inject_noise)
print(task_name) print(task_name)
@@ -32,6 +32,12 @@ def main():
env = make_sim_env(task_name) env = make_sim_env(task_name)
policy = TestPickAndTransferPolicy(inject_noise) policy = TestPickAndTransferPolicy(inject_noise)
# 等待osmesa完全启动后再开始收集数据
print("等待osmesa线程启动...")
time.sleep(60)
print("osmesa已就绪开始收集数据...")
for episode_idx in range(num_episodes): for episode_idx in range(num_episodes):
obs = [] obs = []
reward_ee = [] reward_ee = []

View File

@@ -1,152 +0,0 @@
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from einops import rearrange
from roboimi.utils.utils import set_seed
from roboimi.utils.io_utils import IOUtils
from roboimi.utils.model_interface import ModelInterface
from roboimi.envs.vx300s_jnt import make_sim_env
import time
# from visualize_episodes import save_videos
from roboimi.utils.utils import sample_box_pose, sample_insertion_pose
#should be added into IOUtils
def get_image(obs,camera_names):
curr_images = []
for cam_name in camera_names:
curr_image = rearrange(obs['images'][cam_name], 'h w c -> c h w')
curr_images.append(curr_image)
curr_image = np.stack(curr_images, axis=0)
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
return curr_image
def eval_bc(config, ckpt_name='policy_best.ckpt', save_episode=True):
set_seed(1)
model_interface = ModelInterface(config)
task_name = 'sim_insertion' #config['task_name']
model_interface.setup()
policy = IOUtils.load_policy(config, ckpt_name)
stats = IOUtils.load_stats(config['ckpt_dir'])
num_rollouts = 3
episode_returns = []
highest_rewards = []
for rollout_id in range(num_rollouts):
episode_return, episode_highest_reward = run_episode(config, policy, stats,
save_episode,rollout_id)
def run_episode(config, policy, stats, save_episode,rollout_id):
print("\nrollout_id===",rollout_id,"\n")
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
if 'sim_insertion' in config['task_name']:
peg_pose, socket_pose = sample_insertion_pose()
box_pose = np.hstack((peg_pose[:3],socket_pose[:3])) # used in sim reset
task_name = 'sim_insertion' #config['task_name']
env = make_sim_env(task_name)
env.reset(box_pose)
max_timesteps = config['episode_len']
max_timesteps = int(max_timesteps * 1)
image_list = []
rewards = []
query_frequency = config['policy_config'].get('num_queries', 1)
with torch.inference_mode():
for t in range(700):
# print("obs_img",env.obs['images'])
image_list.append(env.obs['images'] if 'images' in env.obs else {print("img error")})
qpos_numpy = np.array(env.obs['qpos'])
qpos = pre_process(qpos_numpy)
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
curr_image = get_image(env.obs, config['camera_names'])
if config['policy_class'] == "ACT" or "ACTTV":
if t % query_frequency == 0:
all_actions = policy(qpos, curr_image)
elif config['policy_class'] == "CNNMLP":
raw_action = policy(qpos, curr_image)
else:
raise NotImplementedError
raw_action = all_actions[:, t % query_frequency]
raw_action = raw_action.squeeze(0).cpu().numpy()
action = post_process(raw_action)
env.step(action)
rewards.append(env.rew)
env.render()
rewards = np.array(rewards)
episode_return = np.sum(rewards[rewards != None])
episode_highest_reward = np.max(rewards)
env.viewer.close()
del env
return episode_return, episode_highest_reward
def test_env():
try:
env = make_sim_env('sim_insertion')
box_pos = np.concatenate(sample_insertion_pose())
env.reset(box_pos)
while True: pass
except KeyboardInterrupt:
del env
print("stop")
if __name__ == '__main__':
test_env()
# io_utils = IOUtils()
# config = io_utils.load_config()
# eval_bc(config)
# config===== {'onscreen_render': False,
# 'eval': 1,
# 'ckpt_dir': 'ckpt_models',
# 'num_epochs': 3000,
# 'temporal_agg': False,
# 'policy_class': 'ACT',
# 'backbone': 'resnet18',
# 'seed': 0, 'real_robot': 0,
# 'task_name': 'sim_insertion',
# 'images_render_height': 480,
# 'images_render_width': 640,
# 'left_arm_DOF_number': 6,
# 'right_arm_DOF_number': 6,
# 'left_qpos_raw': 8,
# 'right_qpos_raw': 8,
# 'left_qvel_raw': 8,
# 'right_qvel_raw': 8,
# 'dataset_dir': '/home/arm/lzd/act_env/dataset/sim_insertion',
# 'num_episodes': 7,
# 'episode_len': 400,
# 'camera_names': ['top'],
# 'xml_dir': None,
# 'batch_size': 8,
# 'state_dim': 14,
# 'action_dim': 14,
# 'lr_backbone': 1e-05,
# 'enc_layers': 4,
# 'dec_layers': 7,
# 'nheads': 8,
# 'qpos_noise_std': 0,
# 'DT': 0.02,
# 'lr': 1e-05,
# 'kl_weight': 10,
# 'chunk_size': 100,
# 'hidden_dim': 512,
# 'dim_feedforward': 3200,
# 'policy_config': {'lr': 1e-05, 'num_queries': 100, 'kl_weight': 10, 'hidden_dim': 512, 'dim_feedforward': 3200, 'lr_backbone': 1e-05, 'backbone': 'resnet18', 'enc_layers': 4, 'dec_layers': 7, 'nheads': 8, 'camera_names': ['top']}}

View File

@@ -1,179 +0,0 @@
import torch
import os
from tqdm import tqdm
import numpy as np
from copy import deepcopy
from itertools import repeat
import matplotlib.pyplot as plt
import time
from roboimi.utils.utils import set_seed, compute_dict_mean, detach_dict, load_data
from roboimi.utils.io_utils import IOUtils
from roboimi.utils.model_interface import ModelInterface
import matplotlib.pyplot as plt
def train_bc(config):
num_epochs = config['num_epochs']
ckpt_dir = config['ckpt_dir']
seed = config['seed']
os.makedirs(ckpt_dir, exist_ok=True)
set_seed(seed)
model_interface = ModelInterface(config)
model_interface.setup()
policy = model_interface.make_policy()
policy.cuda()
optimizer = model_interface.make_optimizer(policy)
# print("cam names=====",config['camera_names'])
train_dataloader, val_dataloader, stats, _ = load_data(
config['dataset_dir'],
config['num_episodes'],
config['camera_names'],
config['batch_size'],
config['batch_size'])
IOUtils.save_stats(ckpt_dir, stats)
train_history = []
validation_history = []
min_val_loss = np.inf
min_train_loss = np.inf
best_ckpt_info = None
plt.ion()
fig, ax = plt.subplots()
train_losses, val_losses = [], []
train_line, = ax.plot([], [], label='Train Loss')
val_line, = ax.plot([], [], label='Validation Loss')
ax.autoscale_view()
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.legend()
ax.grid(True)
train_annotation = ax.annotate('', xy=(0, 0), textcoords='offset points')
val_annotation = ax.annotate('', xy=(0, 0), textcoords='offset points')
min_train_text = ax.text(0.85, 0.5, '', transform=ax.transAxes, fontsize=10, verticalalignment='center', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.5))
min_val_text = ax.text(0.85, 0.45, '', transform=ax.transAxes, fontsize=10, verticalalignment='center', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.5))
for epoch in tqdm(range(num_epochs)):
print(f'\nEpoch {epoch}')
# Validation
epoch_val_loss, epoch_summary = validate(policy, val_dataloader)
validation_history.append(epoch_summary)
val_losses.append(epoch_val_loss.cpu().item())
if epoch_val_loss < min_val_loss:
min_val_loss = epoch_val_loss
min_val_epoch = epoch
best_ckpt_info = (epoch, min_val_loss,
deepcopy(policy.state_dict()))
print(f'Val loss: {epoch_val_loss:.5f}')
print_summary(epoch_summary)
# Training
epoch_train_loss, epoch_summary = train_epoch(
policy, optimizer, train_dataloader)
train_history.append(epoch_summary)
train_losses.append(epoch_train_loss.cpu().item())
if epoch_train_loss < min_train_loss:
min_train_loss = epoch_train_loss
min_train_epoch = epoch
print(f'Train loss: {epoch_train_loss:.5f}')
print_summary(epoch_summary)
# Update the plot with the new data
train_line.set_xdata(range(len(train_losses)))
train_line.set_ydata(train_losses)
val_line.set_xdata(range(len(val_losses)))
val_line.set_ydata(val_losses)
# Update annotations with the latest loss values at their respective positions
train_annotation.set_position((len(train_losses)-1, train_losses[-1]))
train_annotation.xy = (len(train_losses)-1, train_losses[-1])
train_annotation.set_text(f'{train_losses[-1]:.5f}')
val_annotation.set_position((len(val_losses)-1, val_losses[-1]))
val_annotation.xy = (len(val_losses)-1, val_losses[-1])
val_annotation.set_text(f'{val_losses[-1]:.5f}')
# Update text objects with the minimum loss values, fixed on the right side
min_train_text.set_text(f'Min Train Loss: {min_train_loss:.5f} (Epoch {min_train_epoch})')
min_val_text.set_text(f'Min Val Loss: {min_val_loss:.5f} (Epoch {min_val_epoch})')
ax.relim()
ax.autoscale_view()
plt.draw()
plt.pause(0.1)
plt.ioff()
IOUtils.save_checkpoint(policy, 'last', ckpt_dir, seed, 'last')
best_epoch, min_val_loss, best_state_dict = best_ckpt_info
IOUtils.save_checkpoint(best_state_dict, best_epoch,
ckpt_dir, seed, 'best', min_val_loss)
print(
f'Training finished:\nSeed {seed}, val loss {min_val_loss:.6f} at epoch {best_epoch}')
IOUtils.plot_history(train_history, validation_history,
num_epochs, ckpt_dir, seed)
return best_ckpt_info
def validate(policy, dataloader):
policy.eval()
epoch_dicts = []
with torch.inference_mode():
for data in dataloader:
forward_dict = forward_pass(data, policy)
epoch_dicts.append(forward_dict)
epoch_summary = compute_dict_mean(epoch_dicts)
return epoch_summary['loss'], epoch_summary
def train_epoch(policy, optimizer, dataloader):
policy.train()
epoch_dicts = []
for data in dataloader:
optimizer.zero_grad()
forward_dict = forward_pass(data, policy)
loss = forward_dict['loss']
loss.backward()
optimizer.step()
epoch_dicts.append(detach_dict(forward_dict))
epoch_summary = compute_dict_mean(epoch_dicts)
return epoch_summary['loss'], epoch_summary
def forward_pass(data, policy):
image_data, qpos_data, action_data, is_pad = data
image_data, qpos_data, action_data, is_pad = image_data.cuda(
), qpos_data.cuda(), action_data.cuda(), is_pad.cuda()
return policy(qpos_data, image_data, action_data, is_pad)
def print_summary(summary):
summary_string = ' '.join(
[f'{k}: {v.item():.3f}' for k, v in summary.items()])
print(summary_string)
if __name__ == '__main__':
io_utils = IOUtils()
config = io_utils.load_config()
train_bc(config)

View File

@@ -0,0 +1,312 @@
"""
VLA 策略评估脚本(简化版)
该脚本使用 agent 内置的队列管理来评估训练好的 VLA 策略。
无需单独的评估器类 - agent 处理一切!
使用方法:
python roboimi/demos/eval_vla_simple.py
python roboimi/demos/eval_vla_simple.py eval.ckpt_path=checkpoints/vla_model_final.pt
python roboimi/demos/eval_vla_simple.py eval.ckpt_path=checkpoints/vla_model_best.pt
"""
import sys
import os
import json
import logging
import time
import torch
import numpy as np
import hydra
from pathlib import Path
from typing import Dict
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
from einops import rearrange
from roboimi.envs.double_pos_ctrl_env import make_sim_env
from roboimi.utils.act_ex_utils import sample_transfer_pose
sys.path.append(os.getcwd())
log = logging.getLogger(__name__)
if not OmegaConf.has_resolver("len"):
OmegaConf.register_new_resolver("len", lambda x: len(x))
def load_checkpoint(
ckpt_path: str,
agent_cfg: DictConfig,
device: str = 'cuda'
) -> torch.nn.Module:
"""
从检查点加载训练好的 VLA 模型,使用 Hydra agent 配置。
Args:
ckpt_path: 检查点文件路径 (.pt)
agent_cfg: Hydra agent 配置,用于实例化
device: 加载模型的设备
Returns:
加载后的 VLAAgent 模型
"""
from pathlib import Path as PathLib
ckpt_path = PathLib(ckpt_path).absolute()
if not ckpt_path.exists():
raise FileNotFoundError(f"检查点未找到: {ckpt_path}")
log.info(f"{ckpt_path} 加载检查点")
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
log.info(f"检查点键值: {checkpoint.keys()}")
# 加载数据集统计信息用于归一化
stats = checkpoint.get('dataset_stats', None)
# 使用数据集统计信息从 Hydra 配置实例化 agent
log.info("从配置实例化 agent...")
agent = instantiate(agent_cfg, dataset_stats=stats)
# 加载模型状态
agent.load_state_dict(checkpoint['model_state_dict'])
log.info(f"✅ 模型状态已加载 (步数: {checkpoint.get('step', 'unknown')})")
if stats is not None:
log.info(f"✅ 数据集统计信息已加载 (归一化: {stats.get('normalization_type', 'gaussian')})")
else:
# 后备方案:尝试从外部 JSON 文件加载(兼容旧检查点)
stats_path = ckpt_path.parent / 'dataset_stats.json'
if stats_path.exists():
with open(stats_path, 'r') as f:
stats = json.load(f)
log.info("✅ 数据集统计信息已从外部 JSON 加载(旧版本兼容)")
else:
log.warning("⚠️ 未找到数据集统计信息。动作将无法反归一化!")
agent.eval()
agent.to(device)
log.info(f"✅ 模型已成功加载到 {device}")
return agent, stats
def prepare_observation(obs: Dict, camera_names: list) -> Dict:
"""
将环境观测转换为 agent 格式。
Args:
obs: 环境观测字典,包含图像和 qpos
camera_names: 摄像头名称列表
Returns:
agent 格式的观测字典
"""
import cv2
# 转换图像: numpy -> tensor, HWC -> CHW
images = {}
for cam_name in camera_names:
img = obs['images'][cam_name]
# Resize 到 224x224与训练时一致
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
img = rearrange(img, 'h w c -> c h w')
img = torch.from_numpy(img / 255.0).float()
images[cam_name] = img
# 转换 qpos: numpy -> tensor
qpos = torch.from_numpy(obs['qpos']).float()
return {'qpos': qpos, 'images': images}
class ActionSmoother:
"""
动作平滑器(指数移动平均)
用于平滑执行动作以获得更稳定的控制
"""
def __init__(self, alpha: float = 0.3):
"""
Args:
alpha: 平滑系数 (0-1),值越大越重视当前动作
"""
self.alpha = alpha
self.prev_action = None
def smooth(self, action: np.ndarray) -> np.ndarray:
"""
平滑动作
Args:
action: 当前动作
Returns:
平滑后的动作
"""
if self.prev_action is None:
smoothed = action
else:
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
self.prev_action = smoothed
return smoothed
def reset(self):
"""重置平滑器状态"""
self.prev_action = None
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
def main(cfg: DictConfig):
"""
使用 agent 内置队列管理的简化版 VLA 评估
所有评估参数来自 vla/conf/eval.yaml合并到 cfg 中。
命令行覆盖: python eval_vla_simple.py eval.ckpt_path=... eval.num_episodes=5
"""
# 打印配置
print("=" * 80)
print("VLA 评估配置:")
print("=" * 80)
print(OmegaConf.to_yaml(cfg))
print("=" * 80)
eval_cfg = cfg.eval
device = eval_cfg.device
camera_names = list(eval_cfg.camera_names)
# =========================================================================
# 加载模型
# =========================================================================
log.info(f"🚀 从 {eval_cfg.ckpt_path} 加载模型...")
agent, dataset_stats = load_checkpoint(
ckpt_path=eval_cfg.ckpt_path,
agent_cfg=cfg.agent,
device=device
)
# 重置 agent 的队列
agent.reset()
# 可选:动作平滑器
smoother = ActionSmoother(alpha=eval_cfg.smooth_alpha) if eval_cfg.use_smoothing else None
# =========================================================================
# 创建环境
# =========================================================================
env = make_sim_env(eval_cfg.task_name)
# =========================================================================
# 运行评估回合
# =========================================================================
all_stats = []
for episode_idx in range(eval_cfg.num_episodes):
print(f"\n{'='*60}")
print(f"回合 {episode_idx + 1}/{eval_cfg.num_episodes}")
print(f"{'='*60}\n")
box_pos = sample_transfer_pose()
env.reset(box_pos)
# 为新回合重置 agent 队列
agent.reset()
if smoother:
smoother.reset()
# 计时统计
inference_times = []
total_times = []
with torch.inference_mode():
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"回合 {episode_idx + 1}"):
start_total = time.time()
# 从环境获取观测
obs = env._get_image_obs()
qpos_obs = env._get_qpos_obs()
obs['qpos'] = qpos_obs['qpos']
# 准备给 agent 的观测
observation = prepare_observation(obs, camera_names)
# 选择动作agent 内部处理队列管理)
start_inference = time.time()
action = agent.select_action(observation)
if device == 'cuda':
torch.cuda.synchronize()
end_inference = time.time()
# 转换为 numpy
action = action.cpu().numpy()
# 调试:打印当前时间步的动作(由配置控制)
if eval_cfg.get('verbose_action', False):
print(f"\n[Step {t:3d}] 预测动作: {action}")
print(f" - 动作形状: {action.shape}")
print(f" - 动作范围: [{action.min():.4f}, {action.max():.4f}]")
print(f" - 动作均值: {action.mean():.4f}, 标准差: {action.std():.4f}")
# 可选:平滑动作
if smoother:
action = smoother.smooth(action)
# 执行动作
env.step_jnt(action)
env.render()
end_total = time.time()
# 记录计时
inference_times.append(end_inference - start_inference)
total_times.append(end_total - start_total)
# =========================================================================
# 打印回合统计
# =========================================================================
avg_inference_time = np.mean(inference_times)
avg_total_time = np.mean(total_times)
stats = {
'inference_fps': 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0,
'control_fps': 1.0 / avg_total_time if avg_total_time > 0 else 0.0,
'avg_inference_time_ms': avg_inference_time * 1000,
'avg_total_time_ms': avg_total_time * 1000,
'num_inferences': len([t for t in inference_times if t > 0.001]), # 统计实际推理次数
'num_steps': len(total_times)
}
all_stats.append(stats)
print(f"\n回合 {episode_idx + 1} 完成 ({eval_cfg.max_timesteps} 时间步)")
print(f" 模型推理 FPS: {stats['inference_fps']:.2f} Hz")
print(f" 控制循环 FPS: {stats['control_fps']:.2f} Hz")
print(f" 平均推理时间: {stats['avg_inference_time_ms']:.2f} ms")
print(f" 平均总时间: {stats['avg_total_time_ms']:.2f} ms")
print(f" 总推理次数: {stats['num_inferences']}")
# =========================================================================
# 总体统计
# =========================================================================
print(f"\n{'='*60}")
print("评估完成!")
print(f"{'='*60}")
if all_stats:
avg_inference_fps = np.mean([s['inference_fps'] for s in all_stats])
avg_control_fps = np.mean([s['control_fps'] for s in all_stats])
avg_inference_time = np.mean([s['avg_inference_time_ms'] for s in all_stats])
avg_total_time = np.mean([s['avg_total_time_ms'] for s in all_stats])
print(f"\n总体统计 ({eval_cfg.num_episodes} 个回合):")
print(f" 平均模型推理 FPS: {avg_inference_fps:.2f} Hz")
print(f" 平均控制循环 FPS: {avg_control_fps:.2f} Hz")
print(f" 平均推理时间: {avg_inference_time:.2f} ms")
print(f" 平均总时间: {avg_total_time:.2f} ms")
print()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,532 @@
import sys
import os
import logging
import json
import pickle
import hydra
import torch
import re
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from pathlib import Path
# 确保正确的导入路径
sys.path.append(os.getcwd())
from hydra.utils import instantiate
log = logging.getLogger(__name__)
# 注册列表长度解析器(用于配置中如 ${len:${data.camera_names}}
if not OmegaConf.has_resolver("len"):
OmegaConf.register_new_resolver("len", lambda x: len(x))
def recursive_to_device(data, device):
"""
递归地将嵌套字典/列表中的张量移动到指定设备。
Args:
data: 字典、列表或张量
device: 目标设备 (例如 'cuda', 'cpu')
Returns:
所有张量已移动到指定设备的数据结构
"""
if isinstance(data, torch.Tensor):
return data.to(device)
elif isinstance(data, dict):
return {k: recursive_to_device(v, device) for k, v in data.items()}
elif isinstance(data, list):
return [recursive_to_device(v, device) for v in data]
return data
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
"""
解析恢复训练用的 checkpoint 路径。
Args:
resume_ckpt: 配置中的 resume_ckpt支持路径或 "auto"
checkpoint_dir: 默认检查点目录
Returns:
Path 或 None
"""
if resume_ckpt is None:
return None
if str(resume_ckpt).lower() != "auto":
return Path(resume_ckpt)
pattern = re.compile(r"vla_model_step_(\d+)\.pt$")
candidates = []
for ckpt_path in checkpoint_dir.glob("vla_model_step_*.pt"):
match = pattern.search(ckpt_path.name)
if match:
candidates.append((int(match.group(1)), ckpt_path))
if not candidates:
return None
return max(candidates, key=lambda x: x[0])[1]
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
"""
创建带预热的学习率调度器。
Args:
optimizer: PyTorch 优化器
warmup_steps: 预热步数
max_steps: 总训练步数
scheduler_type: 预热后的调度器类型 ('cosine''constant')
min_lr: 最小学习率(用于余弦衰减)
Returns:
LambdaLR 调度器
"""
import math
# 在 LambdaLR 修改前捕获初始学习率
base_lr = optimizer.param_groups[0]['lr']
min_lr_ratio = min_lr / base_lr if base_lr > 0 else 0.0
def lr_lambda(step):
# 预热阶段:从 0 线性增加到 1
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
# 预热后阶段
if scheduler_type == 'cosine':
# 从 1 到 min_lr_ratio 的余弦退火
progress = float(step - warmup_steps) / float(max(1, max_steps - warmup_steps))
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return max(min_lr_ratio, cosine_decay)
else:
# 恒定学习率
return 1.0
return LambdaLR(optimizer, lr_lambda)
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
def main(cfg: DictConfig):
"""
VLA 训练脚本ResNet 骨干网络 + Diffusion 策略)
该脚本功能:
1. 从 HDF5 文件加载数据集
2. 实例化带 ResNet 视觉编码器的 VLAAgent
3. 训练基于扩散的动作预测模型
4. 定期保存检查点
"""
# 打印配置
print("=" * 80)
print("VLA 训练配置:")
print("=" * 80)
print(OmegaConf.to_yaml(cfg))
print("=" * 80)
log.info(f"🚀 开始 VLA 训练 (设备: {cfg.train.device})")
# 创建检查点目录
checkpoint_dir = Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True)
# =========================================================================
# 1. 实例化数据集与 DataLoader
# =========================================================================
log.info("📦 加载数据集...")
try:
dataset = instantiate(cfg.data)
log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}")
except Exception as e:
log.error(f"❌ 数据集加载失败: {e}")
raise
# 训练/验证集划分
val_split = float(cfg.train.get('val_split', 0.1))
seed = int(cfg.train.get('seed', 42))
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
if val_size > 0:
train_dataset, val_dataset = random_split(
dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(seed)
)
log.info(f"✅ 数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})")
else:
train_dataset, val_dataset = dataset, None
log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)")
train_loader = DataLoader(
train_dataset,
batch_size=cfg.train.batch_size,
shuffle=True,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
persistent_workers=(cfg.train.num_workers > 0),
drop_last=True # 丢弃不完整批次以稳定训练
)
val_loader = None
if val_dataset is not None:
val_loader = DataLoader(
val_dataset,
batch_size=cfg.train.batch_size,
shuffle=False,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
persistent_workers=(cfg.train.num_workers > 0),
drop_last=False
)
log.info(f"✅ 训练加载器每轮批次数: {len(train_loader)}")
if val_loader is not None:
log.info(f"✅ 验证加载器每轮批次数: {len(val_loader)}")
# =========================================================================
# 2. 加载数据集统计信息(将传递给 agent
# =========================================================================
log.info("💾 加载数据集统计信息...")
dataset_stats = None
try:
dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer')
stats_path = Path(dataset_dir) / 'dataset_stats.pkl'
if stats_path.exists():
with open(stats_path, 'rb') as f:
stats = pickle.load(f)
# 扁平化stats字典嵌套结构→扁平结构以匹配NormalizationModule的期望格式
dataset_stats = {
'action_mean': stats['action_mean'].tolist(),
'action_std': stats['action_std'].tolist(),
'action_min': stats['action_min'].tolist(),
'action_max': stats['action_max'].tolist(),
'qpos_mean': stats['qpos_mean'].tolist(),
'qpos_std': stats['qpos_std'].tolist(),
'qpos_min': stats['qpos_min'].tolist(),
'qpos_max': stats['qpos_max'].tolist(),
}
log.info(f"✅ 数据集统计信息加载完成 (归一化: {cfg.agent.normalization_type})")
else:
log.warning(f"⚠️ 统计文件未找到: {stats_path}")
log.warning("⚠️ 推理时动作将无法反归一化!")
except Exception as e:
log.warning(f"⚠️ 统计信息加载失败: {e}")
log.warning("⚠️ 训练将继续,但推理可能无法正常工作")
# =========================================================================
# 3. 实例化 VLA Agent
# =========================================================================
log.info("🤖 初始化 VLA Agent...")
try:
# 将 dataset_stats 和 normalization_type 传递给 agent
agent = instantiate(cfg.agent, dataset_stats=dataset_stats)
agent.to(cfg.train.device)
agent.train()
log.info(f"✅ Agent 初始化完成并已移至 {cfg.train.device}")
# 统计参数量
total_params = sum(p.numel() for p in agent.parameters())
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
log.info(f"📊 总参数量: {total_params:,}")
log.info(f"📊 可训练参数量: {trainable_params:,}")
except Exception as e:
log.error(f"❌ Agent 初始化失败: {e}")
raise
# =========================================================================
# 3.1 从预训练 checkpoint 加载权重(微调)
# =========================================================================
pretrained_ckpt = cfg.train.get('pretrained_ckpt', None)
if pretrained_ckpt is not None:
ckpt_path = Path(pretrained_ckpt)
if ckpt_path.exists():
log.info(f"🔄 [Finetune] 从预训练 checkpoint 加载权重: {ckpt_path}")
try:
checkpoint = torch.load(ckpt_path, map_location=cfg.train.device)
# 只加载模型权重(不加载 optimizer、scheduler
missing_keys, unexpected_keys = agent.load_state_dict(
checkpoint['model_state_dict'],
strict=False # 允许部分加载(结构不完全匹配时)
)
log.info(f"✅ [Finetune] 模型权重加载成功")
if missing_keys:
log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...")
if unexpected_keys:
log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...")
log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}")
log.info(f"📈 [Finetune] 使用新的训练配置lr={cfg.train.lr}, max_steps={cfg.train.max_steps}")
except Exception as e:
log.error(f"❌ [Finetune] 加载 checkpoint 失败: {e}")
log.warning("⚠️ 将从头开始训练")
else:
log.error(f"❌ [Finetune] Checkpoint 文件不存在: {ckpt_path}")
log.warning("⚠️ 将从头开始训练")
# =========================================================================
# 4. 设置优化器与学习率调度器
# =========================================================================
weight_decay = float(cfg.train.get('weight_decay', 1e-5))
grad_clip = float(cfg.train.get('grad_clip', 1.0))
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=weight_decay)
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr}, weight_decay={weight_decay})")
# 设置带预热的学習率调度器
warmup_steps = int(cfg.train.get('warmup_steps', 500))
scheduler_type = cfg.train.get('scheduler_type', 'cosine')
min_lr = float(cfg.train.get('min_lr', 1e-6))
scheduler = get_lr_schedule_with_warmup(
optimizer,
warmup_steps=warmup_steps,
max_steps=cfg.train.max_steps,
scheduler_type=scheduler_type,
min_lr=min_lr
)
log.info(f"📈 学习率调度器: {scheduler_type}{warmup_steps} 步预热 (最小学习率={min_lr})")
# =========================================================================
# 4.1 断点续训(恢复模型、优化器、调度器、步数)
# =========================================================================
start_step = 0
resume_loss = None
resume_best_loss = float('inf')
resume_ckpt = cfg.train.get('resume_ckpt', None)
resume_path = resolve_resume_checkpoint(resume_ckpt, checkpoint_dir)
if resume_ckpt is not None:
if pretrained_ckpt is not None:
log.warning("⚠️ [Resume] 同时设置了 pretrained_ckpt 与 resume_ckpt将优先使用 resume_ckpt 进行断点续训")
if resume_path is None:
log.warning("⚠️ [Resume] 未找到可恢复的 checkpoint将从头开始训练")
elif not resume_path.exists():
log.error(f"❌ [Resume] Checkpoint 文件不存在: {resume_path}")
log.warning("⚠️ 将从头开始训练")
else:
log.info(f"🔄 [Resume] 从 checkpoint 恢复训练: {resume_path}")
try:
checkpoint = torch.load(resume_path, map_location=cfg.train.device)
agent.load_state_dict(checkpoint['model_state_dict'], strict=True)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
resume_step = int(checkpoint['step'])
start_step = resume_step + 1
loaded_loss = checkpoint.get('loss', None)
loaded_val_loss = checkpoint.get('val_loss', None)
resume_loss = float(loaded_loss) if loaded_loss is not None else None
if loaded_val_loss is not None:
resume_best_loss = float(loaded_val_loss)
elif loaded_loss is not None:
resume_best_loss = float(loaded_loss)
log.info(f"✅ [Resume] 恢复成功: 上次步骤={resume_step}, 本次从步骤 {start_step} 开始")
log.info(f"📈 [Resume] 当前学习率: {optimizer.param_groups[0]['lr']:.2e}")
except Exception as e:
log.error(f"❌ [Resume] 恢复失败: {e}")
log.warning("⚠️ 将从头开始训练")
start_step = 0
resume_loss = None
resume_best_loss = float('inf')
# =========================================================================
# 5. 训练循环
# =========================================================================
log.info("🏋️ 开始训练循环...")
def build_agent_input(batch_data):
"""构建 agent 输入格式"""
images = {}
# SimpleRobotDataset 返回 observation.{cam_name} 格式
for cam_name in cfg.data.camera_names:
key = f"observation.{cam_name}"
if key in batch_data:
images[cam_name] = batch_data[key]
return {
'images': images,
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
'action': batch_data['action'],
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
}
def run_validation():
"""运行验证"""
if val_loader is None:
return None
agent.eval()
# 设置确定性种子以获得可重现的损失
# 这确保验证损失在不同步骤之间可比较
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for val_batch in val_loader:
val_batch = recursive_to_device(val_batch, cfg.train.device)
val_input = build_agent_input(val_batch)
val_loss = agent.compute_loss(val_input)
total_loss += val_loss.item()
num_batches += 1
agent.train()
return total_loss / max(num_batches, 1)
data_iter = iter(train_loader)
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
best_loss = resume_best_loss
last_loss = resume_loss
if start_step >= cfg.train.max_steps:
log.warning(
f"⚠️ [Resume] start_step={start_step} 已达到/超过 max_steps={cfg.train.max_steps},跳过训练循环"
)
for step in pbar:
try:
batch = next(data_iter)
except StopIteration:
# 轮次结束时重启迭代器
data_iter = iter(train_loader)
batch = next(data_iter)
# =====================================================================
# 将批次移至设备
# =====================================================================
batch = recursive_to_device(batch, cfg.train.device)
# =====================================================================
# 准备 agent 输入
# =====================================================================
# 数据集返回: {action, qpos, image_<cam_name>, ...}
# Agent 期望: {images: dict, qpos: tensor, action: tensor}
# 准备 agent 输入
agent_input = build_agent_input(batch)
# =====================================================================
# 前向传播与损失计算
# =====================================================================
try:
loss = agent.compute_loss(agent_input)
except Exception as e:
log.error(f"❌ 步骤 {step} 前向传播失败: {e}")
raise
last_loss = loss.item()
# =====================================================================
# 反向传播与优化
# =====================================================================
optimizer.zero_grad()
loss.backward()
# 梯度裁剪以稳定训练
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=grad_clip)
optimizer.step()
scheduler.step()
# =====================================================================
# 日志记录
# =====================================================================
if step % cfg.train.log_freq == 0:
current_lr = optimizer.param_groups[0]['lr']
pbar.set_postfix({
"loss": f"{loss.item():.4f}",
"lr": f"{current_lr:.2e}",
"best_loss": f"{best_loss:.4f}"
})
log.info(f"步骤 {step}/{cfg.train.max_steps} | 损失: {loss.item():.4f} | 学习率: {current_lr:.2e}")
# =====================================================================
# 检查点保存与验证
# =====================================================================
if step > 0 and step % cfg.train.save_freq == 0:
# 运行验证
val_loss = run_validation()
if val_loss is not None:
log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}")
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
# 使用agent的归一化统计信息包含normalization_type
agent_stats = agent.get_normalization_stats()
torch.save({
'step': step,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(),
'val_loss': val_loss,
'dataset_stats': agent_stats, # 保存agent的统计信息
'current_lr': optimizer.param_groups[0]['lr'],
}, checkpoint_path)
log.info(f"💾 检查点已保存: {checkpoint_path}")
# 根据验证损失保存最佳模型
eval_loss = val_loss if val_loss is not None else loss.item()
if eval_loss < best_loss:
best_loss = eval_loss
best_model_path = checkpoint_dir / "vla_model_best.pt"
agent_stats = agent.get_normalization_stats()
torch.save({
'step': step,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(),
'val_loss': val_loss,
'dataset_stats': agent_stats, # 保存agent的统计信息
'current_lr': optimizer.param_groups[0]['lr'],
}, best_model_path)
log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})")
# =========================================================================
# 6. 保存最终模型
# =========================================================================
final_model_path = checkpoint_dir / "vla_model_final.pt"
agent_stats = agent.get_normalization_stats()
torch.save({
'step': cfg.train.max_steps,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': last_loss,
'dataset_stats': agent_stats, # 保存agent的统计信息
'current_lr': optimizer.param_groups[0]['lr'],
}, final_model_path)
log.info(f"💾 最终模型已保存: {final_model_path}")
log.info("✅ 训练成功完成!")
if last_loss is not None:
log.info(f"📊 最终损失: {last_loss:.4f}")
else:
log.info("📊 最终损失: N/A未执行训练步")
if best_loss != float('inf'):
log.info(f"📊 最佳损失: {best_loss:.4f}")
else:
log.info("📊 最佳损失: N/A无有效验证/训练损失)")
if __name__ == "__main__":
main()

View File

@@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2020 - present, Facebook, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -1,9 +0,0 @@
This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0.
@article{Carion2020EndtoEndOD,
title={End-to-End Object Detection with Transformers},
author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko},
journal={ArXiv},
year={2020},
volume={abs/2005.12872}
}

View File

@@ -1,106 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import argparse
from pathlib import Path
import numpy as np
import torch
from .models import build_ACT_model, build_CNNMLP_model
def get_args_parser():
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
parser.add_argument('--lr', default=1e-4, type=float) # will be overridden
parser.add_argument('--lr_backbone', default=1e-5, type=float) # will be overridden
parser.add_argument('--batch_size', default=2, type=int) # not used
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--epochs', default=300, type=int) # not used
parser.add_argument('--lr_drop', default=200, type=int) # not used
parser.add_argument('--clip_max_norm', default=0.1, type=float, # not used
help='gradient clipping max norm')
parser.add_argument('--qpos_noise_std', action='store', default=0, type=float, help='lr', required=False)
# Model parameters
# * Backbone
parser.add_argument('--backbone', default='resnet18', type=str, # will be overridden
help="Name of the convolutional backbone to use")
parser.add_argument('--dilation', action='store_true',
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features")
parser.add_argument('--camera_names', default=[], type=list, # will be overridden
help="A list of camera names")
# * Transformer
parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=6, type=int, # will be overridden
help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int, # will be overridden
help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int, # will be overridden
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float,
help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int, # will be overridden
help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--num_queries', default=400, type=int, # will be overridden
help="Number of query slots")
parser.add_argument('--pre_norm', action='store_true')
parser.add_argument('--state_dim', default=14, type=int)
parser.add_argument('--action_dim', default=14, type=int)
# * Segmentation
parser.add_argument('--masks', action='store_true',
help="Train segmentation head if the flag is provided")
return parser
def build_ACT_model_and_optimizer(args_override):
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
for k, v in args_override.items():
setattr(args, k, v)
model = build_ACT_model(args)
model.cuda()
param_dicts = [
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)
return model, optimizer
def build_CNNMLP_model_and_optimizer(args_override):
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
for k, v in args_override.items():
setattr(args, k, v)
model = build_CNNMLP_model(args)
model.cuda()
param_dicts = [
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)
return model, optimizer

View File

@@ -1,9 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .detr_vae import build as build_vae
from .detr_vae import build_cnnmlp as build_cnnmlp
def build_ACT_model(args):
return build_vae(args)
def build_CNNMLP_model(args):
return build_cnnmlp(args)

View File

@@ -1,168 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Backbone modules.
"""
from collections import OrderedDict
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List
from util.misc import NestedTensor, is_main_process
from .position_encoding import build_position_encoding
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
super().__init__()
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
# parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {'layer4': "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor):
xs = self.body(tensor)
return xs
# out: Dict[str, NestedTensor] = {}
# for name, x in xs.items():
# m = tensor_list.mask
# assert m is not None
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out[name] = NestedTensor(x, mask)
# return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
# class DINOv2BackBone(nn.Module):
# def __init__(self) -> None:
# super().__init__()
# self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
# self.body.eval()
# self.num_channels = 384
# @torch.no_grad()
# def forward(self, tensor):
# xs = self.body.forward_features(tensor)["x_norm_patchtokens"]
# od = OrderedDict()
# od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
# return od
class DINOv2BackBone(nn.Module):
def __init__(self, return_interm_layers: bool = False) -> None:
super().__init__()
self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
self.body.eval()
self.num_channels = 384
self.return_interm_layers = return_interm_layers
@torch.no_grad()
def forward(self, tensor):
features = self.body.forward_features(tensor)
if self.return_interm_layers:
layer1 = features["x_norm_patchtokens"]
layer2 = features["x_norm_patchtokens"]
layer3 = features["x_norm_patchtokens"]
layer4 = features["x_norm_patchtokens"]
od = OrderedDict()
od["0"] = layer1.reshape(layer1.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["1"] = layer2.reshape(layer2.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["2"] = layer3.reshape(layer3.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["3"] = layer4.reshape(layer4.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
return od
else:
xs = features["x_norm_patchtokens"]
od = OrderedDict()
od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
return od
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.dtype))
return out, pos
def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks
if args.backbone == 'dino_v2':
backbone = DINOv2BackBone()
else:
assert args.backbone in ['resnet18', 'resnet34']
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model

View File

@@ -1,300 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR model and criterion classes.
"""
import torch
from torch import nn
from torch.autograd import Variable
from .backbone import build_backbone
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
import numpy as np
def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps
def get_sinusoid_encoding_table(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class DETRVAE(nn.Module):
""" This is the DETR module that performs object detection """
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else:
raise NotImplementedError
# input_dim = 14 + 7 # robot_state + env_state
# self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
# self.input_proj_env_state = nn.Linear(7, hidden_dim)
# self.pos = torch.nn.Embedding(2, hidden_dim)
# self.backbones = None
# encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
# decoder extra parameters
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
is_training = actions is not None # train or val
bs, _ = qpos.shape
### Obtain latent z from action sequence
if is_training:
# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
# do not mask cls token
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
# query model
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
mu = latent_info[:, :self.latent_dim]
logvar = latent_info[:, self.latent_dim:]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
else:
mu = logvar = None
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
if self.backbones is not None:
# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
# print(f"Image shape: {image.shape}, Number of cameras: {len(self.camera_names)}")
for cam_id, cam_name in enumerate(self.camera_names):
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)
# proprioception features
proprio_input = self.input_proj_robot_state(qpos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
a_hat = self.action_head(hs)
is_pad_hat = self.is_pad_head(hs)
return a_hat, is_pad_hat, [mu, logvar]
class CNNMLP(nn.Module):
def __init__(self, backbones, state_dim, camera_names):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.camera_names = camera_names
self.action_head = nn.Linear(1000, state_dim) # TODO add more
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
backbone_down_projs = []
for backbone in backbones:
down_proj = nn.Sequential(
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
nn.Conv2d(128, 64, kernel_size=5),
nn.Conv2d(64, 32, kernel_size=5)
)
backbone_down_projs.append(down_proj)
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
mlp_in_dim = 768 * len(backbones) + 14
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
else:
raise NotImplementedError
def forward(self, qpos, image, env_state, actions=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
is_training = actions is not None # train or val
bs, _ = qpos.shape
# Image observation features and position embeddings
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0] # not used
all_cam_features.append(self.backbone_down_projs[cam_id](features))
# flatten everything
flattened_features = []
for cam_feature in all_cam_features:
flattened_features.append(cam_feature.reshape([bs, -1]))
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
a_hat = self.mlp(features)
return a_hat
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
if hidden_depth == 0:
mods = [nn.Linear(input_dim, output_dim)]
else:
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
for i in range(hidden_depth - 1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
mods.append(nn.Linear(hidden_dim, output_dim))
trunk = nn.Sequential(*mods)
return trunk
def build_encoder(args):
d_model = args.hidden_dim # 256
dropout = args.dropout # 0.1
nhead = args.nheads # 8
dim_feedforward = args.dim_feedforward # 2048
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
normalize_before = args.pre_norm # False
activation = "relu"
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
return encoder
def build(args):
state_dim = args.state_dim
action_dim = args.action_dim
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
# backbone = build_backbone(args)
# backbones.append(backbone)
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
transformer = build_transformer(args)
encoder = build_encoder(args)
model = DETRVAE(
backbones,
transformer,
encoder,
state_dim=state_dim,
action_dim=action_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters/1e6,))
return model
def build_cnnmlp(args):
state_dim = 14 # TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
model = CNNMLP(
backbones,
state_dim=state_dim,
camera_names=args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters/1e6,))
return model

View File

@@ -1,91 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
from util.misc import NestedTensor
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor):
x = tensor
# mask = tensor_list.mask
# assert mask is not None
# not_mask = ~mask
not_mask = torch.ones_like(x[0, [0]])
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = torch.cat([
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
return pos
def build_position_encoding(args):
N_steps = args.hidden_dim // 2
if args.position_embedding in ('v2', 'sine'):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif args.position_embedding in ('v3', 'learned'):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding

View File

@@ -1,312 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
import copy
from typing import Optional, List
import torch
import torch.nn.functional as F
from torch import nn, Tensor
class Transformer(nn.Module):
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False,
return_intermediate_dec=False):
super().__init__()
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
return_intermediate=return_intermediate_dec)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
# TODO flatten only when input has H and W
if len(src.shape) == 4: # has H and W
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
# mask = mask.flatten(1)
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
addition_input = torch.stack([latent_input, proprio_input], axis=0)
src = torch.cat([addition_input, src], axis=0)
else:
assert len(src.shape) == 3
# flatten NxHWxC to HWxNxC
bs, hw, c = src.shape
src = src.permute(1, 0, 2)
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
pos=pos_embed, query_pos=query_embed)
hs = hs.transpose(1, 2)
return hs
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
output = src
for layer in self.layers:
output = layer(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output.unsqueeze(0)
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def build_transformer(args):
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
return_intermediate_dec=True,
)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")

View File

@@ -1,163 +0,0 @@
import torch.nn as nn
from torch.nn import functional as F
import torchvision.transforms as transforms
from torchvision.transforms import v2
import torch
from roboimi.detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
class ACTPolicy(nn.Module):
def __init__(self, args_override):
super().__init__()
model, optimizer = build_ACT_model_and_optimizer(args_override)
self.model = model # CVAE decoder
self.optimizer = optimizer
self.kl_weight = args_override['kl_weight']
print(f'KL Weight {self.kl_weight}')
def __call__(self, qpos, image, actions=None, is_pad=None):
env_state = None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
image = normalize(image)
if actions is not None: # training time
actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries]
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict = dict()
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
loss_dict['l1'] = l1
loss_dict['kl'] = total_kld[0]
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
return loss_dict
else: # inference time
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
return a_hat
def configure_optimizers(self):
return self.optimizer
class ACTTVPolicy(nn.Module):
def __init__(self, args_override):
super().__init__()
model, optimizer = build_ACT_model_and_optimizer(args_override)
self.model = model # CVAE decoder
self.optimizer = optimizer
self.kl_weight = args_override['kl_weight']
self.qpos_noise_std = args_override['qpos_noise_std']
print(f'KL Weight {self.kl_weight}')
def __call__(self, qpos, image, actions=None, is_pad=None):
env_state = None
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
# image = normalize(image)
patch_h = 16
patch_w = 22
if actions is not None:
transform = v2.Compose([
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
v2.RandomPerspective(distortion_scale=0.5),
v2.RandomAffine(degrees=10, translate=(0.1,0.1), scale=(0.9,1.1)),
v2.GaussianBlur(kernel_size=(9,9), sigma=(0.1,2.0)),
v2.Resize((patch_h * 14, patch_w * 14)),
# v2.CenterCrop((patch_h * 14, patch_w * 14)),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
qpos += (self.qpos_noise_std**0.5)*torch.randn_like(qpos)
else: # inference time
transform = v2.Compose([
v2.Resize((patch_h * 14, patch_w * 14)),
# v2.CenterCrop((patch_h * 14, patch_w * 14)),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
image = transform(image)
if actions is not None: # training time
actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries]
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict = dict()
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
loss_dict['l1'] = l1
loss_dict['kl'] = total_kld[0]
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
return loss_dict
else: # inference time
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
return a_hat
def configure_optimizers(self):
return self.optimizer
class CNNMLPPolicy(nn.Module):
def __init__(self, args_override):
super().__init__()
model, optimizer = build_CNNMLP_model_and_optimizer(args_override)
self.model = model # decoder
self.optimizer = optimizer
def __call__(self, qpos, image, actions=None, is_pad=None):
env_state = None # TODO
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
image = normalize(image)
if actions is not None: # training time
actions = actions[:, 0]
a_hat = self.model(qpos, image, env_state, actions)
mse = F.mse_loss(actions, a_hat)
loss_dict = dict()
loss_dict['mse'] = mse
loss_dict['loss'] = loss_dict['mse']
return loss_dict
else: # inference time
a_hat = self.model(qpos, image, env_state) # no action, sample from prior
return a_hat
def configure_optimizers(self):
return self.optimizer
def kl_divergence(mu, logvar):
batch_size = mu.size(0)
assert batch_size != 0
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
dimension_wise_kld = klds.mean(0)
mean_kld = klds.mean(1).mean(0, True)
return total_kld, dimension_wise_kld, mean_kld

View File

@@ -1,10 +0,0 @@
from distutils.core import setup
from setuptools import find_packages
setup(
name='detr',
version='0.0.0',
packages=find_packages(),
license='MIT License',
long_description=open('README.md').read(),
)

View File

@@ -1 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

View File

@@ -1,88 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Utilities for bounding box manipulation and GIoU.
"""
import torch
from torchvision.ops.boxes import box_area
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2,
(x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
return iou - (area - union) / area
def masks_to_boxes(masks):
"""Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
Returns a [N, 4] tensors, with the boxes in xyxy format
"""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device)
h, w = masks.shape[-2:]
y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x)
x_mask = (masks * x.unsqueeze(0))
x_max = x_mask.flatten(1).max(-1)[0]
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
y_mask = (masks * y.unsqueeze(0))
y_max = y_mask.flatten(1).max(-1)[0]
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
return torch.stack([x_min, y_min, x_max, y_max], 1)

View File

@@ -1,468 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import os
import subprocess
import time
from collections import defaultdict, deque
import datetime
import pickle
from packaging import version
from typing import Optional, List
import torch
import torch.distributed as dist
from torch import Tensor
# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
if version.parse(torchvision.__version__) < version.parse('0.7'):
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
sha = 'N/A'
diff = "clean"
branch = 'N/A'
try:
sha = _run(['git', 'rev-parse', 'HEAD'])
subprocess.check_output(['git', 'diff'], cwd=cwd)
diff = _run(['git', 'diff-index', 'HEAD'])
diff = "has uncommited changes" if diff else "clean"
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def collate_fn(batch):
batch = list(zip(*batch))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return NestedTensor(tensor, mask)
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
max_size.append(max_size_i)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if version.parse(torchvision.__version__) < version.parse('0.7'):
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
else:
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)

View File

@@ -1,107 +0,0 @@
"""
Plotting utilities to visualize training logs.
"""
import torch
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path, PurePath
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
'''
Function to plot specific fields from training log(s). Plots both training and test results.
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
- fields = which results to plot from each log file - plots both training and test for each field.
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
- log_name = optional, name of log file if different than default 'log.txt'.
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
- solid lines are training results, dashed lines are test results.
'''
func_name = "plot_utils.py::plot_logs"
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
# convert single Path to list to avoid 'not iterable' error
if not isinstance(logs, list):
if isinstance(logs, PurePath):
logs = [logs]
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
else:
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
Expect list[Path] or single Path obj, received {type(logs)}")
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
for i, dir in enumerate(logs):
if not isinstance(dir, PurePath):
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
if not dir.exists():
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
# verify log_name exists
fn = Path(dir / log_name)
if not fn.exists():
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
print(f"--> full path of missing log file: {fn}")
return
# load log file(s) and plot
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
for j, field in enumerate(fields):
if field == 'mAP':
coco_eval = pd.DataFrame(
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
).ewm(com=ewm_col).mean()
axs[j].plot(coco_eval, c=color)
else:
df.interpolate().ewm(com=ewm_col).mean().plot(
y=[f'train_{field}', f'test_{field}'],
ax=axs[j],
color=[color] * 2,
style=['-', '--']
)
for ax, field in zip(axs, fields):
ax.legend([Path(p).name for p in logs])
ax.set_title(field)
def plot_precision_recall(files, naming_scheme='iter'):
if naming_scheme == 'exp_id':
# name becomes exp_id
names = [f.parts[-3] for f in files]
elif naming_scheme == 'iter':
names = [f.stem for f in files]
else:
raise ValueError(f'not supported {naming_scheme}')
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
data = torch.load(f)
# precision is n_iou, n_points, n_cat, n_area, max_det
precision = data['precision']
recall = data['params'].recThrs
scores = data['scores']
# take precision for all classes, all areas and 100 detections
precision = precision[0, :, :, 0, -1].mean(1)
scores = scores[0, :, :, 0, -1].mean(1)
prec = precision.mean()
rec = data['recall'][0, :, 0, -1].mean()
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
f'score={scores.mean():0.3f}, ' +
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
)
axs[0].plot(recall, precision, c=color)
axs[1].plot(recall, scores, c=color)
axs[0].set_title('Precision / Recall')
axs[0].legend(names)
axs[1].set_title('Scores / Recall')
axs[1].legend(names)
return fig, axs

View File

@@ -53,6 +53,7 @@ class DualDianaMed(MujocoEnv):
self.l_vis = None self.l_vis = None
self.top = None self.top = None
self.angle = None self.angle = None
self.front = None
self.obs = None self.obs = None
self.rew = None self.rew = None
@@ -168,6 +169,7 @@ class DualDianaMed(MujocoEnv):
obs['images']['angle'] = self.angle obs['images']['angle'] = self.angle
obs['images']['r_vis'] = self.r_vis obs['images']['r_vis'] = self.r_vis
obs['images']['l_vis'] = self.l_vis obs['images']['l_vis'] = self.l_vis
obs['images']['front'] = self.front
return obs return obs
def _get_image_obs(self): def _get_image_obs(self):
@@ -177,6 +179,7 @@ class DualDianaMed(MujocoEnv):
obs['images']['angle'] = self.angle obs['images']['angle'] = self.angle
obs['images']['r_vis'] = self.r_vis obs['images']['r_vis'] = self.r_vis
obs['images']['l_vis'] = self.l_vis obs['images']['l_vis'] = self.l_vis
obs['images']['front'] = self.front
return obs return obs
def _get_qpos_obs(self): def _get_qpos_obs(self):
@@ -202,6 +205,8 @@ class DualDianaMed(MujocoEnv):
return self.r_vis return self.r_vis
elif self.cam == 'l_vis': elif self.cam == 'l_vis':
return self.l_vis return self.l_vis
elif self.cam == 'front':
return self.front
else: else:
raise AttributeError("please input right name") raise AttributeError("please input right name")
@@ -222,6 +227,10 @@ class DualDianaMed(MujocoEnv):
img_renderer.update_scene(self.mj_data,camera="angle") img_renderer.update_scene(self.mj_data,camera="angle")
self.angle = img_renderer.render() self.angle = img_renderer.render()
self.angle = self.angle[:, :, ::-1] self.angle = self.angle[:, :, ::-1]
img_renderer.update_scene(self.mj_data,camera="front")
self.front = img_renderer.render()
self.front = self.front[:, :, ::-1]
if self.cam_view is not None:
cv2.imshow('Cam view', self.cam_view) cv2.imshow('Cam view', self.cam_view)
cv2.waitKey(1) cv2.waitKey(1)

View File

@@ -72,12 +72,17 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
self.mj_data.joint('red_box_joint').qpos[5] = 0.0 self.mj_data.joint('red_box_joint').qpos[5] = 0.0
self.mj_data.joint('red_box_joint').qpos[6] = 0.0 self.mj_data.joint('red_box_joint').qpos[6] = 0.0
super().reset() super().reset()
self.top = None
self.angle = None
self.r_vis = None
self.front = None
self.cam_flage = True self.cam_flage = True
t=0 t=0
while self.cam_flage: while self.cam_flage:
if(type(self.top)==type(None) if(type(self.top)==type(None)
or type(self.angle)==type(None) or type(self.angle)==type(None)
or type(self.r_vis)==type(None)): or type(self.r_vis)==type(None)
or type(self.front)==type(None)):
time.sleep(0.001) time.sleep(0.001)
t+=1 t+=1
else: else:

View File

@@ -27,8 +27,8 @@ def sample_insertion_pose():
def sample_transfer_pose(): def sample_transfer_pose():
# Box # Box
x_range = [0.0, 0.05] x_range = [-0.2, 0.2]
y_range = [0.95, 1.05] y_range = [0.7, 1.1]
z_range = [0.47, 0.47] z_range = [0.47, 0.47]
ranges = np.vstack([x_range, y_range, z_range]) ranges = np.vstack([x_range, y_range, z_range])

View File

@@ -18,9 +18,9 @@ SIM_TASK_CONFIGS = {
# }, # },
'sim_transfer': { 'sim_transfer': {
'dataset_dir': DATASET_DIR + '/sim_transfer', 'dataset_dir': DATASET_DIR + '/sim_transfer',
'num_episodes': 7, 'num_episodes': 20,
'episode_len': 700, 'episode_len': 700,
'camera_names': ['angle','r_vis'], 'camera_names': ['top','r_vis','front'],
'xml_dir': HOME_PATH + '/assets' 'xml_dir': HOME_PATH + '/assets'
}, },

View File

@@ -2,7 +2,7 @@ import os
import torch import torch
from roboimi.utils.utils import load_data, set_seed from roboimi.utils.utils import load_data, set_seed
from roboimi.detr.policy import ACTPolicy, CNNMLPPolicy, ACTTVPolicy from roboimi.detr.policy import ACTPolicy, CNNMLPPolicy, ACTTVPolicy
from roboimi.ddt.policy import DDTPolicy from roboimi.gr00t.policy import gr00tPolicy
class ModelInterface: class ModelInterface:
def __init__(self, config): def __init__(self, config):
@@ -66,23 +66,25 @@ class ModelInterface:
'num_queries': 1, 'num_queries': 1,
'camera_names': self.config['camera_names'], 'camera_names': self.config['camera_names'],
} }
elif self.config['policy_class'] == 'DDT': elif self.config['policy_class'] == 'GR00T':
# GR00T uses its own config section from config.yaml
gr00t_config = self.config.get('gr00t', {})
self.config['policy_config'] = { self.config['policy_config'] = {
'lr': self.config['lr'], 'lr': gr00t_config.get('lr', self.config['lr']),
'lr_backbone': self.config['lr_backbone'], 'lr_backbone': gr00t_config.get('lr_backbone', self.config['lr_backbone']),
'backbone': self.config.get('backbone', 'dino_v2'), 'weight_decay': gr00t_config.get('weight_decay', 1e-4),
'num_queries': self.config['chunk_size'], 'embed_dim': gr00t_config.get('embed_dim', 1536),
'hidden_dim': self.config['hidden_dim'], 'hidden_dim': gr00t_config.get('hidden_dim', 1024),
'nheads': self.config['nheads'], 'state_dim': gr00t_config.get('state_dim', 16),
'enc_layers': self.config['enc_layers'], 'action_dim': gr00t_config.get('action_dim', 16),
'state_dim': self.config.get('state_dim', 16), 'num_queries': gr00t_config.get('num_queries', 16),
'action_dim': self.config.get('action_dim', 16), 'num_layers': gr00t_config.get('num_layers', 16),
'nheads': gr00t_config.get('nheads', 32),
'mlp_ratio': gr00t_config.get('mlp_ratio', 4),
'dropout': gr00t_config.get('dropout', 0.2),
'backbone': gr00t_config.get('backbone', 'dino_v2'),
'position_embedding': gr00t_config.get('position_embedding', 'sine'),
'camera_names': self.config['camera_names'], 'camera_names': self.config['camera_names'],
'qpos_noise_std': self.config.get('qpos_noise_std', 0),
# DDT 特有参数
'num_blocks': self.config.get('num_blocks', 12),
'mlp_ratio': self.config.get('mlp_ratio', 4.0),
'num_inference_steps': self.config.get('num_inference_steps', 10),
} }
else: else:
raise NotImplementedError raise NotImplementedError
@@ -94,8 +96,8 @@ class ModelInterface:
return ACTTVPolicy(self.config['policy_config']) return ACTTVPolicy(self.config['policy_config'])
elif self.config['policy_class'] == 'CNNMLP': elif self.config['policy_class'] == 'CNNMLP':
return CNNMLPPolicy(self.config['policy_config']) return CNNMLPPolicy(self.config['policy_config'])
elif self.config['policy_class'] == 'DDT': elif self.config['policy_class'] == 'GR00T':
return DDTPolicy(self.config['policy_config']) return gr00tPolicy(self.config['policy_config'])
else: else:
raise NotImplementedError raise NotImplementedError

1
roboimi/vla/__init__.py Normal file
View File

@@ -0,0 +1 @@
# export VLAAgent, VLAModelConfig

401
roboimi/vla/agent.py Normal file
View File

@@ -0,0 +1,401 @@
import torch
import torch.nn as nn
import numpy as np
from collections import deque
from typing import Dict, Optional, Any, Tuple
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D
from roboimi.vla.models.normalization import NormalizationModule
class VLAAgent(nn.Module):
def __init__(
self,
vision_backbone, # 视觉编码器ResNet 等)
state_encoder,
action_encoder,
head,
action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper)
obs_dim, # 本体感知维度 (例如 关节角度)
pred_horizon=16, # 预测未来多少步动作
obs_horizon=4, # 使用多少步历史观测
diffusion_steps=100, # DDPM 加噪步数
inference_steps=10, # DDIM 推理步数
num_cams=3, # 视觉输入的摄像头数量
dataset_stats=None, # 数据集统计信息,用于归一化
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
num_action_steps=8, # 每次推理实际执行多少步动作
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
):
super().__init__()
# 保存参数
self.action_dim = action_dim
self.obs_dim = obs_dim
self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon
self.num_cams = num_cams
self.num_action_steps = num_action_steps
self.inference_steps = inference_steps
self.head_type = head_type # 'unet' 或 'transformer'
# 归一化模块 - 统一训练和推理的归一化逻辑
self.normalization = NormalizationModule(
stats=dataset_stats,
normalization_type=normalization_type
)
self.vision_encoder = vision_backbone
single_cam_feat_dim = self.vision_encoder.output_dim
# global_cond_dim: 展平后的总维度用于UNet
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
total_prop_dim = obs_dim * obs_horizon
self.global_cond_dim = total_vision_dim + total_prop_dim
# per_step_cond_dim: 每步的条件维度用于Transformer
# 注意这里不乘以obs_horizon因为Transformer的输入是序列形式
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
clip_sample=True,
prediction_type='epsilon' # 预测噪声
)
# DDIM 调度器用于快速推理
self.infer_scheduler = DDIMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule='squaredcos_cap_v2',
clip_sample=True,
prediction_type='epsilon'
)
# 根据head类型初始化不同的参数
if head_type == 'transformer':
# 如果head已经是nn.Module实例直接使用否则需要初始化
if isinstance(head, nn.Module):
# 已经是实例化的模块测试时直接传入<E4BCA0><E585A5>
self.noise_pred_net = head
else:
# Hydra部分初始化的对象调用时传入参数
self.noise_pred_net = head(
input_dim=action_dim,
output_dim=action_dim,
horizon=pred_horizon,
n_obs_steps=obs_horizon,
cond_dim=self.per_step_cond_dim # 每步的条件维度
)
else: # 'unet' (default)
# UNet接口: input_dim, global_cond_dim
self.noise_pred_net = head(
input_dim=action_dim,
global_cond_dim=self.global_cond_dim
)
self.state_encoder = state_encoder
self.action_encoder = action_encoder
# 初始化队列(用于在线推理)
self.reset()
def _get_model_device(self) -> torch.device:
"""获取模型当前所在设备。"""
return next(self.parameters()).device
def _move_to_device(self, data, device: torch.device):
"""递归地将张量数据移动到指定设备。"""
if torch.is_tensor(data):
return data.to(device)
if isinstance(data, dict):
return {k: self._move_to_device(v, device) for k, v in data.items()}
if isinstance(data, list):
return [self._move_to_device(v, device) for v in data]
if isinstance(data, tuple):
return tuple(self._move_to_device(v, device) for v in data)
return data
# ==========================
# 训练阶段 (Training)
# ==========================
def compute_loss(self, batch):
"""
计算训练损失
Args:
batch: 包含 images, qpos (本体感知), action, action_is_pad 的字典
"""
actions, states, images = batch['action'], batch['qpos'], batch['images']
action_is_pad = batch.get('action_is_pad', None) # 获取padding mask
B = actions.shape[0]
# 归一化 states (qpos) 和 actions
states = self.normalization.normalize_qpos(states)
actions = self.normalization.normalize_action(actions)
state_features = self.state_encoder(states)
# 1. 提取视觉特征
visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim)
action_features = self.action_encoder(actions)
# 2. 采样噪声
noise = torch.randn_like(action_features)
# 3. 随机采样时间步 (Timesteps)
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps,
(B,), device=action_features.device
).long()
# 4. 给动作加噪 (Forward Diffusion)
noisy_actions = self.noise_scheduler.add_noise(
action_features, noise, timesteps
)
# 拼接全局条件并展平
# visual_features: (B, obs_horizon, vision_dim)
# state_features: (B, obs_horizon, obs_dim)
# 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim))
global_cond = torch.cat([visual_features, state_features], dim=-1)
global_cond = global_cond.flatten(start_dim=1)
# 5. 网络预测噪声根据head类型选择接口
if self.head_type == 'transformer':
# Transformer需要序列格式的条件: (B, obs_horizon, cond_dim_per_step)
# 将展平的global_cond reshape回序列格式
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
cond=cond
)
else: # 'unet'
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
global_cond=global_cond
)
# 6. 计算 Loss (MSE),支持 padding mask
loss = nn.functional.mse_loss(pred_noise, noise, reduction='none')
# 如果提供了 action_is_pad对padding位置进行mask
if action_is_pad is not None:
# action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim)
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype) # 1.0表示有效数据
valid_count = mask.sum() * loss.shape[-1]
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
else:
loss = loss.mean()
return loss
# ==========================
# 队列管理 (Queue Management)
# ==========================
def reset(self):
"""清空观测和动作队列。应在 env.reset() 时调用"""
self._queues = {
'qpos': deque(maxlen=self.obs_horizon),
'images': deque(maxlen=self.obs_horizon),
'action': deque(maxlen=self.pred_horizon - self.obs_horizon + 1), # 可执行的动作缓存
}
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
"""
将新的观测添加到队列中。
Args:
observation: 包含 'qpos''images' 的字典
"""
# 添加本体感知
if 'qpos' in observation:
self._queues['qpos'].append(observation['qpos'].clone())
# 添加图像
if 'images' in observation:
self._queues['images'].append({k: v.clone() for k, v in observation['images'].items()})
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
"""
从队列中准备用于推理的批量观测。
如果队列未满(首次调用时),用最新观测重复填充。
Returns:
batch: 包含堆叠后的历史观测的字典
"""
# 堆叠历史本体感知
qpos_list = list(self._queues['qpos'])
if len(qpos_list) == 0:
raise ValueError("观测队列为空,请先调用 _populate_queues 添加观测")
# 如果队列未满,用最后一个观测填充
while len(qpos_list) < self.obs_horizon:
qpos_list.append(qpos_list[-1])
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
# 堆叠历史图像
images_list = list(self._queues['images'])
if len(images_list) == 0:
raise ValueError("图像队列为空,请先调用 _populate_queues 添加观测")
# 如果队列未满,用最后一个观测填充
while len(images_list) < self.obs_horizon:
images_list.append(images_list[-1])
batch_images = {}
for cam_name in images_list[0].keys():
batch_images[cam_name] = torch.stack([img[cam_name] for img in images_list], dim=0).unsqueeze(0)
return {'qpos': batch_qpos, 'images': batch_images}
# ==========================
# 在线推理 (Online Inference)
# ==========================
@torch.no_grad()
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
根据当前观测选择单个动作。
这个方法维护一个历史观测和生成动作轨迹的缓存。工作流程:
- 缓存 `obs_horizon` 步的历史观测
- Diffusion 模型生成 `pred_horizon` 步的动作
- 实际执行 `num_action_steps` 步动作
示意图:
--------------------------------------------------------------
(图例: o=obs_horizon, h=pred_horizon, a=num_action_steps)
|时间步 | 0 | 1 | ... | o-1 | o | ... | h-1 |
|观测是否使用 | 是 | 是 | 是 | 是 | 否 | 否 | 否 |
|动作是否生成 | 是 | 是 | 是 | 是 | 是 | 是 | 是 |
|动作是否执行 | 否 | 否 | 否 | 否 | 是 | 是 | 是 |
--------------------------------------------------------------
Args:
observation: 包含 'qpos''images' 的字典
Returns:
action: (action_dim,) 单个动作
"""
# 使用模型当前设备作为唯一真值,将输入移动到模型设备
# 避免根据CPU观测把模型错误搬回CPU。
device = self._get_model_device()
observation = self._move_to_device(observation, device)
# 将新观测添加到队列
self._populate_queues(observation)
# 如果动作队列为空,生成新的动作序列
if len(self._queues['action']) == 0:
# 从队列准备批量观测
batch = self._prepare_observation_batch()
# 生成动作块
actions = self.predict_action_chunk(batch) # (1, pred_horizon, action_dim)
# 提取可执行的动作部分
# 从 obs_horizon-1 开始,因为前面的动作对应过去的观测
start = self.obs_horizon - 1
end = start + self.num_action_steps
executable_actions = actions[:, start:end] # (1, num_action_steps, action_dim)
# 将动作添加到队列
for i in range(executable_actions.shape[1]):
self._queues['action'].append(executable_actions[:, i].squeeze(0)) # (action_dim,)
# 从队列中取出一个动作
action = self._queues['action'].popleft() # (action_dim,)
return action
@torch.no_grad()
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
预测一个动作块(用于在线推理)。
Args:
batch: 包含 'qpos''images' 的字典
- qpos: (B, obs_horizon, obs_dim)
- images: Dict[str, (B, obs_horizon, C, H, W)]
Returns:
actions: (B, pred_horizon, action_dim) 预测的动作序列
"""
return self.predict_action(batch['images'], batch['qpos'])
# ==========================
# 批量推理 (Batch Inference - 原有方法)
# ==========================
@torch.no_grad()
def predict_action(self, images, proprioception):
"""
批量预测动作序列(用于训练和离线评估)
Args:
images: 图像观测字典
proprioception: 本体感知观测 (qpos)
Returns:
denormalized_actions: 反归一化后的动作序列
"""
B = proprioception.shape[0]
# 归一化 proprioception (qpos)
proprioception = self.normalization.normalize_qpos(proprioception)
# 1. 提取当前观测特征(只提取一次)
visual_features = self.vision_encoder(images)
state_features = self.state_encoder(proprioception)
# 拼接条件(只计算一次)
# visual_features: (B, obs_horizon, vision_dim)
# state_features: (B, obs_horizon, obs_dim)
global_cond = torch.cat([visual_features, state_features], dim=-1)
global_cond_flat = global_cond.flatten(start_dim=1)
if self.head_type == 'transformer':
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
else:
cond = None
# 2. 初始化纯高斯噪声动作
# 形状: (B, pred_horizon, action_dim)
device = visual_features.device
current_actions = torch.randn(
(B, self.pred_horizon, self.action_dim), device=device
)
# 3. 逐步去噪循环 (Reverse Diffusion)
self.infer_scheduler.set_timesteps(self.inference_steps) # DDIM 推理步数
for t in self.infer_scheduler.timesteps:
model_input = current_actions
# 预测噪声根据head类型选择接口
if self.head_type == 'transformer':
noise_pred = self.noise_pred_net(
sample=model_input,
timestep=t,
cond=cond
)
else: # 'unet'
noise_pred = self.noise_pred_net(
sample=model_input,
timestep=t,
global_cond=global_cond_flat
)
# 移除噪声,更新 current_actions
current_actions = self.infer_scheduler.step(
noise_pred, t, current_actions
).prev_sample
# 4. 反归一化动作序列
denormalized_actions = self.normalization.denormalize_action(current_actions)
return denormalized_actions
def get_normalization_stats(self):
"""获取归一化统计信息(用于保存到 checkpoint"""
return self.normalization.get_stats()

View File

@@ -0,0 +1,217 @@
import torch
import torch.nn as nn
from collections import deque
from typing import Dict
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from roboimi.vla.models.normalization import NormalizationModule
class VLAAgentGr00tDiT(nn.Module):
"""
VLA Agent variant that swaps Transformer1D head with gr00t DiT head.
Other components (backbone/encoders/scheduler/queue logic) stay aligned
with the existing VLAAgent implementation.
"""
def __init__(
self,
vision_backbone,
state_encoder,
action_encoder,
head,
action_dim,
obs_dim,
pred_horizon=16,
obs_horizon=4,
diffusion_steps=100,
inference_steps=10,
num_cams=3,
dataset_stats=None,
normalization_type="min_max",
num_action_steps=8,
):
super().__init__()
self.action_dim = action_dim
self.obs_dim = obs_dim
self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon
self.num_cams = num_cams
self.num_action_steps = num_action_steps
self.inference_steps = inference_steps
self.normalization = NormalizationModule(
stats=dataset_stats,
normalization_type=normalization_type,
)
self.vision_encoder = vision_backbone
single_cam_feat_dim = self.vision_encoder.output_dim
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
prediction_type="epsilon",
)
self.infer_scheduler = DDIMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
prediction_type="epsilon",
)
if isinstance(head, nn.Module):
self.noise_pred_net = head
else:
self.noise_pred_net = head(
input_dim=action_dim,
output_dim=action_dim,
horizon=pred_horizon,
n_obs_steps=obs_horizon,
cond_dim=self.per_step_cond_dim,
)
self.state_encoder = state_encoder
self.action_encoder = action_encoder
self.reset()
def _get_model_device(self) -> torch.device:
return next(self.parameters()).device
def _move_to_device(self, data, device: torch.device):
if torch.is_tensor(data):
return data.to(device)
if isinstance(data, dict):
return {k: self._move_to_device(v, device) for k, v in data.items()}
if isinstance(data, list):
return [self._move_to_device(v, device) for v in data]
if isinstance(data, tuple):
return tuple(self._move_to_device(v, device) for v in data)
return data
def _build_cond(self, images: Dict[str, torch.Tensor], states: torch.Tensor) -> torch.Tensor:
visual_features = self.vision_encoder(images)
state_features = self.state_encoder(states)
return torch.cat([visual_features, state_features], dim=-1)
def compute_loss(self, batch):
actions, states, images = batch["action"], batch["qpos"], batch["images"]
action_is_pad = batch.get("action_is_pad", None)
bsz = actions.shape[0]
states = self.normalization.normalize_qpos(states)
actions = self.normalization.normalize_action(actions)
action_features = self.action_encoder(actions)
cond = self._build_cond(images, states)
noise = torch.randn_like(action_features)
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
(bsz,),
device=action_features.device,
).long()
noisy_actions = self.noise_scheduler.add_noise(action_features, noise, timesteps)
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
cond=cond,
)
loss = nn.functional.mse_loss(pred_noise, noise, reduction="none")
if action_is_pad is not None:
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
valid_count = mask.sum() * loss.shape[-1]
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
else:
loss = loss.mean()
return loss
def reset(self):
self._queues = {
"qpos": deque(maxlen=self.obs_horizon),
"images": deque(maxlen=self.obs_horizon),
"action": deque(maxlen=self.pred_horizon - self.obs_horizon + 1),
}
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
if "qpos" in observation:
self._queues["qpos"].append(observation["qpos"].clone())
if "images" in observation:
self._queues["images"].append({k: v.clone() for k, v in observation["images"].items()})
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
qpos_list = list(self._queues["qpos"])
if len(qpos_list) == 0:
raise ValueError("observation queue is empty.")
while len(qpos_list) < self.obs_horizon:
qpos_list.append(qpos_list[-1])
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0)
images_list = list(self._queues["images"])
if len(images_list) == 0:
raise ValueError("image queue is empty.")
while len(images_list) < self.obs_horizon:
images_list.append(images_list[-1])
batch_images = {}
for cam_name in images_list[0].keys():
batch_images[cam_name] = torch.stack(
[img[cam_name] for img in images_list], dim=0
).unsqueeze(0)
return {"qpos": batch_qpos, "images": batch_images}
@torch.no_grad()
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
device = self._get_model_device()
observation = self._move_to_device(observation, device)
self._populate_queues(observation)
if len(self._queues["action"]) == 0:
batch = self._prepare_observation_batch()
actions = self.predict_action_chunk(batch)
start = self.obs_horizon - 1
end = start + self.num_action_steps
executable_actions = actions[:, start:end]
for i in range(executable_actions.shape[1]):
self._queues["action"].append(executable_actions[:, i].squeeze(0))
return self._queues["action"].popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
return self.predict_action(batch["images"], batch["qpos"])
@torch.no_grad()
def predict_action(self, images, proprioception):
bsz = proprioception.shape[0]
proprioception = self.normalization.normalize_qpos(proprioception)
cond = self._build_cond(images, proprioception)
device = cond.device
current_actions = torch.randn((bsz, self.pred_horizon, self.action_dim), device=device)
self.infer_scheduler.set_timesteps(self.inference_steps)
for t in self.infer_scheduler.timesteps:
noise_pred = self.noise_pred_net(
sample=current_actions,
timestep=t,
cond=cond,
)
current_actions = self.infer_scheduler.step(
noise_pred, t, current_actions
).prev_sample
return self.normalization.denormalize_action(current_actions)
def get_normalization_stats(self):
return self.normalization.get_stats()

View File

@@ -0,0 +1,39 @@
# @package agent
defaults:
# - /backbone@vision_backbone: resnet
- /backbone@vision_backbone: resnet_diffusion
- /modules@state_encoder: identity_state_encoder
- /modules@action_encoder: identity_action_encoder
- /head: conditional_unet1d
- _self_
_target_: roboimi.vla.agent.VLAAgent
# ====================
# 模型维度配置
# ====================
action_dim: 16 # 动作维度(机器人关节数)
obs_dim: 16 # 本体感知维度(关节位置)
# ====================
#
# ====================
normalization_type: "min_max" # "min_max" or "gaussian"
# ====================
# 时间步配置
# ====================
pred_horizon: 16 # 预测未来多少步动作
obs_horizon: 2 # 使用多少步历史观测
num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1
# ====================
# 相机配置
# ====================
num_cams: 3 # 摄像头数量 (r_vis, top, front)
# ====================
# 扩散过程配置
# ====================
diffusion_steps: 100 # 扩散训练步数DDPM
inference_steps: 10 # 推理时的去噪步数DDIM固定为 10

View File

@@ -0,0 +1,37 @@
# @package agent
defaults:
- /backbone@vision_backbone: resnet_diffusion
- /modules@state_encoder: identity_state_encoder
- /modules@action_encoder: identity_action_encoder
- /head: gr00t_dit1d
- _self_
_target_: roboimi.vla.agent_gr00t_dit.VLAAgentGr00tDiT
# Model dimensions
action_dim: 16
obs_dim: 16
# Normalization
normalization_type: "min_max"
# Horizons
pred_horizon: 16
obs_horizon: 2
num_action_steps: 8
# Cameras
num_cams: 3
# Diffusion
diffusion_steps: 100
inference_steps: 10
# Head overrides
head:
input_dim: ${agent.action_dim}
output_dim: ${agent.action_dim}
horizon: ${agent.pred_horizon}
n_obs_steps: ${agent.obs_horizon}
cond_dim: 208

View File

@@ -0,0 +1,54 @@
# @package agent
defaults:
- /backbone@vision_backbone: resnet_diffusion
- /modules@state_encoder: identity_state_encoder
- /modules@action_encoder: identity_action_encoder
- /head: transformer1d
- _self_
_target_: roboimi.vla.agent.VLAAgent
# ====================
# 模型维度配置
# ====================
action_dim: 16 # 动作维度(机器人关节数)
obs_dim: 16 # 本体感知维度(关节位置)
# ====================
# 归一化配置
# ====================
normalization_type: "min_max" # "min_max" or "gaussian"
# ====================
# 时间步配置
# ====================
pred_horizon: 16 # 预测未来多少步动作
obs_horizon: 2 # 使用多少步历史观测
num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1
# ====================
# 相机配置
# ====================
num_cams: 3 # 摄像头数量 (r_vis, top, front)
# ====================
# 扩散过程配置
# ====================
diffusion_steps: 100 # 扩散训练步数DDPM
inference_steps: 10 # 推理时的去噪步数DDIM<4D><EFBC8C>定为 10
# ====================
# Head 类型标识用于VLAAgent选择调用方式
# ====================
head_type: "transformer" # "unet" 或 "transformer"
# Head 参数覆盖
head:
input_dim: ${agent.action_dim}
output_dim: ${agent.action_dim}
horizon: ${agent.pred_horizon}
n_obs_steps: ${agent.obs_horizon}
# Transformer的cond_dim是每步的维度
# ResNet18 + SpatialSoftmax(32 keypoints) = 64维/相机
# 计算方式:单相机特征(64) * 相机数(3) + obs_dim(16) = 208
cond_dim: 208

View File

@@ -0,0 +1,33 @@
_target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
# ====================
# 骨干网络选择
# ====================
vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50
pretrained_backbone_weights: "IMAGENET1K_V1" # 使用ImageNet预训练权重torchvision>=0.13
# ====================
# 冻结设置
# ====================
freeze_backbone: true # 冻结ResNet参数只训练后面的pool和out层推荐true
# ====================
# 输入配置
# ====================
input_shape: [3, 224, 224] # 输入图像形状 (C, H, W) - ImageNet标准尺寸
crop_shape: null # 裁剪后的图像形状 (H, W) - 设为null禁用裁剪
crop_is_random: true # 训练时使用随机裁剪评估时使用中心裁剪crop_shape=null时无效
# ====================
# 归一化和特征提取
# ====================
use_group_norm: true # 使用 GroupNorm 替代 BatchNorm更适合小批次训练
spatial_softmax_num_keypoints: 32 # Spatial Softmax 关键点数量
# ====================
# 编码器模式
# ====================
# false: 共享编码器(所有摄像头共享一个 ResNet参数少但容量受限推荐
# true: 独立编码器(每个摄像头有独立的 ResNet参数多但容量大
use_separate_rgb_encoder_per_camera: true
num_cameras: 3 # 摄像头数量

View File

@@ -0,0 +1,44 @@
defaults:
- agent: resnet_transformer
- data: simpe_robot_dataset
- eval: eval
- _self_
# ====================
# 训练配置
# ====================
train:
# 基础训练参数
batch_size: 8 # 批次大小
lr: 5e-5 # 学习率Transformer建议更小
max_steps: 100000 # 最大训练步数
device: "cuda" # 设备: "cuda" 或 "cpu"
# 数据加载
num_workers: 8 # DataLoader 工作进程数(调试时设为 0生产环境用 8
val_split: 0.1 # 验证集比例
seed: 42 # 随机种子(用于数据划分)
# 日志和检查点
log_freq: 100 # 日志记录频率(步数)
save_freq: 2000 # 保存检查点频率(步数)
# 学习率调度器(带预热)
warmup_steps: 2000 # 预热步数Transformer建议更长
scheduler_type: "cosine" # 预热后的调度器: "constant" 或 "cosine"
min_lr: 1e-6 # 最小学习率(用于余弦退火)
# 优化器
weight_decay: 1e-5 # 权重衰减L2 正则化)
grad_clip: 1.0 # 梯度裁剪阈值
# 微调配置
pretrained_ckpt: null # 预训练 checkpoint 路径(用于微调),例如: "checkpoints/vla_model_step_8000.pt"
# ====================
# 实验配置
# ====================
experiment:
name: "vla_diffusion" # 实验名称
notes: "" # 实验备注
tags: [] # 实验标签

View File

@@ -0,0 +1,21 @@
# @package data
_target_: roboimi.vla.data.simpe_robot_dataset.SimpleRobotDataset
# ====================
# 数据集路径
# ====================
dataset_dir: "roboimi/demos/dataset/sim_transfer"
# ====================
# 时间步参数(从 agent 配置引用)
# ====================
pred_horizon: ${agent.pred_horizon} # 预测步数
obs_horizon: ${agent.obs_horizon} # 观测步数
# ====================
# 相机配置
# ====================
camera_names:
- r_vis # 机器人视角相机
- top # 顶部相机
- front # 前方相机

View File

@@ -0,0 +1,34 @@
# @package eval
# 评估配置
ckpt_path: "checkpoints/vla_model_best.pt" # 模型检查点路径
num_episodes: 3 # 评估回合数
max_timesteps: 700 # 每回合最大时间步
device: ${train.device} # 与训练保持一致
task_name: "sim_transfer" # 环境任务名称
# ====================
# 策略执行参数
# ====================
# num_queries 已废弃,现在使用 agent 的 select_action() 自动管理队列
# 以下参数仅用于兼容旧代码,实际使用 agent.num_action_steps
num_queries: ${agent.num_action_steps}
obs_horizon: ${agent.obs_horizon}
# ====================
# 相机配置
# ====================
camera_names: ${data.camera_names}
# ====================
# 动作平滑
# ====================
use_smoothing: false
smooth_method: "ema"
smooth_alpha: 0.3
# ====================
# 调试选项
# ====================
verbose_action: true # 是否打印每个时间步的动作信息

View File

@@ -0,0 +1,15 @@
_target_: roboimi.vla.models.heads.conditional_unet1d.ConditionalUnet1D
_partial_: true
# ====================
# UNet1D 配置
# ====================
kernel_size: 3 # 卷积核大小
cond_predict_scale: false # FiLM 条件化时是否同时预测 scalebias + scale 或仅 bias
# ====================
# 网络架构(默认值,可覆盖)
# ====================
# diffusion_step_embed_dim: 256 # 扩散时间步嵌入维度
# down_dims: [256, 512, 1024] # 下采样各层通道数
# n_groups: 8 # GroupNorm 分组数

View File

@@ -0,0 +1,22 @@
_target_: roboimi.vla.models.heads.gr00t_dit1d.Gr00tDiT1D
_partial_: true
# DiT architecture
n_layer: 6
n_head: 8
n_emb: 256
hidden_dim: 256
mlp_ratio: 4
dropout: 0.1
# Positional embeddings
add_action_pos_emb: true
add_cond_pos_emb: true
# Supplied by agent interpolation:
# - input_dim
# - output_dim
# - horizon
# - n_obs_steps
# - cond_dim

View File

@@ -0,0 +1,29 @@
# Transformer-based Diffusion Policy Head
_target_: roboimi.vla.models.heads.transformer1d.Transformer1D
_partial_: true
# ====================
# Transformer 架构配置
# ====================
n_layer: 4 # Transformer层数先用小模型提高收敛稳定性
n_head: 4 # 注意力头数
n_emb: 128 # 嵌入维度
p_drop_emb: 0.05 # Embedding dropout
p_drop_attn: 0.05 # Attention dropout
# ====================
# 条件配置
# ====================
causal_attn: false # 是否使用因果注意力(自回归生成)
obs_as_cond: true # 观测作为条件由cond_dim > 0决定
n_cond_layers: 1 # 条件编码器层数1层先做稳定融合
# ====================
# 注意事项
# ====================
# 以下参数将在agent配置中通过interpolation提供
# - input_dim: ${agent.action_dim}
# - output_dim: ${agent.action_dim}
# - horizon: ${agent.pred_horizon}
# - n_obs_steps: ${agent.obs_horizon}
# - cond_dim: 通过agent中的global_cond_dim计算

View File

@@ -0,0 +1 @@
_target_: roboimi.vla.modules.encoders.IdentityActionEncoder

View File

@@ -0,0 +1 @@
_target_: roboimi.vla.modules.encoders.IdentityStateEncoder

View File

View File

@@ -0,0 +1,46 @@
import abc
import torch
import torch.nn as nn
from typing import Dict, Any, Optional
class VLABackbone(nn.Module, abc.ABC):
"""
Contract for Vision/Language Backbones.
Must return a feature tensor of shape (B, Seq, Embed_Dim).
"""
@abc.abstractmethod
def forward(self, obs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Args:
obs: Dictionary containing 'image' and optionally 'text'.
Returns:
features: (B, S, D) embedding.
"""
pass
class VLAProjector(nn.Module, abc.ABC):
"""
Contract for the adaptation layer (Projector).
Connects Backbone features to the Policy Head.
"""
@abc.abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
class VLAHead(nn.Module, abc.ABC):
"""
Contract for Action Generation Heads (Policies).
Handles both training (loss calculation) and inference (action generation).
"""
@abc.abstractmethod
def forward(self, embeddings: torch.Tensor, actions: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
"""
Args:
embeddings: (B, S, Hidden) from Projector.
actions: (B, Pred_Horizon, Action_Dim) - Ground truth for training.
Returns:
Dict containing 'loss' (if actions provided) or 'pred_actions'.
"""
pass

View File

View File

@@ -0,0 +1,242 @@
import torch
import h5py
from torch.utils.data import Dataset
from typing import List, Dict, Union
from pathlib import Path
from collections import OrderedDict
class SimpleRobotDataset(Dataset):
"""
HDF5 懒加载数据集 - LeRobotDataset 格式
返回格式:
- observation.state: (obs_horizon, state_dim)
- observation.{cam_name}: (obs_horizon, C, H, W)
- action: (pred_horizon, action_dim)
"""
def __init__(
self,
dataset_dir: Union[str, Path],
obs_horizon: int = 2,
pred_horizon: int = 8,
camera_names: List[str] = None,
max_open_files: int = 64,
):
"""
Args:
dataset_dir: HDF5 文件目录路径
obs_horizon: 观察过去多少帧
pred_horizon: 预测未来多少帧动作
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
max_open_files: 每个 worker 最多缓存的 HDF5 文件句柄数
HDF5 文件格式:
- action: [T, action_dim]
- observations/qpos: [T, obs_dim]
- observations/images/{cam_name}: [T, H, W, C]
"""
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.camera_names = camera_names or []
self.max_open_files = max(1, int(max_open_files))
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
self.dataset_dir = Path(dataset_dir)
if not self.dataset_dir.exists():
raise FileNotFoundError(f"数据集目录不存在: {dataset_dir}")
# 查找 HDF5 文件
self.hdf5_files = sorted(self.dataset_dir.glob("*.hdf5"))
if not self.hdf5_files:
self.hdf5_files = sorted(self.dataset_dir.glob("episode_*.hdf5"))
if not self.hdf5_files:
raise FileNotFoundError(f"{dataset_dir} 中未找到 HDF5 文件")
# 构建 episode 索引(只存储元数据,不加载数据)
self.episodes = {}
self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
for ep_idx, hdf5_path in enumerate(self.hdf5_files):
with h5py.File(hdf5_path, 'r') as f:
T = f['action'].shape[0]
start_idx = len(self.frame_meta)
for t in range(T):
self.frame_meta.append({
"ep_idx": ep_idx,
"frame_idx": t,
"hdf5_path": hdf5_path,
})
self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta)))
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)}")
def __len__(self):
return len(self.frame_meta)
def _close_all_files(self) -> None:
"""关闭当前 worker 内缓存的所有 HDF5 文件句柄。"""
for f in self._file_cache.values():
try:
f.close()
except Exception:
pass
self._file_cache.clear()
def _get_h5_file(self, hdf5_path: Union[str, Path]) -> h5py.File:
"""
获取 HDF5 文件句柄worker 内 LRU 缓存)。
注意:缓存的是文件句柄,不是帧数据本身。
"""
key = str(hdf5_path)
if key in self._file_cache:
self._file_cache.move_to_end(key)
return self._file_cache[key]
# 超过上限时淘汰最久未使用的句柄
if len(self._file_cache) >= self.max_open_files:
_, old_file = self._file_cache.popitem(last=False)
try:
old_file.close()
except Exception:
pass
f = h5py.File(key, 'r')
self._file_cache[key] = f
return f
def _load_frame(self, idx: int) -> Dict:
"""从 HDF5 文件懒加载单帧数据"""
meta = self.frame_meta[idx]
f = self._get_h5_file(meta["hdf5_path"])
frame = {
"episode_index": meta["ep_idx"],
"frame_index": meta["frame_idx"],
"task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown",
"observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(),
"action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(),
}
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
for cam_name in self.camera_names:
h5_path = f'observations/images/{cam_name}'
if h5_path in f:
img = f[h5_path][meta["frame_idx"]]
# Resize图像到224x224减少内存和I/O负担
import cv2
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
# 转换为float并归一化到 [0, 1]
img = torch.from_numpy(img).float() / 255.0
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
return frame
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
frame = self._load_frame(idx)
ep_idx = frame["episode_index"]
# 获取当前 episode 的帧索引范围
ep_indices = self.episodes[ep_idx]
ep_start = ep_indices[0]
ep_end = ep_indices[-1]
# ============================================
# 1. 加载观察(过去 obs_horizon 帧)
# ============================================
observations = {
"state": [], # 状态数据
}
# 为每个摄像头初始化独立列表
for cam_name in self.camera_names:
observations[f"observation.{cam_name}"] = []
observation_is_pad = []
for delta in range(-self.obs_horizon + 1, 1): # [-1, 0] for obs_horizon=2
target_idx = idx + delta
# 边界检查
if ep_start <= target_idx <= ep_end:
target_frame = self._load_frame(target_idx)
is_pad = False
else:
# 超出边界,用边界帧填充
if target_idx < ep_start:
target_frame = self._load_frame(ep_start)
else:
target_frame = self._load_frame(ep_end)
is_pad = True
# 收集状态
observations["state"].append(target_frame["observation.state"])
# 收集每个摄像头的图像
for cam_name in self.camera_names:
observations[f"observation.{cam_name}"].append(target_frame[f"observation.{cam_name}"])
observation_is_pad.append(is_pad)
# ============================================
# 2. 加载动作(未来 pred_horizon 帧)
# ============================================
actions = []
action_is_pad = []
for delta in range(self.pred_horizon):
target_idx = idx + delta
if target_idx <= ep_end:
actions.append(self._load_frame(target_idx)["action"])
action_is_pad.append(False)
else:
actions.append(self._load_frame(ep_end)["action"])
action_is_pad.append(True)
# ============================================
# 3. 组装返回数据LeRobotDataset 格式)
# ============================================
result = {
# 状态观察: (obs_horizon, state_dim)
"observation.state": torch.stack(observations["state"]),
"observation_is_pad": torch.tensor(observation_is_pad, dtype=torch.bool),
# 动作: (pred_horizon, action_dim)
"action": torch.stack(actions),
"action_is_pad": torch.tensor(action_is_pad, dtype=torch.bool),
# 任务
"task": frame["task"],
}
# 图像:每个摄像头独立的 key
# 形状: (obs_horizon, C, H, W)
for cam_name in self.camera_names:
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
return result
@property
def camera_keys(self) -> list[str]:
"""获取所有相机键名 (LeRobotDataset 格式)"""
return [f"observation.{cam_name}" for cam_name in self.camera_names]
@property
def camera_info(self) -> dict:
"""获取相机信息"""
if not self.camera_names:
return {}
# 从第一个样本获取形状
sample = self[0]
info = {}
for cam_name in self.camera_names:
key = f"observation.{cam_name}"
if key in sample:
info[key] = {
"shape": sample[key].shape,
"dtype": str(sample[key].dtype),
}
return info
def __del__(self):
self._close_all_files()

View File

View File

@@ -0,0 +1,4 @@
# Backbone models
from .resnet_diffusion import ResNetDiffusionBackbone
__all__ = ["ResNetBackbone", "ResNetDiffusionBackbone"]

View File

@@ -0,0 +1,372 @@
from roboimi.vla.core.interfaces import VLABackbone
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
from typing import Callable, Optional, Tuple, Union
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
Args:
root_module: 需要替换子模块的根模块
predicate: 接受一个模块作为参数,如果该模块需要被替换则返回 True。
func: 接受一个模块作为参数,并返回一个新的模块来替换它。
Returns:
子模块已被替换的根模块。
"""
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
parent_module = root_module.get_submodule(".".join(parents))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# 验证所有 BN 是否已被替换
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
return root_module
class SpatialSoftmax(nn.Module):
"""
Finn 等人在 "Deep Spatial Autoencoders for Visuomotor Learning" 中描述的空间软 Argmax 操作
(https://huggingface.co/papers/1509.06113)。这是 robomimic 实现的一个最小移植版本。
"""
def __init__(self, input_shape, num_kp=None):
"""
Args:
input_shape (list): (C, H, W) 输入特征图形状。
num_kp (int): 输出中的关键点数量。如果为 None输出将具有与输入相同的通道数。
"""
super().__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._out_c = num_kp
else:
self.nets = None
self._out_c = self._in_c
# 我们可以直接使用 torch.linspace但这似乎与 numpy 的行为略有不同
# 并且会导致预训练模型的 pc_success 略有下降。
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# 注册为 buffer以便将其移动到正确的设备。
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
Args:
features: (B, C, H, W) 输入特征图。
Returns:
(B, K, 2) 关键点的图像空间坐标。
"""
if self.nets is not None:
features = self.nets(features)
# [B, K, H, W] -> [B * K, H * W],其中 K 是关键点数量
features = features.reshape(-1, self._in_h * self._in_w)
# 2d softmax 归一化
attention = F.softmax(features, dim=-1)
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] 用于 x 和 y 维度的空间坐标均值
expected_xy = attention @ self.pos_grid
# 重塑为 [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
return feature_keypoints
class _SingleRgbEncoder(nn.Module):
"""单个摄像头的 RGB 编码器,支持独立或共享使用"""
def __init__(
self,
vision_backbone: str,
pretrained_backbone_weights: str | None,
input_shape: Tuple[int, int, int],
crop_shape: Optional[Tuple[int, int]],
crop_is_random: bool,
use_group_norm: bool,
spatial_softmax_num_keypoints: int,
freeze_backbone: bool = True, # 新增是否冻结backbone
):
super().__init__()
# 设置可选的预处理
if crop_shape is not None:
self.do_crop = True
# 评估时始终使用中心裁剪
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
if crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
crop_shape = input_shape[1:]
# 设置骨干网络
backbone_model = getattr(torchvision.models, vision_backbone)(
weights=pretrained_backbone_weights
)
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if use_group_norm:
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# 冻结backbone参数可选
if freeze_backbone:
for param in self.backbone.parameters():
param.requires_grad = False
# 设置池化和最终层
# 使用试运行来获取特征图形状
dummy_shape = (1, input_shape[0], *crop_shape)
with torch.no_grad():
dummy_out = self.backbone(torch.zeros(dummy_shape))
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
self.pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
self.feature_dim = spatial_softmax_num_keypoints * 2
self.out = nn.Linear(spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
# 注册ImageNet标准化参数为buffer会自动移到GPU
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward_single_image(self, x: torch.Tensor) -> torch.Tensor:
if self.do_crop:
x = self.maybe_random_crop(x) if self.training else self.center_crop(x)
# ImageNet标准化预训练权重期望的输入分布
x = (x - self.mean) / self.std
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
return x
class ResNetDiffusionBackbone(VLABackbone):
def __init__(
self,
vision_backbone: str = "resnet18",
pretrained_backbone_weights: str | None = None,
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W)
crop_shape: Optional[Tuple[int, int]] = None,
crop_is_random: bool = True,
use_group_norm: bool = True,
spatial_softmax_num_keypoints: int = 32,
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
freeze_backbone: bool = True, # 新增是否冻结ResNet backbone推荐True
):
super().__init__()
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
self.num_cameras = num_cameras
if use_separate_rgb_encoder_per_camera:
# 独立编码器模式:为每个摄像头创建独立的编码器
encoders = [
_SingleRgbEncoder(
vision_backbone=vision_backbone,
pretrained_backbone_weights=pretrained_backbone_weights,
input_shape=input_shape,
crop_shape=crop_shape,
crop_is_random=crop_is_random,
use_group_norm=use_group_norm,
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
freeze_backbone=freeze_backbone,
)
for _ in range(num_cameras)
]
self.rgb_encoder = nn.ModuleList(encoders)
# 重要output_dim 始终表示单个编码器的特征维度(与 lerobot 保持一致)
self.feature_dim = encoders[0].feature_dim
else:
# 共享编码器模式:所有摄像头共享同一个编码器
self.rgb_encoder = _SingleRgbEncoder(
vision_backbone=vision_backbone,
pretrained_backbone_weights=pretrained_backbone_weights,
input_shape=input_shape,
crop_shape=crop_shape,
crop_is_random=crop_is_random,
use_group_norm=use_group_norm,
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
freeze_backbone=freeze_backbone,
)
self.feature_dim = self.rgb_encoder.feature_dim
def forward(self, images):
"""
Args:
images: Dict[str, Tensor], 每个摄像头的图像
形状: {cam_name: (B, T, C, H, W)}
Returns:
Tensor: (B, T, total_feature_dim)
"""
any_tensor = next(iter(images.values()))
B, T = any_tensor.shape[:2]
cam_names = sorted(images.keys())
if self.use_separate_rgb_encoder_per_camera:
# 独立编码器模式:每个摄像头使用对应的编码器
features_all = []
for cam_idx, cam_name in enumerate(cam_names):
img = images[cam_name]
encoder = self.rgb_encoder[cam_idx]
features = encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
features_all.append(features)
return torch.cat(features_all, dim=1).view(B, T, -1)
else:
# 共享编码器模式:所有摄像头共享同一个编码器
features_all = []
for cam_name in cam_names:
img = images[cam_name]
features = self.rgb_encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
features_all.append(features)
return torch.cat(features_all, dim=1).view(B, T, -1)
@property
def output_dim(self):
return self.feature_dim
if __name__ == "__main__":
print("=" * 60)
print("🚀 Testing ResNetDiffusionBackbone")
print("=" * 60)
# Configuration
B, T = 2, 5
C, H, W = 3, 96, 96
crop_h, crop_w = 84, 84
num_keypoints = 32
feature_dim_per_cam = num_keypoints * 2
# Create dummy input (2 cameras)
images = {
"cam_high": torch.randn(B, T, C, H, W),
"cam_wrist": torch.randn(B, T, C, H, W)
}
num_cameras = len(images)
# ============================================================================
# Test 1: Shared Encoder (默认模式)
# ============================================================================
print("\n[Test 1] Shared Encoder Mode")
print("-" * 60)
backbone_shared = ResNetDiffusionBackbone(
vision_backbone="resnet18",
pretrained_backbone_weights=None, # Speed up test
input_shape=(C, H, W),
crop_shape=(crop_h, crop_w),
crop_is_random=True,
use_group_norm=True,
spatial_softmax_num_keypoints=num_keypoints,
use_separate_rgb_encoder_per_camera=False, # 共享编码器
)
print(f"✅ Shared encoder model instantiated")
print(f" Output dim per camera: {feature_dim_per_cam}")
print(f" Number of cameras: {num_cameras}")
print(f" Expected total dim: {num_cameras * feature_dim_per_cam}")
output = backbone_shared(images)
print(f"\n🔄 Forward pass completed")
print(f" Input shapes: {[v.shape for v in images.values()]}")
print(f" Output shape: {output.shape}")
expected_dim = num_cameras * feature_dim_per_cam
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
print(f"✨ Test passed!")
# ============================================================================
# Test 2: Separate Encoders (独立编码器模式)
# ============================================================================
print("\n[Test 2] Separate Encoders Mode")
print("-" * 60)
backbone_separate = ResNetDiffusionBackbone(
vision_backbone="resnet18",
pretrained_backbone_weights=None, # Speed up test
input_shape=(C, H, W),
crop_shape=(crop_h, crop_w),
crop_is_random=True,
use_group_norm=True,
spatial_softmax_num_keypoints=num_keypoints,
use_separate_rgb_encoder_per_camera=True, # 独立编码器
num_cameras=num_cameras,
)
print(f"✅ Separate encoders model instantiated")
print(f" Output dim per camera: {feature_dim_per_cam}")
print(f" Number of cameras: {num_cameras}")
print(f" Number of encoders: {len(backbone_separate.rgb_encoder)}")
output = backbone_separate(images)
print(f"\n🔄 Forward pass completed")
print(f" Input shapes: {[v.shape for v in images.values()]}")
print(f" Output shape: {output.shape}")
expected_dim = num_cameras * feature_dim_per_cam
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
print(f"✨ Test passed!")
# ============================================================================
# Test 3: Verify parameters count
# ============================================================================
print("\n[Test 3] Parameter Count Comparison")
print("-" * 60)
shared_params = sum(p.numel() for p in backbone_shared.parameters())
separate_params = sum(p.numel() for p in backbone_separate.parameters())
print(f" Shared encoder parameters: {shared_params:,}")
print(f" Separate encoders parameters: {separate_params:,}")
print(f" Ratio: {separate_params / shared_params:.2f}x")
assert separate_params > shared_params, "Separate encoders should have more parameters"
print(f"✨ Verification passed!")
# ============================================================================
# Test 4: Verify independent parameters
# ============================================================================
print("\n[Test 4] Verify Independent Parameters")
print("-" * 60)
# Check that encoders have independent parameters
encoder_0_first_param = list(backbone_separate.rgb_encoder[0].parameters())[0]
encoder_1_first_param = list(backbone_separate.rgb_encoder[1].parameters())[0]
# Modify first encoder's parameter
with torch.no_grad():
encoder_0_first_param += 1.0
# Verify they are not the same tensor
assert not torch.allclose(encoder_0_first_param, encoder_1_first_param), \
"Encoders should have independent parameters"
print(f"✅ Encoders have independent parameters")
print(f"✨ All tests passed!")
print("\n" + "=" * 60)
print("🎉 All tests completed successfully!")
print("=" * 60)

View File

@@ -0,0 +1,5 @@
# Action Head models
from .conditional_unet1d import ConditionalUnet1D
from .transformer1d import Transformer1D
__all__ = ["ConditionalUnet1D", "Transformer1D"]

View File

@@ -0,0 +1,256 @@
# Diffusion Policy Action Head 实现
import torch
import torch.nn as nn
from typing import Dict, Optional
from diffusers import DDPMScheduler
from roboimi.vla.core.interfaces import VLAHead
from typing import Union
import logging
import torch
import torch.nn as nn
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
import math
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class ConditionalResidualBlock1D(nn.Module):
def __init__(self,
in_channels,
out_channels,
cond_dim,
kernel_size=3,
n_groups=8,
cond_predict_scale=False):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
])
cond_channels = out_channels
if cond_predict_scale:
cond_channels = out_channels * 2
self.cond_predict_scale = cond_predict_scale
self.out_channels = out_channels
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
Rearrange('batch t -> batch t 1'),
)
# make sure dimensions compatible
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
if in_channels != out_channels else nn.Identity()
def forward(self, x, cond):
'''
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
'''
out = self.blocks[0](x)
embed = self.cond_encoder(cond)
if self.cond_predict_scale:
embed = embed.reshape(
embed.shape[0], 2, self.out_channels, 1)
scale = embed[:,0,...]
bias = embed[:,1,...]
out = scale * out + bias
else:
out = out + embed
out = self.blocks[1](out)
out = out + self.residual_conv(x)
return out
class ConditionalUnet1D(nn.Module):
def __init__(self,
input_dim,
global_cond_dim=None,
diffusion_step_embed_dim=256,
down_dims=[256,512,1024],
kernel_size=3,
n_groups=8,
cond_predict_scale=False
):
super().__init__()
all_dims = [input_dim] + list(down_dims)
start_dim = down_dims[0]
dsed = diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed),
nn.Linear(dsed, dsed * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
)
cond_dim = dsed
if global_cond_dim is not None:
cond_dim += global_cond_dim
in_out = list(zip(all_dims[:-1], all_dims[1:]))
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList([
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale
),
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale
),
])
down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
down_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
ConditionalResidualBlock1D(
dim_out, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
up_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
up_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_out*2, dim_in, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
ConditionalResidualBlock1D(
dim_in, dim_in, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
nn.Conv1d(start_dim, input_dim, 1),
)
self.diffusion_step_encoder = diffusion_step_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
global_cond=None,
**kwargs):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
global_cond: (B,global_cond_dim)
output: (B,T,input_dim)
"""
sample = einops.rearrange(sample, 'b h t -> b t h')
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
if global_cond is not None:
global_feature = torch.cat([
global_feature, global_cond
], axis=-1)
x = sample
h = []
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
h.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, 'b t h -> b h t')
return x

View File

@@ -0,0 +1,146 @@
import torch
import torch.nn as nn
from types import SimpleNamespace
from typing import Optional, Union
from pathlib import Path
import importlib.util
def _load_gr00t_dit():
repo_root = Path(__file__).resolve().parents[4]
dit_path = repo_root / "gr00t" / "models" / "dit.py"
spec = importlib.util.spec_from_file_location("gr00t_dit_standalone", dit_path)
if spec is None or spec.loader is None:
raise ImportError(f"Unable to load DiT from {dit_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.DiT
DiT = _load_gr00t_dit()
class Gr00tDiT1D(nn.Module):
"""
Adapter that wraps gr00t DiT with the same call signature used by VLA heads.
Expected forward interface:
- sample: (B, T_action, input_dim)
- timestep: (B,) or scalar diffusion timestep
- cond: (B, T_obs, cond_dim)
"""
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int,
cond_dim: int,
n_layer: int = 8,
n_head: int = 8,
n_emb: int = 256,
hidden_dim: int = 256,
mlp_ratio: int = 4,
dropout: float = 0.1,
add_action_pos_emb: bool = True,
add_cond_pos_emb: bool = True,
):
super().__init__()
if cond_dim <= 0:
raise ValueError("Gr00tDiT1D requires cond_dim > 0.")
self.horizon = horizon
self.n_obs_steps = n_obs_steps
self.input_proj = nn.Linear(input_dim, n_emb)
self.cond_proj = nn.Linear(cond_dim, n_emb)
self.output_proj = nn.Linear(hidden_dim, output_dim)
self.action_pos_emb = (
nn.Parameter(torch.zeros(1, horizon, n_emb))
if add_action_pos_emb
else None
)
self.cond_pos_emb = (
nn.Parameter(torch.zeros(1, n_obs_steps, n_emb))
if add_cond_pos_emb
else None
)
args = SimpleNamespace(
embed_dim=n_emb,
nheads=n_head,
mlp_ratio=mlp_ratio,
dropout=dropout,
num_layers=n_layer,
hidden_dim=hidden_dim,
)
self.dit = DiT(args, cross_attention_dim=n_emb)
self._init_weights()
def _init_weights(self):
if self.action_pos_emb is not None:
nn.init.normal_(self.action_pos_emb, mean=0.0, std=0.02)
if self.cond_pos_emb is not None:
nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
def _normalize_timesteps(
self,
timestep: Union[torch.Tensor, float, int],
batch_size: int,
device: torch.device,
) -> torch.Tensor:
if not torch.is_tensor(timestep):
timesteps = torch.tensor([timestep], device=device)
else:
timesteps = timestep.to(device)
if timesteps.ndim == 0:
timesteps = timesteps[None]
if timesteps.shape[0] != batch_size:
timesteps = timesteps.expand(batch_size)
return timesteps.long()
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
if cond is None:
raise ValueError("`cond` is required for Gr00tDiT1D forward.")
bsz, t_act, _ = sample.shape
if t_act > self.horizon:
raise ValueError(
f"sample length {t_act} exceeds configured horizon {self.horizon}"
)
hidden_states = self.input_proj(sample)
if self.action_pos_emb is not None:
hidden_states = hidden_states + self.action_pos_emb[:, :t_act, :]
encoder_hidden_states = self.cond_proj(cond)
if self.cond_pos_emb is not None:
t_obs = encoder_hidden_states.shape[1]
if t_obs > self.n_obs_steps:
raise ValueError(
f"cond length {t_obs} exceeds configured n_obs_steps {self.n_obs_steps}"
)
encoder_hidden_states = (
encoder_hidden_states + self.cond_pos_emb[:, :t_obs, :]
)
timesteps = self._normalize_timesteps(
timestep, batch_size=bsz, device=sample.device
)
dit_output = self.dit(
hidden_states=hidden_states,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
)
return self.output_proj(dit_output)

View File

@@ -0,0 +1,396 @@
"""
Transformer-based Diffusion Policy Head
使用Transformer架构Encoder-Decoder替代UNet进行噪声预测。
支持通过Cross-Attention注入全局条件观测特征
"""
import math
import torch
import torch.nn as nn
from typing import Optional
class SinusoidalPosEmb(nn.Module):
"""正弦位置编码(用于时间步嵌入)"""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Transformer1D(nn.Module):
"""
Transformer-based 1D Diffusion Model
使用Encoder-Decoder架构
- Encoder: 处理条件(观测 + 时间步)
- Decoder: 通过Cross-Attention预测噪声
Args:
input_dim: 输入动作维度
output_dim: 输出动作维度
horizon: 预测horizon长度
n_obs_steps: 观测步数
cond_dim: 条件维度
n_layer: Transformer层数
n_head: 注意力头数
n_emb: 嵌入维度
p_drop_emb: Embedding dropout
p_drop_attn: Attention dropout
causal_attn: 是否使用因果注意力(自回归)
n_cond_layers: Encoder层数0表示使用MLP
"""
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int = None,
cond_dim: int = 0,
n_layer: int = 8,
n_head: int = 8,
n_emb: int = 256,
p_drop_emb: float = 0.1,
p_drop_attn: float = 0.1,
causal_attn: bool = False,
obs_as_cond: bool = False,
n_cond_layers: int = 0
):
super().__init__()
# 计算序列长度
if n_obs_steps is None:
n_obs_steps = horizon
T = horizon
T_cond = 1 # 时间步token数量
# 确定是否使用观测作为条件
obs_as_cond = cond_dim > 0
if obs_as_cond:
T_cond += n_obs_steps
# 保存配置
self.T = T
self.T_cond = T_cond
self.horizon = horizon
self.obs_as_cond = obs_as_cond
self.input_dim = input_dim
self.output_dim = output_dim
# ==================== 输入嵌入 ====================
self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
self.drop = nn.Dropout(p_drop_emb)
# ==================== 条件编码 ====================
# 时间步嵌入
self.time_emb = SinusoidalPosEmb(n_emb)
# 观测条件嵌入(可选)
self.cond_obs_emb = None
if obs_as_cond:
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
# 条件位置编码
self.cond_pos_emb = None
if T_cond > 0:
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
# ==================== Encoder ====================
self.encoder = None
self.encoder_only = False
if T_cond > 0:
if n_cond_layers > 0:
# 使用Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True # Pre-LN更稳定
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers
)
else:
# 使用简单的MLP
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb)
)
else:
# Encoder-only模式BERT风格
self.encoder_only = True
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_layer
)
# ==================== Attention Mask ====================
self.mask = None
self.memory_mask = None
if causal_attn:
# 因果mask确保只关注左侧
sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer("mask", mask)
if obs_as_cond:
# 交叉注意力mask
S = T_cond
t, s = torch.meshgrid(
torch.arange(T),
torch.arange(S),
indexing='ij'
)
mask = t >= (s - 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer('memory_mask', mask)
# ==================== Decoder ====================
if not self.encoder_only:
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True
)
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_layer
)
# ==================== 输出头 ====================
self.ln_f = nn.LayerNorm(n_emb)
self.head = nn.Linear(n_emb, output_dim)
# ==================== 初始化 ====================
self.apply(self._init_weights)
# 打印参数量
total_params = sum(p.numel() for p in self.parameters())
print(f"Transformer1D parameters: {total_params:,}")
def _init_weights(self, module):
"""初始化权重"""
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
# MultiheadAttention的权重初始化
for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']:
weight = getattr(module, name, None)
if weight is not None:
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
for name in ['in_proj_bias', 'bias_k', 'bias_v']:
bias = getattr(module, name, None)
if bias is not None:
torch.nn.init.zeros_(bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, Transformer1D):
# 位置编码初始化
torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
if self.cond_pos_emb is not None:
torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
def forward(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
cond: Optional[torch.Tensor] = None,
**kwargs
):
"""
前向传播
Args:
sample: (B, T, input_dim) 输入序列(加噪动作)
timestep: (B,) 时间步
cond: (B, T', cond_dim) 条件序列(观测特征)
Returns:
(B, T, output_dim) 预测的噪声
"""
# ==================== 处理时间步 ====================
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# 扩展到batch维度
timesteps = timesteps.expand(sample.shape[0])
time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb)
# ==================== 处理输入 ====================
input_emb = self.input_emb(sample) # (B, T, n_emb)
# ==================== Encoder-Decoder模式 ====================
if not self.encoder_only:
# --- Encoder: 处理条件 ---
cond_embeddings = time_emb
if self.obs_as_cond and cond is not None:
# 添加观测条件
cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb)
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
# 添加位置编码
tc = cond_embeddings.shape[1]
pos_emb = self.cond_pos_emb[:, :tc, :]
x = self.drop(cond_embeddings + pos_emb)
# 通过encoder
memory = self.encoder(x) # (B, T_cond, n_emb)
# --- Decoder: 预测噪声 ---
# 添加位置编码到输入
token_embeddings = input_emb
t = token_embeddings.shape[1]
pos_emb = self.pos_emb[:, :t, :]
x = self.drop(token_embeddings + pos_emb)
# Cross-Attention: Query来自输入Key/Value来自memory
x = self.decoder(
tgt=x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask
)
# ==================== Encoder-Only模式 ====================
else:
# BERT风格时间步作为特殊token
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
t = token_embeddings.shape[1]
pos_emb = self.pos_emb[:, :t, :]
x = self.drop(token_embeddings + pos_emb)
x = self.encoder(src=x, mask=self.mask)
x = x[:, 1:, :] # 移除时间步token
# ==================== 输出头 ====================
x = self.ln_f(x)
x = self.head(x) # (B, T, output_dim)
return x
# ============================================================================
# 便捷函数创建Transformer1D模型
# ============================================================================
def create_transformer1d(
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int,
cond_dim: int,
n_layer: int = 8,
n_head: int = 8,
n_emb: int = 256,
**kwargs
) -> Transformer1D:
"""
创建Transformer1D模型的便捷函数
Args:
input_dim: 输入动作维度
output_dim: 输出动作维度
horizon: 预测horizon
n_obs_steps: 观测步数
cond_dim: 条件维度
n_layer: Transformer层数
n_head: 注意力头数
n_emb: 嵌入维度
**kwargs: 其他参数
Returns:
Transformer1D模型
"""
model = Transformer1D(
input_dim=input_dim,
output_dim=output_dim,
horizon=horizon,
n_obs_steps=n_obs_steps,
cond_dim=cond_dim,
n_layer=n_layer,
n_head=n_head,
n_emb=n_emb,
**kwargs
)
return model
if __name__ == "__main__":
print("=" * 80)
print("Testing Transformer1D")
print("=" * 80)
# 配置
B = 4
T = 16
action_dim = 16
obs_horizon = 2
cond_dim = 416 # vision + state特征维度
# 创建模型
model = Transformer1D(
input_dim=action_dim,
output_dim=action_dim,
horizon=T,
n_obs_steps=obs_horizon,
cond_dim=cond_dim,
n_layer=4,
n_head=8,
n_emb=256,
causal_attn=False
)
# 测试前向传播
sample = torch.randn(B, T, action_dim)
timestep = torch.randint(0, 100, (B,))
cond = torch.randn(B, obs_horizon, cond_dim)
output = model(sample, timestep, cond)
print(f"\n输入:")
print(f" sample: {sample.shape}")
print(f" timestep: {timestep.shape}")
print(f" cond: {cond.shape}")
print(f"\n输出:")
print(f" output: {output.shape}")
print(f"\n✅ 测试通过!")

View File

@@ -0,0 +1,126 @@
"""
归一化模块 - 统一训练和推理的归一化逻辑
支持两种归一化方式:
1. Gaussian (z-score): (x - mean) / std
2. MinMax: 2 * (x - min) / (max - min) - 1 -> [-1, 1]
"""
import torch
import torch.nn as nn
from typing import Optional, Dict, Literal
class NormalizationModule(nn.Module):
"""
统一的归一化模块
用于在 Agent 内部对 qpos 和 action 进行归一化/反归一化
"""
def __init__(
self,
stats: Optional[Dict] = None,
normalization_type: Literal['gaussian', 'min_max'] = None,
):
"""
Args:
stats: 数据集统计信息字典,格式:
{
'qpos_mean': [...],
'qpos_std': [...],
'qpos_min': [...], # 仅 min_max 需要
'qpos_max': [...], # 仅 min_max 需要
'action_mean': [...],
'action_std': [...],
'action_min': [...], # 仅 min_max 需要
'action_max': [...], # 仅 min_max 需要
}
normalization_type: 归一化类型 ('gaussian''min_max')
"""
super().__init__()
self.normalization_type = normalization_type
self.enabled = stats is not None
if self.enabled:
if self.normalization_type == 'gaussian':
self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32))
self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32))
self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32))
self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32))
elif self.normalization_type == 'min_max':
self.register_buffer('qpos_min', torch.tensor(stats['qpos_min'], dtype=torch.float32))
self.register_buffer('qpos_max', torch.tensor(stats['qpos_max'], dtype=torch.float32))
self.register_buffer('action_min', torch.tensor(stats['action_min'], dtype=torch.float32))
self.register_buffer('action_max', torch.tensor(stats['action_max'], dtype=torch.float32))
def normalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
"""归一化 qpos"""
if not self.enabled:
return qpos
if self.normalization_type == 'gaussian':
return (qpos - self.qpos_mean) / self.qpos_std
elif self.normalization_type == 'min_max':
return 2 * (qpos - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
else:
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
def denormalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
"""反归一化 qpos"""
if not self.enabled:
return qpos
if self.normalization_type == 'gaussian':
return qpos * self.qpos_std + self.qpos_mean
elif self.normalization_type == 'min_max':
return (qpos + 1) / 2 * (self.qpos_max - self.qpos_min) + self.qpos_min
else:
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
def normalize_action(self, action: torch.Tensor) -> torch.Tensor:
"""归一化 action"""
if not self.enabled:
return action
if self.normalization_type == 'gaussian':
return (action - self.action_mean) / self.action_std
elif self.normalization_type == 'min_max':
return 2 * (action - self.action_min) / (self.action_max - self.action_min) - 1
else:
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
def denormalize_action(self, action: torch.Tensor) -> torch.Tensor:
"""反归一化 action"""
if not self.enabled:
return action
if self.normalization_type == 'gaussian':
return action * self.action_std + self.action_mean
elif self.normalization_type == 'min_max':
return (action + 1) / 2 * (self.action_max - self.action_min) + self.action_min
else:
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
def get_stats(self) -> Optional[Dict]:
"""导出统计信息(用于保存到 checkpoint"""
if not self.enabled:
return None
stats = {
'normalization_type': self.normalization_type,
}
if self.normalization_type == 'gaussian':
stats['qpos_mean'] = self.qpos_mean.cpu().tolist()
stats['qpos_std'] = self.qpos_std.cpu().tolist()
stats['action_mean'] = self.action_mean.cpu().tolist()
stats['action_std'] = self.action_std.cpu().tolist()
elif self.normalization_type == 'min_max':
stats['qpos_min'] = self.qpos_min.cpu().tolist()
stats['qpos_max'] = self.qpos_max.cpu().tolist()
stats['action_min'] = self.action_min.cpu().tolist()
stats['action_max'] = self.action_max.cpu().tolist()
return stats

View File

@@ -0,0 +1,18 @@
from torch import nn
class IdentityStateEncoder(nn.Module):
def __init__(self):
super().__init__()
def forward(self, state):
return state
class IdentityActionEncoder(nn.Module):
def __init__(self):
super().__init__()
def forward(self, action):
return action

View File

@@ -0,0 +1,87 @@
import h5py
import numpy as np
import os
import glob
import pickle
def get_data_stats(dataset_dir):
"""
计算 action 和 qpos 的 Min, Max, Mean, Std
输出扁平化结构(与 NormalizationModule 期望一致):
{
'action_mean': [...],
'action_std': [...],
'action_min': [...],
'action_max': [...],
'qpos_mean': [...],
'qpos_std': [...],
'qpos_min': [...],
'qpos_max': [...],
}
"""
files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
print(f"Found {len(files)} episodes in {dataset_dir}")
all_actions = []
all_qpos = []
print("Reading data...")
for file_path in files:
with h5py.File(file_path, 'r') as f:
action = f['action'][:]
qpos = f['observations']['qpos'][:]
all_actions.append(action)
all_qpos.append(qpos)
# 拼接所有数据
all_actions = np.concatenate(all_actions, axis=0)
all_qpos = np.concatenate(all_qpos, axis=0)
print(f"Total steps: {all_actions.shape[0]}")
# --- 核心计算部分 ---
# 计算统计量
action_mean = np.mean(all_actions, axis=0)
action_std = np.std(all_actions, axis=0)
action_min = np.min(all_actions, axis=0)
action_max = np.max(all_actions, axis=0)
qpos_mean = np.mean(all_qpos, axis=0)
qpos_std = np.std(all_qpos, axis=0)
qpos_min = np.min(all_qpos, axis=0)
qpos_max = np.max(all_qpos, axis=0)
# 修正标准差(防止除以 0
eps = 1e-8
action_std_corrected = np.where(action_std < eps, eps, action_std)
qpos_std_corrected = np.where(qpos_std < eps, eps, qpos_std)
# 转换为扁平化结构(与 NormalizationModule 期望一致)
stats_flat = {
'action_mean': action_mean,
'action_std': action_std_corrected,
'action_min': action_min,
'action_max': action_max,
'qpos_mean': qpos_mean,
'qpos_std': qpos_std_corrected,
'qpos_min': qpos_min,
'qpos_max': qpos_max
}
return stats_flat
if __name__ == "__main__":
DATASET_DIR = 'roboimi/demos/dataset/sim_transfer'
OUTPUT_PATH = DATASET_DIR + "/dataset_stats.pkl"
stats_flat = get_data_stats(DATASET_DIR)
# 打印检查
print("\n--- Stats Computed ---")
print(f"Action Mean shape: {stats_flat['action_mean'].shape}")
print(f"Action Std shape: {stats_flat['action_std'].shape}")
# 保存
with open(OUTPUT_PATH, 'wb') as f:
pickle.dump(stats_flat, f)
print(f"\nStats saved to {OUTPUT_PATH}")