Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2376f494d2 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -124,5 +124,3 @@ GEMINI.md
|
||||
|
||||
# Copilot
|
||||
.github/copilot-instructions.md
|
||||
|
||||
.hydra/
|
||||
36
README.en.md
Normal file
36
README.en.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# robo-imi-act
|
||||
|
||||
#### Description
|
||||
{**When you're done, you can delete the content in this README and update the file with details for others getting started with your repository**}
|
||||
|
||||
#### Software Architecture
|
||||
Software architecture description
|
||||
|
||||
#### Installation
|
||||
|
||||
1. xxxx
|
||||
2. xxxx
|
||||
3. xxxx
|
||||
|
||||
#### Instructions
|
||||
|
||||
1. xxxx
|
||||
2. xxxx
|
||||
3. xxxx
|
||||
|
||||
#### Contribution
|
||||
|
||||
1. Fork the repository
|
||||
2. Create Feat_xxx branch
|
||||
3. Commit your code
|
||||
4. Create Pull Request
|
||||
|
||||
|
||||
#### Gitee Feature
|
||||
|
||||
1. You can use Readme\_XXX.md to support different languages, such as Readme\_en.md, Readme\_zh.md
|
||||
2. Gitee blog [blog.gitee.com](https://blog.gitee.com)
|
||||
3. Explore open source project [https://gitee.com/explore](https://gitee.com/explore)
|
||||
4. The most valuable open source project [GVP](https://gitee.com/gvp)
|
||||
5. The manual of Gitee [https://gitee.com/help](https://gitee.com/help)
|
||||
6. The most popular members [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/)
|
||||
223
README.md
223
README.md
@@ -1,208 +1,39 @@
|
||||
# RoboIMI
|
||||
# robo-imi-act
|
||||
|
||||
基于 MuJoCo 的机器人仿真与模仿学习框架,实现了使用扩散策略的视觉-语言-动作(VLA)模型,用于机器人操作任务。
|
||||
#### 介绍
|
||||
{**以下是 Gitee 平台说明,您可以替换此简介**
|
||||
Gitee 是 OSCHINA 推出的基于 Git 的代码托管平台(同时支持 SVN)。专为开发者提供稳定、高效、安全的云端软件开发协作平台
|
||||
无论是个人、团队、或是企业,都能够用 Gitee 实现代码托管、项目管理、协作开发。企业项目请看 [https://gitee.com/enterprises](https://gitee.com/enterprises)}
|
||||
|
||||
## 主要特性
|
||||
#### 软件架构
|
||||
软件架构说明
|
||||
|
||||
- **多机器人平台支持**:支持 Diana 和 vx300s 机械臂,可扩展至其他机器人
|
||||
- **扩散策略**:采用最先进的扩散模型(DDPM/DDIM)进行动作序列预测
|
||||
- **视觉-语言-动作模型**:使用 ResNet-18 视觉骨干网络和空间 softmax 进行视觉特征提取
|
||||
- **灵活的控制模式**:支持关节空间和末端执行器(笛卡尔)控制
|
||||
- **Hydra 配置系统**:模块化配置系统,便于实验
|
||||
- **HDF5 数据集格式**:高效存储和加载演示数据
|
||||
- **单臂和双臂任务**:支持单臂和双臂操作任务
|
||||
|
||||
## 安装
|
||||
#### 安装教程
|
||||
|
||||
### 环境要求
|
||||
1. xxxx
|
||||
2. xxxx
|
||||
3. xxxx
|
||||
|
||||
- Python 3.8+
|
||||
- 支持 CUDA 的 GPU(训练时推荐)
|
||||
- Conda 或 Miniconda
|
||||
#### 使用说明
|
||||
|
||||
### 安装步骤
|
||||
1. xxxx
|
||||
2. xxxx
|
||||
3. xxxx
|
||||
|
||||
```bash
|
||||
# 克隆仓库
|
||||
git clone <repository-url>
|
||||
cd robo-imi-act
|
||||
#### 参与贡献
|
||||
|
||||
# 创建并激活 conda 环境
|
||||
conda env create -f environment.yml
|
||||
conda activate roboimi
|
||||
1. Fork 本仓库
|
||||
2. 新建 Feat_xxx 分支
|
||||
3. 提交代码
|
||||
4. 新建 Pull Request
|
||||
|
||||
# 以开发模式安装包
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
#### 特技
|
||||
|
||||
### 1. 数据采集
|
||||
|
||||
在仿真环境中记录演示轨迹:
|
||||
|
||||
```bash
|
||||
# 为 vx300s 机器人记录轨迹
|
||||
python roboimi/demos/record_sim_episodes.py
|
||||
|
||||
# 为 Diana 机器人记录轨迹
|
||||
python roboimi/demos/diana_record_sim_episodes.py
|
||||
```
|
||||
|
||||
轨迹数据以 HDF5 文件格式保存,包含机器人状态、动作和相机观测。
|
||||
|
||||
### 2. 计算数据集统计信息
|
||||
|
||||
训练前需要计算归一化统计数据:
|
||||
|
||||
```bash
|
||||
python roboimi/vla/scripts/calculate_stats.py
|
||||
```
|
||||
|
||||
该命令会生成 `data_stats.pkl` 文件,包含动作和观测的均值/标准差或最小值/最大值。
|
||||
|
||||
### 3. 训练 VLA 模型
|
||||
|
||||
使用采集的数据训练视觉-语言-动作模型:
|
||||
|
||||
```bash
|
||||
# 使用默认配置训练
|
||||
python roboimi/demos/vla_scripts/train_vla.py
|
||||
|
||||
# 覆盖特定参数
|
||||
python roboimi/demos/vla_scripts/train_vla.py train.batch_size=32 train.lr=5e-5 train.max_steps=50000
|
||||
|
||||
# 使用不同的模型架构
|
||||
python roboimi/demos/vla_scripts/train_vla.py agent=resnet_diffusion data=resnet_dataset
|
||||
```
|
||||
|
||||
训练输出保存至 `outputs/<日期>/<时间>/`,模型检查点保存至 `checkpoints/`。
|
||||
|
||||
### 4. 评估模型
|
||||
|
||||
在仿真环境中评估训练好的模型:
|
||||
|
||||
```bash
|
||||
# 使用默认配置评估(使用最佳检查点)
|
||||
python roboimi/demos/vla_scripts/eval_vla.py
|
||||
|
||||
# 指定检查点和评估轮数
|
||||
python roboimi/demos/vla_scripts/eval_vla.py eval.ckpt_path=checkpoints/vla_model_step_8000.pt eval.num_episodes=5
|
||||
|
||||
# 启用动作平滑以获得更流畅的执行
|
||||
python roboimi/demos/vla_scripts/eval_vla.py eval.use_smoothing=true eval.smooth_alpha=0.5
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
robo-imi-act/
|
||||
├── roboimi/
|
||||
│ ├── assets/ # 机器人模型和资源
|
||||
│ │ ├── models/manipulators/ # URDF 和 MuJoCo XML 文件
|
||||
│ │ └── robots/ # 机器人抽象类
|
||||
│ ├── envs/ # 仿真环境
|
||||
│ │ ├── mujoco_base.py # MuJoCo 环境基类
|
||||
│ │ ├── single_base.py # 单臂任务基类
|
||||
│ │ └── double_base.py # 双臂任务基类
|
||||
│ ├── vla/ # 视觉-语言-动作模型
|
||||
│ │ ├── agent.py # VLAAgent(训练与推理)
|
||||
│ │ ├── models/
|
||||
│ │ │ ├── backbones/ # 视觉编码器(ResNet 等)
|
||||
│ │ │ └── heads/ # 策略头(扩散 UNet1D)
|
||||
│ │ ├── conf/ # Hydra 配置文件
|
||||
│ │ └── scripts/ # 训练和工具脚本
|
||||
│ └── demos/ # 演示脚本和示例
|
||||
├── checkpoints/ # 保存的模型检查点
|
||||
├── outputs/ # 训练输出(Hydra)
|
||||
├── environment.yml # Conda 环境配置
|
||||
└── CLAUDE.md # Claude Code 开发指南
|
||||
```
|
||||
|
||||
## 架构设计
|
||||
|
||||
### VLA 训练流程
|
||||
|
||||
```
|
||||
HDF5 轨迹数据 → Dataset → DataLoader → VLAAgent → 模型检查点
|
||||
```
|
||||
|
||||
**模型组件**:
|
||||
- **视觉骨干网络**:ResNet-18 + 空间 softmax,用于从相机图像中提取视觉特征
|
||||
- **扩散头**:条件 UNet1D,使用 DDPM/DDIM 预测动作序列
|
||||
- **VLAAgent**:组合视觉编码器和扩散策略,处理训练和推理
|
||||
|
||||
### 配置系统
|
||||
|
||||
基于 Hydra 的配置文件位于 `roboimi/vla/conf/`:
|
||||
- `config.yaml`:主要训练配置(批次大小、学习率、设备)
|
||||
- `agent/resnet_diffusion.yaml`:模型架构(动作维度、观测维度、时间窗口)
|
||||
- `data/resnet_dataset.yaml`:数据集路径、相机名称、归一化类型
|
||||
- `eval/eval.yaml`:评估设置(检查点路径、轮数、平滑参数)
|
||||
|
||||
使用配置插值保持一致性:`${agent.obs_horizon}`
|
||||
|
||||
### 数据集格式
|
||||
|
||||
HDF5 轨迹文件(`episode_*.hdf5`)包含:
|
||||
- `action`:机器人动作 `[T, action_dim]`
|
||||
- `observations/qpos`:关节位置 `[T, obs_dim]`
|
||||
- `observations/images/<cam_name>`:相机图像 `[T, H, W, C]`
|
||||
|
||||
统计文件(`data_stats.pkl`)存储归一化参数(最小值/最大值/均值/标准差)。
|
||||
|
||||
## 开发指南
|
||||
|
||||
### 添加新机器人
|
||||
|
||||
1. 在 `roboimi/assets/models/manipulators/<robot_name>/` 创建 URDF/XML 文件
|
||||
2. 在 `roboimi/assets/robots/<robot_name>.py` 定义机器人类(继承自 `arm_base.py`)
|
||||
3. 在 `roboimi/envs/<robot_name>_*.py` 创建环境类
|
||||
4. 如需要,在常量中注册机器人
|
||||
|
||||
### 修改 VLA 架构
|
||||
|
||||
1. **自定义骨干网络**:在 `roboimi/vla/models/backbones/` 创建新类,继承 `VLABackbone`
|
||||
2. **自定义头部**:在 `roboimi/vla/models/heads/` 创建新类,继承 `VLAHead`
|
||||
3. **更新配置**:在 `roboimi/vla/conf/agent/` 添加新的 YAML 文件
|
||||
4. **接口定义**:参考 `roboimi/vla/core/interfaces.py` 的抽象基类
|
||||
|
||||
### 训练最佳实践
|
||||
|
||||
- 采集新数据后务必运行 `calculate_stats.py`
|
||||
- 训练时会归一化输入/输出;推理时使用检查点中保存的统计信息进行反归一化
|
||||
- 模型预测 `pred_horizon` 步,但只执行前 `action_horizon` 步
|
||||
- 推理使用 DDIM(10 步)快速采样;训练使用 DDPM(100 步)
|
||||
- 监控验证损失以防止过拟合
|
||||
|
||||
## 技术细节
|
||||
|
||||
- **坐标系**:关节空间(qpos)或末端执行器空间(xyz + rpy + 夹爪)
|
||||
- **动作时间窗口**:`obs_horizon` 为观测窗口,`pred_horizon` 为预测窗口,`action_horizon` 为执行窗口
|
||||
- **归一化**:对稳定训练至关重要 - 训练前务必计算统计信息
|
||||
- **推理加速**:使用 DDIM 调度器,比训练时的 DDPM 快 10 倍
|
||||
- **设备配置**:通过 `train.device` 配置(cuda/cpu)
|
||||
|
||||
## 许可证
|
||||
|
||||
[在此添加许可证信息]
|
||||
|
||||
## 引用
|
||||
|
||||
如果您在研究中使用了本代码库,请引用:
|
||||
|
||||
```bibtex
|
||||
[在此添加引用信息]
|
||||
```
|
||||
|
||||
## 贡献
|
||||
|
||||
欢迎贡献!请随时提交 Pull Request 或开启 Issue。
|
||||
|
||||
## 致谢
|
||||
|
||||
本项目基于以下开源项目构建:
|
||||
- [MuJoCo](https://mujoco.org/) - 物理仿真引擎
|
||||
- [PyTorch](https://pytorch.org/) - 深度学习框架
|
||||
- [Hydra](https://hydra.cc/) - 配置管理系统
|
||||
- [Diffusers](https://github.com/huggingface/diffusers) - 扩散模型库
|
||||
1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md
|
||||
2. Gitee 官方博客 [blog.gitee.com](https://blog.gitee.com)
|
||||
3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解 Gitee 上的优秀开源项目
|
||||
4. [GVP](https://gitee.com/gvp) 全称是 Gitee 最有价值开源项目,是综合评定出的优秀开源项目
|
||||
5. Gitee 官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help)
|
||||
6. Gitee 封面人物是一档用来展示 Gitee 会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/)
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
检查所有 episode 的重复帧情况
|
||||
|
||||
找出哪些 episode 有问题,需要删除或重新收集
|
||||
"""
|
||||
import os
|
||||
import h5py
|
||||
import glob
|
||||
import numpy as np
|
||||
|
||||
|
||||
def check_all_episodes():
|
||||
"""检查所有 episode 的质量"""
|
||||
|
||||
dataset_dir = "roboimi/demos/dataset/sim_transfer"
|
||||
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
|
||||
episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', '')))
|
||||
|
||||
print("="*80)
|
||||
print("所有 Episode 质量检查")
|
||||
print("="*80)
|
||||
|
||||
good_episodes = []
|
||||
bad_episodes = []
|
||||
|
||||
for ep_idx, ep_file in enumerate(episode_files):
|
||||
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
|
||||
|
||||
try:
|
||||
with h5py.File(ep_file, 'r') as f:
|
||||
img_path = '/observations/images/top'
|
||||
if img_path not in f:
|
||||
continue
|
||||
|
||||
images = f[img_path][:]
|
||||
|
||||
# 检查前 50 帧的重复情况
|
||||
check_frames = min(50, len(images))
|
||||
duplicate_count = 0
|
||||
|
||||
for i in range(check_frames - 1):
|
||||
img1 = images[i]
|
||||
img2 = images[i + 1]
|
||||
diff = np.mean(np.abs(img1.astype(float) - img2.astype(float)))
|
||||
|
||||
if diff < 1.0: # 重复
|
||||
duplicate_count += 1
|
||||
|
||||
duplicate_rate = duplicate_count / check_frames * 100
|
||||
|
||||
# 判断质量
|
||||
if duplicate_rate > 10: # 超过10%重复
|
||||
bad_episodes.append((ep_idx, ep_name, duplicate_rate, duplicate_count))
|
||||
status = "❌"
|
||||
else:
|
||||
good_episodes.append((ep_idx, ep_name, duplicate_rate, duplicate_count))
|
||||
status = "✅"
|
||||
|
||||
print(f"{status} Episode {ep_idx:2d}: {duplicate_rate:5.1f}% 重复 ({duplicate_count:2d}/{check_frames}) - {ep_name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Episode {ep_idx}: 错误 - {e}")
|
||||
|
||||
# 总结
|
||||
print("\n" + "="*80)
|
||||
print("总结")
|
||||
print("="*80)
|
||||
print(f"总共检查: {len(episode_files)} 个 episodes")
|
||||
print(f"正常的: {len(good_episodes)} 个 ✅")
|
||||
print(f"有问题的: {len(bad_episodes)} 个 ❌")
|
||||
|
||||
if bad_episodes:
|
||||
print(f"\n有问题的 episodes:")
|
||||
for ep_idx, ep_name, rate, count in bad_episodes:
|
||||
print(f" - episode_{ep_idx}.hdf5: {rate:.1f}% 重复")
|
||||
|
||||
print(f"\n删除命令:")
|
||||
ep_names = [name for _, name, _, _ in bad_episodes]
|
||||
print(f" rm " + " ".join([f"{dataset_dir}/{name}.hdf5" for name in ep_names]))
|
||||
|
||||
print(f"\n建议:")
|
||||
if len(bad_episodes) > 0:
|
||||
print(f" 1. 删除有问题的 {len(bad_episodes)} 个 episodes")
|
||||
print(f" 2. 重新收集数据,或使用剩余的 {len(good_episodes)} 个正常 episodes")
|
||||
else:
|
||||
print(f" ✅ 所有 episodes 都正常,可以直接使用!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_all_episodes()
|
||||
@@ -1,202 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
检查特定帧的图像 - 用于验证数据记录问题
|
||||
|
||||
功能:
|
||||
1. 提取每个 episode 的第 0、1、2 帧图像
|
||||
2. 对比不同 episode 的相同帧号
|
||||
3. 保存图像供人工检查
|
||||
"""
|
||||
import os
|
||||
import h5py
|
||||
import glob
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def check_specific_frames(frame_indices=[0, 1, 2], camera='top', num_episodes=10):
|
||||
"""
|
||||
检查特定帧的图像和 qpos
|
||||
|
||||
Args:
|
||||
frame_indices: 要检查的帧索引列表
|
||||
camera: 相机名称
|
||||
num_episodes: 要检查的 episode 数量
|
||||
"""
|
||||
|
||||
dataset_dir = "roboimi/demos/dataset/sim_transfer"
|
||||
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
|
||||
# 按数字排序
|
||||
episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', '')))
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = f'/tmp/dataset_frames'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
print(f"检查前 {min(num_episodes, len(episode_files))} 个 episode 的特定帧")
|
||||
print(f"帧索引: {frame_indices}")
|
||||
print(f"相机: {camera}")
|
||||
print("="*80)
|
||||
|
||||
# 收集所有数据
|
||||
for ep_idx in range(min(num_episodes, len(episode_files))):
|
||||
ep_file = episode_files[ep_idx]
|
||||
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
|
||||
|
||||
try:
|
||||
with h5py.File(ep_file, 'r') as f:
|
||||
# 读取 qpos
|
||||
qpos = f['/observations/qpos'][:]
|
||||
|
||||
# 读取图像
|
||||
img_path = f'/observations/images/{camera}'
|
||||
if img_path not in f:
|
||||
print(f"Episode {ep_name}: 相机 {camera} 不存在")
|
||||
continue
|
||||
|
||||
images = f[img_path][:]
|
||||
|
||||
print(f"\nEpisode {ep_name}:")
|
||||
print(f" 总帧数: {len(images)}")
|
||||
|
||||
# 保存指定帧
|
||||
for frame_idx in frame_indices:
|
||||
if frame_idx >= len(images):
|
||||
print(f" 帧 {frame_idx}: 超出范围")
|
||||
continue
|
||||
|
||||
# 保存图像
|
||||
img = images[frame_idx]
|
||||
filename = f"{output_dir}/ep{ep_idx:02d}_frame{frame_idx:03d}.png"
|
||||
cv2.imwrite(filename, img)
|
||||
|
||||
# 打印 qpos
|
||||
q = qpos[frame_idx]
|
||||
print(f" 帧 {frame_idx}: qpos[0:3]=[{q[0]:6.2f}, {q[1]:6.2f}, {q[2]:6.2f}], qpos[3]={q[3]:6.2f} → {filename}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Episode {ep_name}: 错误 - {e}")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print(f"✅ 所有图像已保存到: {output_dir}")
|
||||
print(f"\n查看方法:")
|
||||
print(f" eog {output_dir}/*.png")
|
||||
print(f" ")
|
||||
print(f" # 或对比特定帧:")
|
||||
print(f" eog {output_dir}/*_frame000.png # 所有 episode 的第 0 帧")
|
||||
print(f" eog {output_dir}/*_frame001.png # 所有 episode 的第 1 帧")
|
||||
print(f" eog {output_dir}/*_frame002.png # 所有 episode 的第 2 帧")
|
||||
|
||||
|
||||
def compare_frame_across_episodes(frame_idx=0, camera='top', num_episodes=10):
|
||||
"""
|
||||
并排对比所有 episode 的某一帧
|
||||
|
||||
生成一个大的对比图,包含所有 episode 的指定帧
|
||||
"""
|
||||
|
||||
dataset_dir = "roboimi/demos/dataset/sim_transfer"
|
||||
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
|
||||
episode_files = sorted(episode_files, key=lambda x: int(x.split('_')[-1].replace('.hdf5', '')))
|
||||
|
||||
num_compare = min(num_episodes, len(episode_files))
|
||||
cols = 5 # 每行 5 个
|
||||
rows = (num_compare + cols - 1) // cols
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = f'/tmp/dataset_frames'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
print(f"生成对比图: 所有 Episode 的第 {frame_idx} 帧")
|
||||
print("="*80)
|
||||
|
||||
# 收集图像
|
||||
images_compare = []
|
||||
qpos_list = []
|
||||
|
||||
for ep_idx in range(num_compare):
|
||||
ep_file = episode_files[ep_idx]
|
||||
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
|
||||
|
||||
try:
|
||||
with h5py.File(ep_file, 'r') as f:
|
||||
qpos = f['/observations/qpos'][:]
|
||||
img_path = f'/observations/images/{camera}'
|
||||
|
||||
if img_path in f and frame_idx < f[img_path].shape[0]:
|
||||
img = f[img_path][frame_idx]
|
||||
images_compare.append(img)
|
||||
qpos_list.append(qpos[frame_idx])
|
||||
print(f"Episode {ep_name}: qpos[0:3]=[{qpos[frame_idx][0]:.2f}, {qpos[frame_idx][1]:.2f}, {qpos[frame_idx][2]:.2f}]")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Episode {ep_name}: 错误 - {e}")
|
||||
|
||||
if not images_compare:
|
||||
print("❌ 没有收集到图像")
|
||||
return
|
||||
|
||||
# 获取图像尺寸
|
||||
h, w = images_compare[0].shape[:2]
|
||||
|
||||
# 创建对比图
|
||||
compare_img = np.zeros((rows * h + 50, cols * w, 3), dtype=np.uint8)
|
||||
|
||||
for i, (img, qpos) in enumerate(zip(images_compare, qpos_list)):
|
||||
row = i // cols
|
||||
col = i % cols
|
||||
|
||||
y_start = row * h + 30
|
||||
y_end = y_start + h
|
||||
x_start = col * w
|
||||
x_end = x_start + w
|
||||
|
||||
# 调整大小(如果需要)
|
||||
if img.shape[:2] != (h, w):
|
||||
img = cv2.resize(img, (w, h))
|
||||
|
||||
compare_img[y_start:y_end, x_start:x_end] = img
|
||||
|
||||
# 添加信息
|
||||
ep_name = f"Ep {i}"
|
||||
cv2.putText(compare_img, ep_name, (x_start + 10, row * h + 20),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
|
||||
cv2.putText(compare_img, f"qpos[3]={qpos[3]:.2f}", (x_start + 10, y_end - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
|
||||
|
||||
# 保存对比图
|
||||
output_path = f"{output_dir}/compare_frame{frame_idx:03d}.png"
|
||||
cv2.imwrite(output_path, compare_img)
|
||||
|
||||
print(f"\n✅ 对比图已保存: {output_path}")
|
||||
print(f" 查看方法: eog {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
print("="*80)
|
||||
print("特定帧检查工具")
|
||||
print("="*80)
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
frame_idx = int(sys.argv[1])
|
||||
compare_frame_across_episodes(frame_idx=frame_idx, camera='top', num_episodes=10)
|
||||
else:
|
||||
# 默认检查第 0、1、2 帧
|
||||
check_specific_frames(frame_indices=[0, 1, 2], camera='top', num_episodes=10)
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("生成对比图...")
|
||||
print("="*80)
|
||||
|
||||
# 生成第 0 帧的对比图
|
||||
compare_frame_across_episodes(frame_idx=0, camera='top', num_episodes=10)
|
||||
compare_frame_across_episodes(frame_idx=1, camera='top', num_episodes=10)
|
||||
compare_frame_across_episodes(frame_idx=2, camera='top', num_episodes=10)
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("其他用法:")
|
||||
print(" python check_specific_frames.py 0 # 只检查第 0 帧")
|
||||
print(" python check_specific_frames.py 1 # 只检查第 1 帧")
|
||||
print("="*80)
|
||||
@@ -1,238 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamConfig
|
||||
from lerobot.optim.schedulers import DiffuserSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("diffusion")
|
||||
@dataclass
|
||||
class DiffusionConfig(PreTrainedConfig):
|
||||
"""Configuration class for DiffusionPolicy.
|
||||
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
- Either:
|
||||
- At least one key starting with "observation.image is required as an input.
|
||||
AND/OR
|
||||
- The key "observation.environment_state" is required as input.
|
||||
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||
views. Right now we only support all images having the same shape.
|
||||
- "action" is required as an output key.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||
downsampling.
|
||||
kernel_size: The convolutional kernel size of the diffusion modeling Unet.
|
||||
n_groups: Number of groups used in the group norm of the Unet's convolutional blocks.
|
||||
diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear
|
||||
network. This is the output dimension of that network, i.e., the embedding dimension.
|
||||
use_film_scale_modulation: FiLM (https://huggingface.co/papers/1709.07871) is used for the Unet conditioning.
|
||||
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
||||
modulation.
|
||||
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
|
||||
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
||||
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
||||
beta_start: Beta value for the first forward-diffusion step.
|
||||
beta_end: Beta value for the last forward-diffusion step.
|
||||
prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon"
|
||||
or "sample". These have equivalent outcomes from a latent variable modeling perspective, but
|
||||
"epsilon" has been shown to work better in many deep neural network settings.
|
||||
clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
|
||||
denoising step at inference time. WARNING: you will need to make sure your action-space is
|
||||
normalized to fit within this range.
|
||||
clip_sample_range: The magnitude of the clipping range as described above.
|
||||
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
||||
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
||||
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
|
||||
`LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
|
||||
to False as the original Diffusion Policy implementation does the same.
|
||||
"""
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 2
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
crop_shape: tuple[int, int] | None = (84, 84)
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
use_separate_rgb_encoder_per_camera: bool = False
|
||||
# Unet.
|
||||
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||
kernel_size: int = 5
|
||||
n_groups: int = 8
|
||||
diffusion_step_embed_dim: int = 128
|
||||
use_film_scale_modulation: bool = True
|
||||
# Noise scheduler.
|
||||
noise_scheduler_type: str = "DDPM"
|
||||
num_train_timesteps: int = 100
|
||||
beta_schedule: str = "squaredcos_cap_v2"
|
||||
beta_start: float = 0.0001
|
||||
beta_end: float = 0.02
|
||||
prediction_type: str = "epsilon"
|
||||
clip_sample: bool = True
|
||||
clip_sample_range: float = 1.0
|
||||
|
||||
# Inference
|
||||
num_inference_steps: int | None = None
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-6
|
||||
scheduler_name: str = "cosine"
|
||||
scheduler_warmup_steps: int = 500
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
|
||||
supported_prediction_types = ["epsilon", "sample"]
|
||||
if self.prediction_type not in supported_prediction_types:
|
||||
raise ValueError(
|
||||
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
||||
)
|
||||
supported_noise_schedulers = ["DDPM", "DDIM"]
|
||||
if self.noise_scheduler_type not in supported_noise_schedulers:
|
||||
raise ValueError(
|
||||
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
||||
# Check that the horizon size and U-Net downsampling is compatible.
|
||||
# U-Net downsamples by 2 with each stage.
|
||||
downsampling_factor = 2 ** len(self.down_dims)
|
||||
if self.horizon % downsampling_factor != 0:
|
||||
raise ValueError(
|
||||
"The horizon should be an integer multiple of the downsampling factor (which is determined "
|
||||
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||
return DiffuserSchedulerConfig(
|
||||
name=self.scheduler_name,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
if len(self.image_features) > 0:
|
||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
raise ValueError(
|
||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,764 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
|
||||
TODO(alexander-soare):
|
||||
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_dtype_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
(paper: https://huggingface.co/papers/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||
"""
|
||||
|
||||
config_class = DiffusionConfig
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
self.diffusion = DiffusionModel(config)
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.diffusion.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
if self.config.image_features:
|
||||
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.config.env_state_feature:
|
||||
self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
# stack n latest observations from the queue
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch, noise=noise)
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method handles caching a history of observations and an action trajectory generated by the
|
||||
underlying diffusion model. Here's how it works:
|
||||
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
||||
copied `n_obs_steps` times to fill the cache).
|
||||
- The diffusion model generates `horizon` steps worth of actions.
|
||||
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
||||
Schematically this looks like:
|
||||
----------------------------------------------------------------------------------------------
|
||||
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
||||
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
||||
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||
----------------------------------------------------------------------------------------------
|
||||
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
# NOTE: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch, noise=noise)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues[ACTION].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
return loss, None
|
||||
|
||||
|
||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||
"""
|
||||
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
||||
to the scheduler.
|
||||
"""
|
||||
if name == "DDPM":
|
||||
return DDPMScheduler(**kwargs)
|
||||
elif name == "DDIM":
|
||||
return DDIMScheduler(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||
|
||||
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Build observation encoders (depending on which observations are provided).
|
||||
global_cond_dim = self.config.robot_state_feature.shape[0]
|
||||
if self.config.image_features:
|
||||
num_images = len(self.config.image_features)
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
|
||||
self.rgb_encoder = nn.ModuleList(encoders)
|
||||
global_cond_dim += encoders[0].feature_dim * num_images
|
||||
else:
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||
if self.config.env_state_feature:
|
||||
global_cond_dim += self.config.env_state_feature.shape[0]
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
beta_start=config.beta_start,
|
||||
beta_end=config.beta_end,
|
||||
beta_schedule=config.beta_schedule,
|
||||
clip_sample=config.clip_sample,
|
||||
clip_sample_range=config.clip_sample_range,
|
||||
prediction_type=config.prediction_type,
|
||||
)
|
||||
|
||||
if config.num_inference_steps is None:
|
||||
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
||||
else:
|
||||
self.num_inference_steps = config.num_inference_steps
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
self,
|
||||
batch_size: int,
|
||||
global_cond: Tensor | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
noise: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
dtype = get_dtype_from_parameters(self)
|
||||
|
||||
# Sample prior.
|
||||
sample = (
|
||||
noise
|
||||
if noise is not None
|
||||
else torch.randn(
|
||||
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
)
|
||||
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
for t in self.noise_scheduler.timesteps:
|
||||
# Predict model output.
|
||||
model_output = self.unet(
|
||||
sample,
|
||||
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
||||
global_cond=global_cond,
|
||||
)
|
||||
# Compute previous image: x_t -> x_t-1
|
||||
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode image features and concatenate them all together along with the state vector."""
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
global_cond_feats = [batch[OBS_STATE]]
|
||||
# Extract image features.
|
||||
if self.config.image_features:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(batch[OBS_IMAGES], "b s n ... -> n (b s) ...")
|
||||
img_features_list = torch.cat(
|
||||
[
|
||||
encoder(images)
|
||||
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
||||
]
|
||||
)
|
||||
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
global_cond_feats.append(img_features)
|
||||
|
||||
if self.config.env_state_feature:
|
||||
global_cond_feats.append(batch[OBS_ENV_STATE])
|
||||
|
||||
# Concatenate features then flatten to (B, global_cond_dim).
|
||||
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||
|
||||
def generate_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""
|
||||
This function expects `batch` to have:
|
||||
{
|
||||
"observation.state": (B, n_obs_steps, state_dim)
|
||||
|
||||
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||
AND/OR
|
||||
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||
}
|
||||
"""
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Encode image features and concatenate them all together along with the state vector.
|
||||
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||
|
||||
# run sampling
|
||||
actions = self.conditional_sample(batch_size, global_cond=global_cond, noise=noise)
|
||||
|
||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.config.n_action_steps
|
||||
actions = actions[:, start:end]
|
||||
|
||||
return actions
|
||||
|
||||
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, n_obs_steps, state_dim)
|
||||
|
||||
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||
AND/OR
|
||||
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||
|
||||
"action": (B, horizon, action_dim)
|
||||
"action_is_pad": (B, horizon)
|
||||
}
|
||||
"""
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
|
||||
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
||||
n_obs_steps = batch[OBS_STATE].shape[1]
|
||||
horizon = batch[ACTION].shape[1]
|
||||
assert horizon == self.config.horizon
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Encode image features and concatenate them all together along with the state vector.
|
||||
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||
|
||||
# Forward diffusion.
|
||||
trajectory = batch[ACTION]
|
||||
# Sample noise to add to the trajectory.
|
||||
eps = torch.randn(trajectory.shape, device=trajectory.device)
|
||||
# Sample a random noising timestep for each item in the batch.
|
||||
timesteps = torch.randint(
|
||||
low=0,
|
||||
high=self.noise_scheduler.config.num_train_timesteps,
|
||||
size=(trajectory.shape[0],),
|
||||
device=trajectory.device,
|
||||
).long()
|
||||
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
|
||||
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
|
||||
|
||||
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
|
||||
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
|
||||
|
||||
# Compute the loss.
|
||||
# The target is either the original trajectory, or the noise.
|
||||
if self.config.prediction_type == "epsilon":
|
||||
target = eps
|
||||
elif self.config.prediction_type == "sample":
|
||||
target = batch[ACTION]
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(pred, target, reduction="none")
|
||||
|
||||
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
||||
if self.config.do_mask_loss_for_padding:
|
||||
if "action_is_pad" not in batch:
|
||||
raise ValueError(
|
||||
"You need to provide 'action_is_pad' in the batch when "
|
||||
f"{self.config.do_mask_loss_for_padding=}."
|
||||
)
|
||||
in_episode_bound = ~batch["action_is_pad"]
|
||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||
(https://huggingface.co/papers/1509.06113). A minimal port of the robomimic implementation.
|
||||
|
||||
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
|
||||
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
|
||||
|
||||
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
|
||||
-----------------------------------------------------
|
||||
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
|
||||
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
|
||||
| ... | ... | ... | ... |
|
||||
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
|
||||
-----------------------------------------------------
|
||||
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
|
||||
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
|
||||
|
||||
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
|
||||
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
|
||||
linear mapping (in_channels, H, W) -> (num_kp, H, W).
|
||||
"""
|
||||
|
||||
def __init__(self, input_shape, num_kp=None):
|
||||
"""
|
||||
Args:
|
||||
input_shape (list): (C, H, W) input feature map shape.
|
||||
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3
|
||||
self._in_c, self._in_h, self._in_w = input_shape
|
||||
|
||||
if num_kp is not None:
|
||||
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||
self._out_c = num_kp
|
||||
else:
|
||||
self.nets = None
|
||||
self._out_c = self._in_c
|
||||
|
||||
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||
# and causes a small degradation in pc_success of pre-trained models.
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# register as buffer so it's moved to the correct device.
|
||||
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
|
||||
|
||||
def forward(self, features: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
features: (B, C, H, W) input feature maps.
|
||||
Returns:
|
||||
(B, K, 2) image-space coordinates of keypoints.
|
||||
"""
|
||||
if self.nets is not None:
|
||||
features = self.nets(features)
|
||||
|
||||
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||
features = features.reshape(-1, self._in_h * self._in_w)
|
||||
# 2d softmax normalization
|
||||
attention = F.softmax(features, dim=-1)
|
||||
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
|
||||
expected_xy = attention @ self.pos_grid
|
||||
# reshape to [B, K, 2]
|
||||
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||
|
||||
return feature_keypoints
|
||||
|
||||
|
||||
class DiffusionRgbEncoder(nn.Module):
|
||||
"""Encodes an RGB image into a 1D feature vector.
|
||||
|
||||
Includes the ability to normalize and crop the image first.
|
||||
"""
|
||||
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if config.crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
self.do_crop = False
|
||||
|
||||
# Set up backbone.
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
weights=config.pretrained_backbone_weights
|
||||
)
|
||||
# Note: This assumes that the layer4 feature map is children()[-3]
|
||||
# TODO(alexander-soare): Use a safer alternative.
|
||||
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||
if config.use_group_norm:
|
||||
if config.pretrained_backbone_weights:
|
||||
raise ValueError(
|
||||
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
|
||||
)
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.image_features`.
|
||||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, C, H, W) image tensor with pixel values in [0, 1].
|
||||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
else:
|
||||
# Always use center crop for eval.
|
||||
x = self.center_crop(x)
|
||||
# Extract backbone feature.
|
||||
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
|
||||
# Final linear layer with non-linearity.
|
||||
x = self.relu(self.out(x))
|
||||
return x
|
||||
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
root_module: The module for which the submodules need to be replaced
|
||||
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
|
||||
func: Takes a module as an argument and returns a new module to replace it with.
|
||||
Returns:
|
||||
The root module with its submodules replaced.
|
||||
"""
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
parent_module = root_module.get_submodule(".".join(parents))
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
src_module = parent_module[int(k)]
|
||||
else:
|
||||
src_module = getattr(parent_module, k)
|
||||
tgt_module = func(src_module)
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
parent_module[int(k)] = tgt_module
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
return root_module
|
||||
|
||||
|
||||
class DiffusionSinusoidalPosEmb(nn.Module):
|
||||
"""1D sinusoidal positional embeddings as in Attention is All You Need."""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class DiffusionConv1dBlock(nn.Module):
|
||||
"""Conv1d --> GroupNorm --> Mish"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class DiffusionConditionalUnet1d(nn.Module):
|
||||
"""A 1D convolutional UNet with FiLM modulation for conditioning.
|
||||
|
||||
Note: this removes local conditioning as compared to the original diffusion policy code.
|
||||
"""
|
||||
|
||||
def __init__(self, config: DiffusionConfig, global_cond_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
# Encoder for the diffusion timestep.
|
||||
self.diffusion_step_encoder = nn.Sequential(
|
||||
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
|
||||
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
|
||||
)
|
||||
|
||||
# The FiLM conditioning dimension.
|
||||
cond_dim = config.diffusion_step_embed_dim + global_cond_dim
|
||||
|
||||
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
||||
# just reverse these.
|
||||
in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list(
|
||||
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
|
||||
)
|
||||
|
||||
# Unet encoder.
|
||||
common_res_block_kwargs = {
|
||||
"cond_dim": cond_dim,
|
||||
"kernel_size": config.kernel_size,
|
||||
"n_groups": config.n_groups,
|
||||
"use_film_scale_modulation": config.use_film_scale_modulation,
|
||||
}
|
||||
self.down_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
self.down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
# Downsample as long as it is not the last block.
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Processing in the middle of the auto-encoder.
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Unet decoder.
|
||||
self.up_modules = nn.ModuleList([])
|
||||
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
self.up_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
# dim_in * 2, because it takes the encoder's skip connection as well
|
||||
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
# Upsample as long as it is not the last block.
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
||||
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, T, input_dim) tensor for input to the Unet.
|
||||
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
|
||||
global_cond: (B, global_cond_dim)
|
||||
output: (B, T, input_dim)
|
||||
Returns:
|
||||
(B, T, input_dim) diffusion model prediction.
|
||||
"""
|
||||
# For 1D convolutions we'll need feature dimension first.
|
||||
x = einops.rearrange(x, "b t d -> b d t")
|
||||
|
||||
timesteps_embed = self.diffusion_step_encoder(timestep)
|
||||
|
||||
# If there is a global conditioning feature, concatenate it to the timestep embedding.
|
||||
if global_cond is not None:
|
||||
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
|
||||
else:
|
||||
global_feature = timesteps_embed
|
||||
|
||||
# Run encoder, keeping track of skip features to pass to the decoder.
|
||||
encoder_skip_features: list[Tensor] = []
|
||||
for resnet, resnet2, downsample in self.down_modules:
|
||||
x = resnet(x, global_feature)
|
||||
x = resnet2(x, global_feature)
|
||||
encoder_skip_features.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
for mid_module in self.mid_modules:
|
||||
x = mid_module(x, global_feature)
|
||||
|
||||
# Run decoder, using the skip features from the encoder.
|
||||
for resnet, resnet2, upsample in self.up_modules:
|
||||
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
|
||||
x = resnet(x, global_feature)
|
||||
x = resnet2(x, global_feature)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
x = einops.rearrange(x, "b d t -> b t d")
|
||||
return x
|
||||
|
||||
|
||||
class DiffusionConditionalResidualBlock1d(nn.Module):
|
||||
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
cond_dim: int,
|
||||
kernel_size: int = 3,
|
||||
n_groups: int = 8,
|
||||
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
|
||||
# FiLM just modulates bias).
|
||||
use_film_scale_modulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_film_scale_modulation = use_film_scale_modulation
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# FiLM modulation (https://huggingface.co/papers/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
||||
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||
|
||||
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# A final convolution for dimension matching the residual (if needed).
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, in_channels, T)
|
||||
cond: (B, cond_dim)
|
||||
Returns:
|
||||
(B, out_channels, T)
|
||||
"""
|
||||
out = self.conv1(x)
|
||||
|
||||
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
|
||||
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
|
||||
if self.use_film_scale_modulation:
|
||||
# Treat the embedding as a list of scales and biases.
|
||||
scale = cond_embed[:, : self.out_channels]
|
||||
bias = cond_embed[:, self.out_channels :]
|
||||
out = scale * out + bias
|
||||
else:
|
||||
# Treat the embedding as biases.
|
||||
out = out + cond_embed
|
||||
|
||||
out = self.conv2(out)
|
||||
out = out + self.residual_conv(x)
|
||||
return out
|
||||
@@ -1,92 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_diffusion_pre_post_processors(
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for a diffusion policy.
|
||||
|
||||
The pre-processing pipeline prepares the input data for the model by:
|
||||
1. Renaming features.
|
||||
2. Normalizing the input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Moving the data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving the data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the diffusion policy,
|
||||
containing feature definitions, normalization mappings, and device information.
|
||||
dataset_stats: A dictionary of statistics used for normalization.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
474
environment.yml
474
environment.yml
@@ -1,474 +0,0 @@
|
||||
name: roboimi
|
||||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1
|
||||
- _openmp_mutex=4.5
|
||||
- _python_abi3_support=1.0
|
||||
- aiohappyeyeballs=2.6.1
|
||||
- aiohttp=3.13.3
|
||||
- aiosignal=1.4.0
|
||||
- alsa-lib=1.2.9
|
||||
- anyio=4.12.1
|
||||
- aom=3.5.0
|
||||
- async-timeout=5.0.1
|
||||
- attr=2.5.1
|
||||
- attrs=25.4.0
|
||||
- aws-c-auth=0.7.22
|
||||
- aws-c-cal=0.6.15
|
||||
- aws-c-common=0.9.23
|
||||
- aws-c-compression=0.2.18
|
||||
- aws-c-event-stream=0.4.2
|
||||
- aws-c-http=0.8.2
|
||||
- aws-c-io=0.14.9
|
||||
- aws-c-mqtt=0.10.4
|
||||
- aws-c-s3=0.5.10
|
||||
- aws-c-sdkutils=0.1.16
|
||||
- aws-checksums=0.1.18
|
||||
- aws-crt-cpp=0.26.12
|
||||
- aws-sdk-cpp=1.11.329
|
||||
- box2d-py=2.3.8
|
||||
- brotli=1.1.0
|
||||
- brotli-bin=1.1.0
|
||||
- brotli-python=1.1.0
|
||||
- bzip2=1.0.8
|
||||
- c-ares=1.34.6
|
||||
- ca-certificates=2026.1.4
|
||||
- cairo=1.16.0
|
||||
- certifi=2026.1.4
|
||||
- cffi=1.17.1
|
||||
- charset-normalizer=3.4.4
|
||||
- click=8.3.1
|
||||
- cloudpickle=3.0.0
|
||||
- contourpy=1.3.0
|
||||
- cpython=3.10.19
|
||||
- cuda-cudart=12.6.68
|
||||
- cuda-cudart_linux-64=12.6.68
|
||||
- cuda-nvrtc=12.6.68
|
||||
- cuda-nvtx=12.6.68
|
||||
- cuda-version=12.6
|
||||
- cudnn=8.9.7.29
|
||||
- cycler=0.12.1
|
||||
- datasets=4.0.0
|
||||
- dav1d=1.2.1
|
||||
- dbus=1.13.6
|
||||
- dill=0.3.8
|
||||
- eigen=3.4.0
|
||||
- exceptiongroup=1.3.1
|
||||
- expat=2.6.3
|
||||
- farama-notifications=0.0.4
|
||||
- filelock=3.15.4
|
||||
- fluidsynth=2.3.3
|
||||
- font-ttf-dejavu-sans-mono=2.37
|
||||
- font-ttf-inconsolata=3.000
|
||||
- font-ttf-source-code-pro=2.038
|
||||
- font-ttf-ubuntu=0.83
|
||||
- fontconfig=2.14.2
|
||||
- fonts-conda-ecosystem=1
|
||||
- fonts-conda-forge=1
|
||||
- fonttools=4.53.1
|
||||
- freetype=2.12.1
|
||||
- frozenlist=1.7.0
|
||||
- fsspec=2024.6.1
|
||||
- gettext=0.22.5
|
||||
- gettext-tools=0.22.5
|
||||
- gflags=2.2.2
|
||||
- git-lfs=3.7.1
|
||||
- glog=0.7.1
|
||||
- gmp=6.3.0
|
||||
- gmpy2=2.1.5
|
||||
- graphite2=1.3.13
|
||||
- gym=0.26.1
|
||||
- gym-box2d=0.26.1
|
||||
- gym-notices=0.0.8
|
||||
- gymnasium=0.29.1
|
||||
- h11=0.16.0
|
||||
- h2=4.3.0
|
||||
- harfbuzz=7.3.0
|
||||
- hf-xet=1.2.1
|
||||
- hpack=4.1.0
|
||||
- httpcore=1.0.9
|
||||
- httpx=0.28.1
|
||||
- huggingface_hub=1.3.5
|
||||
- hyperframe=6.1.0
|
||||
- icu=72.1
|
||||
- idna=3.11
|
||||
- jack=1.9.22
|
||||
- jax-jumpy=1.0.0
|
||||
- jinja2=3.1.4
|
||||
- jpeg=9e
|
||||
- keyutils=1.6.3
|
||||
- kiwisolver=1.4.9
|
||||
- krb5=1.21.3
|
||||
- lame=3.100
|
||||
- lcms2=2.15
|
||||
- ld_impl_linux-64=2.40
|
||||
- lerc=4.0.0
|
||||
- libabseil=20240116.2
|
||||
- libarrow=16.1.0
|
||||
- libarrow-acero=16.1.0
|
||||
- libarrow-dataset=16.1.0
|
||||
- libarrow-substrait=16.1.0
|
||||
- libasprintf=0.22.5
|
||||
- libasprintf-devel=0.22.5
|
||||
- libavif=0.11.1
|
||||
- libblas=3.9.0
|
||||
- libbrotlicommon=1.1.0
|
||||
- libbrotlidec=1.1.0
|
||||
- libbrotlienc=1.1.0
|
||||
- libcap=2.69
|
||||
- libcblas=3.9.0
|
||||
- libcrc32c=1.1.2
|
||||
- libcublas=12.6.1.4
|
||||
- libcufft=11.2.6.59
|
||||
- libcurand=10.3.7.68
|
||||
- libcurl=8.12.1
|
||||
- libcusolver=11.6.4.69
|
||||
- libcusparse=12.5.3.3
|
||||
- libdb=6.2.32
|
||||
- libdeflate=1.17
|
||||
- libedit=3.1.20250104
|
||||
- libev=4.33
|
||||
- libevent=2.1.12
|
||||
- libexpat=2.6.3
|
||||
- libffi=3.4.2
|
||||
- libflac=1.4.3
|
||||
- libgcc=14.1.0
|
||||
- libgcc-ng=14.1.0
|
||||
- libgcrypt=1.11.0
|
||||
- libgettextpo=0.22.5
|
||||
- libgettextpo-devel=0.22.5
|
||||
- libgfortran=14.1.0
|
||||
- libgfortran-ng=14.1.0
|
||||
- libgfortran5=14.1.0
|
||||
- libglib=2.80.3
|
||||
- libgoogle-cloud=2.25.0
|
||||
- libgoogle-cloud-storage=2.25.0
|
||||
- libgpg-error=1.50
|
||||
- libgrpc=1.62.2
|
||||
- libhwloc=2.9.3
|
||||
- libiconv=1.17
|
||||
- libjpeg-turbo=2.1.4
|
||||
- liblapack=3.9.0
|
||||
- libmad=0.15.1b
|
||||
- libmagma=2.8.0
|
||||
- libmagma_sparse=2.8.0
|
||||
- libnghttp2=1.67.0
|
||||
- libnsl=2.0.1
|
||||
- libnvjitlink=12.6.68
|
||||
- libogg=1.3.5
|
||||
- libopenblas=0.3.27
|
||||
- libopus=1.3.1
|
||||
- libparquet=16.1.0
|
||||
- libpng=1.6.43
|
||||
- libprotobuf=4.25.3
|
||||
- libre2-11=2023.09.01
|
||||
- libsndfile=1.2.2
|
||||
- libsqlite=3.46.0
|
||||
- libssh2=1.11.1
|
||||
- libstdcxx=14.1.0
|
||||
- libstdcxx-ng=14.1.0
|
||||
- libsystemd0=256.5
|
||||
- libthrift=0.19.0
|
||||
- libtiff=4.5.0
|
||||
- libtorch=2.4.0
|
||||
- libutf8proc=2.8.0
|
||||
- libuuid=2.38.1
|
||||
- libuv=1.48.0
|
||||
- libvorbis=1.3.7
|
||||
- libwebp-base=1.4.0
|
||||
- libxcb=1.13
|
||||
- libxcrypt=4.4.36
|
||||
- libxml2=2.11.5
|
||||
- libzlib=1.3.1
|
||||
- llvm-openmp=18.1.8
|
||||
- lz4-c=1.9.4
|
||||
- markupsafe=2.1.5
|
||||
- matplotlib-base=3.9.2
|
||||
- mkl=2023.2.0
|
||||
- mpc=1.3.1
|
||||
- mpfr=4.2.1
|
||||
- mpg123=1.31.3
|
||||
- mpmath=1.3.0
|
||||
- multidict=6.7.0
|
||||
- multiprocess=0.70.16
|
||||
- munkres=1.1.4
|
||||
- nccl=2.22.3.1
|
||||
- ncurses=6.5
|
||||
- networkx=3.3
|
||||
- numpy=1.26.4
|
||||
- openjpeg=2.5.0
|
||||
- openssl=3.6.1
|
||||
- opusfile=0.12
|
||||
- orc=2.0.1
|
||||
- orocos-kdl=1.5.1
|
||||
- packaging=24.1
|
||||
- pandas=2.2.2
|
||||
- pcre2=10.44
|
||||
- pillow=9.4.0
|
||||
- pip=24.2
|
||||
- pixman=0.43.2
|
||||
- portaudio=19.6.0
|
||||
- portmidi=2.0.4
|
||||
- propcache=0.3.1
|
||||
- pthread-stubs=0.4
|
||||
- pulseaudio-client=16.1
|
||||
- pyarrow=16.1.0
|
||||
- pyarrow-core=16.1.0
|
||||
- pybind11=2.13.5
|
||||
- pybind11-global=2.13.5
|
||||
- pycparser=2.22
|
||||
- pygame=2.1.3
|
||||
- pyparsing=3.1.4
|
||||
- pysocks=1.7.1
|
||||
- python=3.10.14
|
||||
- python-dateutil=2.9.0
|
||||
- python-gil=3.10.19
|
||||
- python-orocos-kdl=1.5.1
|
||||
- python-tzdata=2024.1
|
||||
- python-xxhash=3.6.0
|
||||
- python_abi=3.10
|
||||
- pytorch=2.4.0
|
||||
- pytz=2024.1
|
||||
- pyyaml=6.0.3
|
||||
- qhull=2020.2
|
||||
- re2=2023.09.01
|
||||
- readline=8.2
|
||||
- regex=2026.1.15
|
||||
- requests=2.32.5
|
||||
- s2n=1.4.16
|
||||
- safetensors=0.7.0
|
||||
- sdl2=2.26.5
|
||||
- sdl2_image=2.6.3
|
||||
- sdl2_mixer=2.6.3
|
||||
- sdl2_ttf=2.20.2
|
||||
- setuptools=72.2.0
|
||||
- shellingham=1.5.4
|
||||
- six=1.16.0
|
||||
- sleef=3.6.1
|
||||
- snappy=1.2.2
|
||||
- sniffio=1.3.1
|
||||
- stable-baselines3=2.3.2
|
||||
- sympy=1.13.2
|
||||
- tbb=2021.11.0
|
||||
- tk=8.6.13
|
||||
- tokenizers=0.22.2
|
||||
- tqdm=4.67.2
|
||||
- transformers=5.0.0
|
||||
- typer-slim=0.21.1
|
||||
- typing-extensions=4.12.2
|
||||
- typing_extensions=4.12.2
|
||||
- tzdata=2024a
|
||||
- unicodedata2=15.1.0
|
||||
- urllib3=2.5.0
|
||||
- wheel=0.44.0
|
||||
- xorg-kbproto=1.0.7
|
||||
- xorg-libice=1.1.1
|
||||
- xorg-libsm=1.2.4
|
||||
- xorg-libx11=1.8.4
|
||||
- xorg-libxau=1.0.11
|
||||
- xorg-libxdmcp=1.1.3
|
||||
- xorg-libxext=1.3.4
|
||||
- xorg-libxrender=0.9.10
|
||||
- xorg-renderproto=0.11.1
|
||||
- xorg-xextproto=7.3.0
|
||||
- xorg-xproto=7.0.31
|
||||
- xxhash=0.8.3
|
||||
- xz=5.2.6
|
||||
- yaml=0.2.5
|
||||
- yarl=1.22.0
|
||||
- zlib=1.3.1
|
||||
- zstandard=0.23.0
|
||||
- zstd=1.5.6
|
||||
- pip:
|
||||
- GitPython==3.1.46
|
||||
- Jinja2==3.1.6
|
||||
- MarkupSafe==3.0.3
|
||||
- PyOpenGL==3.1.7
|
||||
- PyYAML==6.0.3
|
||||
- Pygments==2.19.2
|
||||
- absl-py==2.1.0
|
||||
- accelerate==1.12.0
|
||||
- aiofiles==24.1.0
|
||||
- aiohappyeyeballs==2.6.1
|
||||
- aiohttp==3.13.3
|
||||
- aiosignal==1.4.0
|
||||
- annotated-doc==0.0.4
|
||||
- annotated-types==0.7.0
|
||||
- antlr4-python3-runtime==4.9.3
|
||||
- anyio==4.12.1
|
||||
- asciitree==0.3.3
|
||||
- asttokens==3.0.1
|
||||
- async-timeout==5.0.1
|
||||
- attrs==25.4.0
|
||||
- av==15.1.0
|
||||
- brotli==1.2.0
|
||||
- charset-normalizer==3.4.4
|
||||
- cmake==4.1.3
|
||||
- cmeel==0.58.0
|
||||
- cmeel-assimp==5.4.3.1
|
||||
- cmeel-boost==1.87.0.1
|
||||
- cmeel-console-bridge==1.0.2.3
|
||||
- cmeel-octomap==1.10.0
|
||||
- cmeel-qhull==8.0.2.1
|
||||
- cmeel-tinyxml==2.6.2.3
|
||||
- cmeel-tinyxml2==10.0.0
|
||||
- cmeel-urdfdom==3.1.1.1
|
||||
- cmeel-zlib==1.3.1
|
||||
- coal==3.0.2
|
||||
- coal-library==3.0.1
|
||||
- colorama==0.4.6
|
||||
- datasets==4.5.0
|
||||
- decorator==5.2.1
|
||||
- deepdiff==8.6.1
|
||||
- diffusers==0.30.0
|
||||
- dill==0.4.0
|
||||
- docstring_parser==0.17.0
|
||||
- draccus==0.10.0
|
||||
- eigenpy==3.10.3
|
||||
- einops==0.8.1
|
||||
- etils==1.7.0
|
||||
- evdev==1.9.2
|
||||
- exceptiongroup==1.3.1
|
||||
- executing==2.2.1
|
||||
- fastapi==0.128.0
|
||||
- fasteners==0.20
|
||||
- ffmpy==1.0.0
|
||||
- filelock==3.20.3
|
||||
- frozenlist==1.8.0
|
||||
- fsspec==2025.10.0
|
||||
- gitdb==4.0.12
|
||||
- glfw==2.7.0
|
||||
- gradio==6.3.0
|
||||
- gradio_client==2.0.3
|
||||
- groovy==0.1.2
|
||||
- gymnasium==1.2.3
|
||||
- h11==0.16.0
|
||||
- h5py==3.15.1
|
||||
- hf-xet==1.2.0
|
||||
- hf_transfer==0.1.9
|
||||
- httpcore==1.0.9
|
||||
- httpx==0.28.1
|
||||
- huggingface_hub==1.3.2
|
||||
- hydra-core==1.3.2
|
||||
- imageio==2.35.1
|
||||
- imageio-ffmpeg==0.6.0
|
||||
- importlib_metadata==8.7.1
|
||||
- importlib_resources==6.5.2
|
||||
- inquirerpy==0.3.4
|
||||
- ipython==8.38.0
|
||||
- jedi==0.19.2
|
||||
- jsonargparse==4.45.0
|
||||
- jsonlines==4.0.0
|
||||
- kiwisolver==1.4.5
|
||||
- lerobot==0.4.2
|
||||
- libcoal==3.0.2
|
||||
- libpinocchio==3.8.0
|
||||
- lightning==2.5.0.post0
|
||||
- lightning-utilities==0.15.2
|
||||
- lxml==5.3.0
|
||||
- markdown-it-py==4.0.0
|
||||
- matplotlib-inline==0.2.1
|
||||
- mdurl==0.1.2
|
||||
- mergedeep==1.3.4
|
||||
- mpmath==1.3.0
|
||||
- mujoco==3.2.2
|
||||
- mujoco-python-viewer==0.1.4
|
||||
- multidict==6.7.0
|
||||
- multiprocess==0.70.18
|
||||
- mypy_extensions==1.1.0
|
||||
- networkx==3.4.2
|
||||
- numcodecs==0.13.1
|
||||
- numpy==2.2.6
|
||||
- nvidia-cublas-cu12==12.4.5.8
|
||||
- nvidia-cuda-cupti-cu12==12.4.127
|
||||
- nvidia-cuda-nvrtc-cu12==12.4.127
|
||||
- nvidia-cuda-runtime-cu12==12.4.127
|
||||
- nvidia-cudnn-cu12==9.1.0.70
|
||||
- nvidia-cufft-cu12==11.2.1.3
|
||||
- nvidia-cufile-cu12==1.11.1.6
|
||||
- nvidia-curand-cu12==10.3.5.147
|
||||
- nvidia-cusolver-cu12==11.6.1.9
|
||||
- nvidia-cusparse-cu12==12.3.1.170
|
||||
- nvidia-cusparselt-cu12==0.6.3
|
||||
- nvidia-nccl-cu12==2.21.5
|
||||
- nvidia-nvjitlink-cu12==12.4.127
|
||||
- nvidia-nvshmem-cu12==3.3.20
|
||||
- nvidia-nvtx-cu12==12.4.127
|
||||
- omegaconf==2.3.0
|
||||
- opencv-contrib-python==4.10.0.84
|
||||
- opencv-python==4.13.0.90
|
||||
- orderly-set==5.5.0
|
||||
- orjson==3.11.5
|
||||
- packaging==24.2
|
||||
- pandas==2.3.3
|
||||
- parso==0.8.5
|
||||
- pexpect==4.9.0
|
||||
- pfzy==0.3.4
|
||||
- pillow==12.1.0
|
||||
- pin==3.3.1
|
||||
- platformdirs==4.5.1
|
||||
- prompt_toolkit==3.0.52
|
||||
- propcache==0.4.1
|
||||
- protobuf==6.33.4
|
||||
- proxsuite==0.7.2
|
||||
- psutil==7.2.1
|
||||
- ptyprocess==0.7.0
|
||||
- pure_eval==0.2.3
|
||||
- pyarrow==22.0.0
|
||||
- pydantic==2.12.5
|
||||
- pydantic_core==2.41.5
|
||||
- pydub==0.25.1
|
||||
- pynput==1.8.1
|
||||
- pyquaternion==0.9.9
|
||||
- pyserial==3.5
|
||||
- python-dateutil==2.9.0.post0
|
||||
- python-multipart==0.0.21
|
||||
- python-xlib==0.33
|
||||
- pytorch-lightning==2.6.0
|
||||
- pyyaml-include==1.4.1
|
||||
- qwen-vl-utils==0.0.14
|
||||
- regex==2026.1.15
|
||||
- requests==2.32.5
|
||||
- rerun-sdk==0.26.2
|
||||
- rich==14.2.0
|
||||
- ruckig==0.9.2
|
||||
- safehttpx==0.1.7
|
||||
- safetensors==0.7.0
|
||||
- scipy==1.14.1
|
||||
- semantic-version==2.10.0
|
||||
- sentry-sdk==2.49.0
|
||||
- shellingham==1.5.4
|
||||
- smmap==5.0.2
|
||||
- stack-data==0.6.3
|
||||
- starlette==0.50.0
|
||||
- sympy==1.13.1
|
||||
- termcolor==3.3.0
|
||||
- timm==1.0.24
|
||||
- toml==0.10.2
|
||||
- tomli==2.4.0
|
||||
- tomlkit==0.13.3
|
||||
- torch==2.5.0
|
||||
- torchcodec==0.5
|
||||
- torchmetrics==1.8.2
|
||||
- torchvision==0.20.0
|
||||
- tqdm==4.67.1
|
||||
- traitlets==5.14.3
|
||||
- triton==3.1.0
|
||||
- typer==0.21.1
|
||||
- typer-slim==0.21.1
|
||||
- typeshed_client==2.8.2
|
||||
- typing-inspect==0.9.0
|
||||
- typing-inspection==0.4.2
|
||||
- typing_extensions==4.15.0
|
||||
- tzdata==2025.3
|
||||
- urdf_parser_py==0.0.4
|
||||
- urllib3==2.6.3
|
||||
- uv==0.9.28
|
||||
- uvicorn==0.40.0
|
||||
- wandb==0.24.0
|
||||
- wcwidth==0.2.14
|
||||
- xxhash==3.6.0
|
||||
- yarl==1.22.0
|
||||
- zarr==2.18.3
|
||||
- zipp==3.20.1
|
||||
@@ -1,324 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
将 HDF5 数据集转换为视频,用于可视化检查
|
||||
|
||||
功能:
|
||||
1. 将单个 episode 转换为视频
|
||||
2. 对比多个 episode 的视频
|
||||
3. 放慢播放速度便于观察
|
||||
"""
|
||||
import os
|
||||
import h5py
|
||||
import glob
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def episode_to_video(episode_file, output_path, camera='top', fps=30, slow_factor=1):
|
||||
"""
|
||||
将单个 episode 转换为视频
|
||||
|
||||
Args:
|
||||
episode_file: HDF5 文件路径
|
||||
output_path: 输出视频路径
|
||||
camera: 要使用的相机名称
|
||||
fps: 帧率
|
||||
slow_factor: 慢放倍数(1=正常,2=半速)
|
||||
"""
|
||||
try:
|
||||
with h5py.File(episode_file, 'r') as f:
|
||||
# 读取图像序列
|
||||
img_path = f'/observations/images/{camera}'
|
||||
|
||||
if img_path not in f:
|
||||
print(f" ❌ 相机 {camera} 不存在")
|
||||
return False
|
||||
|
||||
images = f[img_path][:] # shape: (T, H, W, C)
|
||||
qpos = f['/observations/qpos'][:]
|
||||
actions = f['/action'][:]
|
||||
|
||||
total_frames = len(images)
|
||||
height, width = images.shape[1], images.shape[2]
|
||||
|
||||
# 创建视频写入器
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
actual_fps = fps // slow_factor
|
||||
out = cv2.VideoWriter(output_path, fourcc, actual_fps, (width, height))
|
||||
|
||||
# 逐帧写入
|
||||
for i in range(total_frames):
|
||||
frame = images[i].astype(np.uint8)
|
||||
|
||||
# 在图像上添加信息
|
||||
info_text = [
|
||||
f"Episode: {os.path.basename(episode_file).replace('.hdf5', '')}",
|
||||
f"Frame: {i}/{total_frames}",
|
||||
f"qpos[0:3]: [{qpos[i, 0]:.2f}, {qpos[i, 1]:.2f}, {qpos[i, 2]:.2f}]",
|
||||
]
|
||||
|
||||
for j, text in enumerate(info_text):
|
||||
cv2.putText(frame, text, (10, 30 + j*30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||||
|
||||
out.write(frame)
|
||||
|
||||
out.release()
|
||||
print(f" ✅ 保存: {output_path}")
|
||||
print(f" 帧数: {total_frames}, 尺寸: {width}x{height}, FPS: {actual_fps}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def generate_all_videos(camera='top', num_episodes=5, slow_factor=1):
|
||||
"""生成前 N 个 episode 的视频"""
|
||||
|
||||
dataset_dir = "roboimi/demos/dataset/sim_transfer"
|
||||
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
|
||||
|
||||
if len(episode_files) == 0:
|
||||
print(f"❌ 没有找到数据文件: {dataset_dir}")
|
||||
return
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = '/tmp/dataset_videos'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
print(f"找到 {len(episode_files)} 个 episode 文件")
|
||||
print(f"将生成前 {min(num_episodes, len(episode_files))} 个 episode 的视频\n")
|
||||
|
||||
# 生成视频
|
||||
for i in range(min(num_episodes, len(episode_files))):
|
||||
ep_file = episode_files[i]
|
||||
ep_name = os.path.basename(ep_file).replace('.hdf5', '')
|
||||
output_path = f"{output_dir}/{ep_name}_{camera}.mp4"
|
||||
|
||||
print(f"[{i+1}/{min(num_episodes, len(episode_files))}] {ep_name}")
|
||||
episode_to_video(ep_file, output_path, camera=camera, slow_factor=slow_factor)
|
||||
print()
|
||||
|
||||
print(f"✅ 所有视频已保存到: {output_dir}")
|
||||
print(f"\n播放方法:")
|
||||
print(f" # 播放单个视频")
|
||||
print(f" vlc {output_dir}/*.mp4")
|
||||
print(f" ")
|
||||
print(f" # 或用文件管理器")
|
||||
print(f" nautilus {output_dir}")
|
||||
|
||||
|
||||
def generate_multi_camera_video(episode_idx=0, slow_factor=1):
|
||||
"""生成包含多个相机的视频(分屏显示)"""
|
||||
|
||||
dataset_dir = "roboimi/demos/dataset/sim_transfer"
|
||||
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
|
||||
|
||||
if episode_idx >= len(episode_files):
|
||||
print(f"❌ Episode {episode_idx} 不存在")
|
||||
return
|
||||
|
||||
ep_file = episode_files[episode_idx]
|
||||
|
||||
try:
|
||||
with h5py.File(ep_file, 'r') as f:
|
||||
# 获取所有相机
|
||||
cameras = []
|
||||
for key in f.keys():
|
||||
if 'images' in key:
|
||||
for cam_name in f[key].keys():
|
||||
if cam_name not in cameras:
|
||||
cameras.append(cam_name)
|
||||
|
||||
print(f"Episode {episode_idx} 的相机: {cameras}")
|
||||
|
||||
# 读取所有相机的图像
|
||||
all_images = {}
|
||||
for cam in cameras:
|
||||
img_path = f'/observations/images/{cam}'
|
||||
if img_path in f:
|
||||
all_images[cam] = f[img_path][:]
|
||||
|
||||
if not all_images:
|
||||
print("❌ 没有找到图像数据")
|
||||
return
|
||||
|
||||
# 获取第一个相机的尺寸
|
||||
first_cam = list(all_images.keys())[0]
|
||||
total_frames = len(all_images[first_cam])
|
||||
height, width = all_images[first_cam].shape[1], all_images[first_cam].shape[2]
|
||||
|
||||
# 创建多相机布局
|
||||
num_cams = len(all_images)
|
||||
cols = min(2, num_cams)
|
||||
rows = (num_cams + cols - 1) // cols
|
||||
|
||||
canvas_width = width * cols
|
||||
canvas_height = height * rows
|
||||
|
||||
# 创建视频写入器
|
||||
output_path = f'/tmp/dataset_videos/episode_{episode_idx}_all_cameras.mp4'
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(output_path, fourcc, 30 // slow_factor, (canvas_width, canvas_height))
|
||||
|
||||
# 逐帧合成
|
||||
for i in range(total_frames):
|
||||
canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8)
|
||||
|
||||
for cam_idx, cam_name in enumerate(all_images.keys()):
|
||||
img = all_images[cam_name][i]
|
||||
|
||||
# 计算在画布上的位置
|
||||
row = cam_idx // cols
|
||||
col = cam_idx % cols
|
||||
y_start = row * height
|
||||
y_end = y_start + height
|
||||
x_start = col * width
|
||||
x_end = x_start + width
|
||||
|
||||
# 调整大小(如果需要)
|
||||
if img.shape[:2] != (height, width):
|
||||
img = cv2.resize(img, (width, height))
|
||||
|
||||
# 放到画布上
|
||||
canvas[y_start:y_end, x_start:x_end] = img
|
||||
|
||||
# 添加相机名称
|
||||
cv2.putText(canvas, cam_name, (x_start + 10, y_start + 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
|
||||
|
||||
# 添加帧信息
|
||||
cv2.putText(canvas, f"Frame: {i}/{total_frames}", (10, canvas_height - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||||
|
||||
out.write(canvas)
|
||||
|
||||
out.release()
|
||||
print(f"✅ 保存多相机视频: {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 错误: {e}")
|
||||
|
||||
|
||||
def compare_episodes(camera='top', slow_factor=2):
|
||||
"""并排对比多个 episode 的视频"""
|
||||
|
||||
dataset_dir = "roboimi/demos/dataset/sim_transfer"
|
||||
episode_files = sorted(glob.glob(os.path.join(dataset_dir, "episode_*.hdf5")))
|
||||
|
||||
# 选择要对比的 episode
|
||||
episodes_to_compare = [0, 1, 2, 3, 4] # 对比前 5 个
|
||||
|
||||
print(f"对比 Episodes: {episodes_to_compare}")
|
||||
|
||||
# 读取所有 episode 的数据
|
||||
all_data = []
|
||||
for ep_idx in episodes_to_compare:
|
||||
if ep_idx >= len(episode_files):
|
||||
continue
|
||||
|
||||
try:
|
||||
with h5py.File(episode_files[ep_idx], 'r') as f:
|
||||
img_path = f'/observations/images/{camera}'
|
||||
if img_path in f:
|
||||
all_data.append({
|
||||
'idx': ep_idx,
|
||||
'images': f[img_path][:],
|
||||
'qpos': f['/observations/qpos'][:]
|
||||
})
|
||||
except:
|
||||
pass
|
||||
|
||||
if len(all_data) == 0:
|
||||
print("❌ 没有数据")
|
||||
return
|
||||
|
||||
# 获取参数
|
||||
first_data = all_data[0]
|
||||
height, width = first_data['images'].shape[1], first_data['images'].shape[2]
|
||||
total_frames = min([d['images'].shape[0] for d in all_data])
|
||||
|
||||
# 创建并排布局
|
||||
num_compare = len(all_data)
|
||||
canvas_width = width * num_compare
|
||||
canvas_height = height
|
||||
|
||||
# 创建视频
|
||||
output_path = f'/tmp/dataset_videos/compare_{camera}.mp4'
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(output_path, fourcc, 30 // slow_factor, (canvas_width, canvas_height))
|
||||
|
||||
print(f"生成对比视频,共 {total_frames} 帧...")
|
||||
|
||||
# 逐帧对比
|
||||
for i in range(total_frames):
|
||||
canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8)
|
||||
|
||||
for j, data in enumerate(all_data):
|
||||
img = data['images'][i]
|
||||
qpos = data['qpos'][i]
|
||||
|
||||
# 调整大小(如果需要)
|
||||
if img.shape[:2] != (height, width):
|
||||
img = cv2.resize(img, (width, height))
|
||||
|
||||
# 放到画布上
|
||||
x_start = j * width
|
||||
x_end = x_start + width
|
||||
canvas[:, x_start:x_end] = img
|
||||
|
||||
# 添加信息
|
||||
ep_name = f"Ep {data['idx']}"
|
||||
cv2.putText(canvas, ep_name, (x_start + 10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
|
||||
cv2.putText(canvas, f"qpos[0:3]: [{qpos[0]:.2f}, {qpos[1]:.2f}, {qpos[2]:.2f}]",
|
||||
(x_start + 10, height - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
|
||||
|
||||
# 添加帧号
|
||||
cv2.putText(canvas, f"Frame: {i}/{total_frames}", (10, canvas_height - 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
||||
|
||||
out.write(canvas)
|
||||
|
||||
if i % 100 == 0:
|
||||
print(f" 进度: {i}/{total_frames}")
|
||||
|
||||
out.release()
|
||||
print(f"✅ 保存对比视频: {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
print("="*60)
|
||||
print("数据集视频生成工具")
|
||||
print("="*60)
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
command = sys.argv[1]
|
||||
|
||||
if command == 'compare':
|
||||
# 对比多个 episode
|
||||
camera = sys.argv[2] if len(sys.argv) > 2 else 'top'
|
||||
compare_episodes(camera=camera, slow_factor=2)
|
||||
|
||||
elif command == 'multi':
|
||||
# 多相机视频
|
||||
ep_idx = int(sys.argv[2]) if len(sys.argv) > 2 else 0
|
||||
generate_multi_camera_video(episode_idx=ep_idx, slow_factor=1)
|
||||
|
||||
else:
|
||||
print("未知命令")
|
||||
else:
|
||||
# 默认:生成前 5 个 episode 的视频
|
||||
print("\n生成前 5 个 episode 的视频(top 相机,慢放 2x)...")
|
||||
print("="*60 + "\n")
|
||||
generate_all_videos(camera='top', num_episodes=5, slow_factor=2)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("其他用法:")
|
||||
print(" python generate_dataset_videos.py compare top # 对比多个 episode")
|
||||
print(" python generate_dataset_videos.py multi 0 # 多相机视频")
|
||||
print("="*60)
|
||||
125
gr00t/main.py
125
gr00t/main.py
@@ -1,125 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
GR00T (diffusion-based DiT policy) model builder.
|
||||
|
||||
This module provides functions to build GR00T models and optimizers
|
||||
from configuration dictionaries (typically from config.yaml's 'gr00t:' section).
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from .models import build_gr00t_model
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
"""
|
||||
Create argument parser for GR00T model configuration.
|
||||
|
||||
All parameters can be overridden via args_override dictionary in
|
||||
build_gr00t_model_and_optimizer(). This allows loading from config.yaml.
|
||||
"""
|
||||
parser = argparse.ArgumentParser('GR00T training and evaluation script', add_help=False)
|
||||
|
||||
# Training parameters
|
||||
parser.add_argument('--lr', default=1e-5, type=float,
|
||||
help='Learning rate for main parameters')
|
||||
parser.add_argument('--lr_backbone', default=1e-5, type=float,
|
||||
help='Learning rate for backbone parameters')
|
||||
parser.add_argument('--weight_decay', default=1e-4, type=float,
|
||||
help='Weight decay for optimizer')
|
||||
|
||||
# GR00T model architecture parameters
|
||||
parser.add_argument('--embed_dim', default=1536, type=int,
|
||||
help='Embedding dimension for transformer')
|
||||
parser.add_argument('--hidden_dim', default=1024, type=int,
|
||||
help='Hidden dimension for MLP layers')
|
||||
parser.add_argument('--state_dim', default=16, type=int,
|
||||
help='State (qpos) dimension')
|
||||
parser.add_argument('--action_dim', default=16, type=int,
|
||||
help='Action dimension')
|
||||
parser.add_argument('--num_queries', default=16, type=int,
|
||||
help='Number of action queries (chunk size)')
|
||||
|
||||
# DiT (Diffusion Transformer) parameters
|
||||
parser.add_argument('--num_layers', default=16, type=int,
|
||||
help='Number of transformer layers')
|
||||
parser.add_argument('--nheads', default=32, type=int,
|
||||
help='Number of attention heads')
|
||||
parser.add_argument('--mlp_ratio', default=4, type=float,
|
||||
help='MLP hidden dimension ratio')
|
||||
parser.add_argument('--dropout', default=0.2, type=float,
|
||||
help='Dropout rate')
|
||||
|
||||
# Backbone parameters
|
||||
parser.add_argument('--backbone', default='dino_v2', type=str,
|
||||
help='Backbone architecture (dino_v2, resnet18, resnet34)')
|
||||
parser.add_argument('--position_embedding', default='sine', type=str,
|
||||
choices=('sine', 'learned'),
|
||||
help='Type of positional encoding')
|
||||
|
||||
# Camera configuration
|
||||
parser.add_argument('--camera_names', default=[], nargs='+',
|
||||
help='List of camera names for observations')
|
||||
|
||||
# Other parameters (not directly used but kept for compatibility)
|
||||
parser.add_argument('--batch_size', default=15, type=int)
|
||||
parser.add_argument('--epochs', default=20000, type=int)
|
||||
parser.add_argument('--masks', action='store_true',
|
||||
help='Use intermediate layer features')
|
||||
parser.add_argument('--dilation', action='store_false',
|
||||
help='Use dilated convolution in backbone')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def build_gr00t_model_and_optimizer(args_override):
|
||||
"""
|
||||
Build GR00T model and optimizer from config dictionary.
|
||||
|
||||
This function is designed to work with config.yaml loading:
|
||||
1. Parse default arguments
|
||||
2. Override with values from args_override (typically from config['gr00t'])
|
||||
3. Build model and optimizer
|
||||
|
||||
Args:
|
||||
args_override: Dictionary of config values, typically from config.yaml's 'gr00t:' section
|
||||
Expected keys: embed_dim, hidden_dim, state_dim, action_dim,
|
||||
num_queries, nheads, mlp_ratio, dropout, num_layers,
|
||||
lr, lr_backbone, camera_names, backbone, etc.
|
||||
|
||||
Returns:
|
||||
model: GR00T model on CUDA
|
||||
optimizer: AdamW optimizer with separate learning rates for backbone and other params
|
||||
"""
|
||||
parser = argparse.ArgumentParser('GR00T training and evaluation script',
|
||||
parents=[get_args_parser()])
|
||||
args = parser.parse_args()
|
||||
|
||||
# Override with config values
|
||||
for k, v in args_override.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
# Build model
|
||||
model = build_gr00t_model(args)
|
||||
model.cuda()
|
||||
|
||||
# Create parameter groups with different learning rates
|
||||
param_dicts = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters()
|
||||
if "backbone" not in n and p.requires_grad]
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters()
|
||||
if "backbone" in n and p.requires_grad],
|
||||
"lr": args.lr_backbone,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer = torch.optim.AdamW(param_dicts,
|
||||
lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
@@ -1,3 +0,0 @@
|
||||
from .gr00t import build_gr00t_model
|
||||
|
||||
__all__ = ['build_gr00t_model']
|
||||
@@ -1,142 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
embedding_dim = args.embed_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps):
|
||||
dtype = next(self.parameters()).dtype
|
||||
timesteps_proj = self.time_proj(timesteps).to(dtype)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, embedding_dim, norm_eps=1e-5, norm_elementwise_affine=False):
|
||||
super().__init__()
|
||||
|
||||
output_dim = embedding_dim * 2
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
temb = self.linear(self.silu(temb))
|
||||
scale, shift = temb.chunk(2, dim=1)
|
||||
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, args, crosss_attention_dim, use_self_attn=False):
|
||||
super().__init__()
|
||||
dim = args.embed_dim
|
||||
num_heads = args.nheads
|
||||
mlp_ratio = args.mlp_ratio
|
||||
dropout = args.dropout
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
|
||||
if not use_self_attn:
|
||||
self.attn = nn.MultiheadAttention(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
kdim=crosss_attention_dim,
|
||||
vdim=crosss_attention_dim,
|
||||
batch_first=True,
|
||||
)
|
||||
else:
|
||||
self.attn = nn.MultiheadAttention(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, dim * mlp_ratio),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim * mlp_ratio, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, temb, context=None):
|
||||
norm_hidden_states = self.norm1(hidden_states, temb)
|
||||
|
||||
attn_output = self.attn(
|
||||
norm_hidden_states,
|
||||
context if context is not None else norm_hidden_states,
|
||||
context if context is not None else norm_hidden_states,
|
||||
)[0]
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
|
||||
ff_output = self.mlp(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
class DiT(nn.Module):
|
||||
def __init__(self, args, cross_attention_dim):
|
||||
super().__init__()
|
||||
inner_dim = args.embed_dim
|
||||
num_layers = args.num_layers
|
||||
output_dim = args.hidden_dim
|
||||
|
||||
self.timestep_encoder = TimestepEncoder(args)
|
||||
|
||||
all_blocks = []
|
||||
for idx in range(num_layers):
|
||||
use_self_attn = idx % 2 == 1
|
||||
if use_self_attn:
|
||||
block = BasicTransformerBlock(args, crosss_attention_dim=None, use_self_attn=True)
|
||||
else:
|
||||
block = BasicTransformerBlock(args, crosss_attention_dim=cross_attention_dim, use_self_attn=False)
|
||||
all_blocks.append(block)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(all_blocks)
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
||||
self.proj_out_2 = nn.Linear(inner_dim, output_dim)
|
||||
|
||||
def forward(self, hidden_states, timestep, encoder_hidden_states):
|
||||
temb = self.timestep_encoder(timestep)
|
||||
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if idx % 2 == 1:
|
||||
hidden_states = block(hidden_states, temb)
|
||||
else:
|
||||
hidden_states = block(hidden_states, temb, context=encoder_hidden_states)
|
||||
|
||||
conditioning = temb
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
def build_dit(args, cross_attention_dim):
|
||||
return DiT(args, cross_attention_dim)
|
||||
@@ -1,124 +0,0 @@
|
||||
|
||||
from .modules import (
|
||||
build_action_decoder,
|
||||
build_action_encoder,
|
||||
build_state_encoder,
|
||||
build_time_sampler,
|
||||
build_noise_scheduler,
|
||||
)
|
||||
from .backbone import build_backbone
|
||||
from .dit import build_dit
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class gr00t(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
backbones,
|
||||
dit,
|
||||
state_encoder,
|
||||
action_encoder,
|
||||
action_decoder,
|
||||
time_sampler,
|
||||
noise_scheduler,
|
||||
num_queries,
|
||||
camera_names,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.camera_names = camera_names
|
||||
self.dit = dit
|
||||
self.state_encoder = state_encoder
|
||||
self.action_encoder = action_encoder
|
||||
self.action_decoder = action_decoder
|
||||
self.time_sampler = time_sampler
|
||||
self.noise_scheduler = noise_scheduler
|
||||
|
||||
if backbones is not None:
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, qpos, image, actions=None, is_pad=None):
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
|
||||
all_cam_features = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
B, C, H, W = features.shape
|
||||
features_seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)
|
||||
all_cam_features.append(features_seq)
|
||||
encoder_hidden_states = torch.cat(all_cam_features, dim=1)
|
||||
|
||||
state_features = self.state_encoder(qpos) # [B, 1, emb_dim]
|
||||
|
||||
if is_training:
|
||||
# training logic
|
||||
|
||||
timesteps = self.time_sampler(bs, actions.device, actions.dtype)
|
||||
noisy_actions, target_velocity = self.noise_scheduler.add_noise(
|
||||
actions, timesteps
|
||||
)
|
||||
t_discretized = (timesteps[:, 0, 0] * 1000).long()
|
||||
action_features = self.action_encoder(noisy_actions, t_discretized)
|
||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||
model_output = self.dit(sa_embs, t_discretized, encoder_hidden_states)
|
||||
pred = self.action_decoder(model_output)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
action_loss = F.mse_loss(pred_actions, target_velocity, reduction='none')
|
||||
return pred_actions, action_loss
|
||||
else:
|
||||
actions = torch.randn(bs, self.num_queries, qpos.shape[-1], device=qpos.device, dtype=qpos.dtype)
|
||||
k = 5
|
||||
dt = 1.0 / k
|
||||
for t in range(k):
|
||||
t_cont = t / float(k)
|
||||
t_discretized = int(t_cont * 1000)
|
||||
timesteps = torch.full((bs,), t_discretized, device=qpos.device, dtype=qpos.dtype)
|
||||
action_features = self.action_encoder(actions, timesteps)
|
||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||
# Create tensor of shape [B] for DiT (consistent with training path)
|
||||
model_output = self.dit(sa_embs, timesteps, encoder_hidden_states)
|
||||
pred = self.action_decoder(model_output)
|
||||
pred_velocity = pred[:, -self.num_queries :]
|
||||
actions = actions + pred_velocity * dt
|
||||
return actions, _
|
||||
def build_gr00t_model(args):
|
||||
state_dim = args.state_dim
|
||||
action_dim = args.action_dim
|
||||
|
||||
backbones = []
|
||||
for _ in args.camera_names:
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
cross_attention_dim = backbones[0].num_channels
|
||||
|
||||
dit = build_dit(args, cross_attention_dim)
|
||||
|
||||
state_encoder = build_state_encoder(args)
|
||||
action_encoder = build_action_encoder(args)
|
||||
action_decoder = build_action_decoder(args)
|
||||
time_sampler = build_time_sampler(args)
|
||||
noise_scheduler = build_noise_scheduler(args)
|
||||
model = gr00t(
|
||||
backbones,
|
||||
dit,
|
||||
state_encoder,
|
||||
action_encoder,
|
||||
action_decoder,
|
||||
time_sampler,
|
||||
noise_scheduler,
|
||||
args.num_queries,
|
||||
args.camera_names,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters/1e6,))
|
||||
return model
|
||||
|
||||
|
||||
@@ -1,179 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# ActionEncoder
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.embed_dim = args.embed_dim
|
||||
|
||||
def forward(self, timesteps):
|
||||
timesteps = timesteps.float()
|
||||
B, T = timesteps.shape
|
||||
device = timesteps.device
|
||||
|
||||
half_dim = self.embed_dim // 2
|
||||
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
|
||||
sin = torch.sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||
|
||||
return enc
|
||||
|
||||
|
||||
class ActionEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
action_dim = args.action_dim
|
||||
embed_dim = args.embed_dim
|
||||
|
||||
self.W1 = nn.Linear(action_dim, embed_dim)
|
||||
self.W2 = nn.Linear(2 * embed_dim, embed_dim)
|
||||
self.W3 = nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
self.pos_encoder = SinusoidalPositionalEncoding(args)
|
||||
|
||||
def forward(self, actions, timesteps):
|
||||
B, T, _ = actions.shape
|
||||
|
||||
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||
# so that shape => (B, T)
|
||||
# Handle different input shapes: (B,), (B, 1), (B, 1, 1)
|
||||
# Reshape to (B,) then expand to (B, T)
|
||||
# if timesteps.dim() == 3:
|
||||
# # Shape (B, 1, 1) or (B, T, 1) -> (B,)
|
||||
# timesteps = timesteps[:, 0, 0]
|
||||
# elif timesteps.dim() == 2:
|
||||
# # Shape (B, 1) or (B, T) -> take first element if needed
|
||||
# if timesteps.shape[1] == 1:
|
||||
# timesteps = timesteps[:, 0]
|
||||
# # else: already (B, T), use as is
|
||||
# elif timesteps.dim() != 1:
|
||||
# raise ValueError(
|
||||
# f"Expected `timesteps` to have shape (B,), (B, 1), or (B, 1, 1), got {timesteps.shape}"
|
||||
# )
|
||||
|
||||
# Now timesteps should be (B,), expand to (B, T)
|
||||
if timesteps.dim() == 1 and timesteps.shape[0] == B:
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, T)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected `timesteps` to have shape (B,) so we can replicate across T."
|
||||
)
|
||||
|
||||
# 2) Standard action MLP step for shape => (B, T, w)
|
||||
a_emb = self.W1(actions)
|
||||
|
||||
# 3) Get the sinusoidal encoding (B, T, w)
|
||||
tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype)
|
||||
|
||||
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||
x = F.silu(self.W2(x))
|
||||
|
||||
# 5) Finally W3 => (B, T, w)
|
||||
x = self.W3(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def build_action_encoder(args):
|
||||
return ActionEncoder(args)
|
||||
|
||||
|
||||
# StateEncoder
|
||||
class StateEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
input_dim = args.state_dim
|
||||
hidden_dim = args.hidden_dim
|
||||
output_dim = args.embed_dim
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, output_dim),
|
||||
)
|
||||
|
||||
def forward(self, states):
|
||||
state_emb = self.mlp(states) # [B, emb_dim]
|
||||
state_emb = state_emb.unsqueeze(1)
|
||||
return state_emb # [B, 1, emb_dim]
|
||||
|
||||
|
||||
def build_state_encoder(args):
|
||||
return StateEncoder(args)
|
||||
|
||||
|
||||
# ActionDecoder
|
||||
class ActionDecoder(nn.Module):
|
||||
def __init__(self,args):
|
||||
super().__init__()
|
||||
input_dim = args.hidden_dim
|
||||
hidden_dim = args.hidden_dim
|
||||
output_dim = args.action_dim
|
||||
|
||||
self.num_queries = args.num_queries
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, output_dim),
|
||||
)
|
||||
|
||||
def forward(self, model_output):
|
||||
pred_actions = self.mlp(model_output)
|
||||
return pred_actions[:, -self.num_queries:]
|
||||
|
||||
|
||||
def build_action_decoder(args):
|
||||
return ActionDecoder(args)
|
||||
|
||||
|
||||
# TimeSampler
|
||||
class TimeSampler(nn.Module):
|
||||
def __init__(self, noise_s = 0.999, noise_beta_alpha=1.5, noise_beta_beta=1.0):
|
||||
super().__init__()
|
||||
self.noise_s = noise_s
|
||||
self.beta_dist = torch.distributions.Beta(noise_beta_alpha, noise_beta_beta)
|
||||
|
||||
def forward(self, batch_size, device, dtype):
|
||||
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
sample = (1 - sample) * self.noise_s
|
||||
return sample[:, None, None]
|
||||
|
||||
|
||||
def build_time_sampler(args):
|
||||
return TimeSampler()
|
||||
|
||||
|
||||
# NoiseScheduler
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class FlowMatchingScheduler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# --- 训练逻辑:加噪并计算目标 ---
|
||||
def add_noise(self, actions, timesteps):
|
||||
noise = torch.randn_like(actions)
|
||||
noisy_samples = actions * timesteps + noise * (1 - timesteps)
|
||||
target_velocity = actions - noise
|
||||
|
||||
return noisy_samples, target_velocity
|
||||
|
||||
# --- 推理逻辑:欧拉步 (Euler Step) ---
|
||||
def step(self, model_output, sample, dt):
|
||||
prev_sample = sample + model_output * dt
|
||||
return prev_sample
|
||||
|
||||
def build_noise_scheduler(args):
|
||||
return FlowMatchingScheduler()
|
||||
@@ -1,90 +0,0 @@
|
||||
"""
|
||||
GR00T Policy wrapper for imitation learning.
|
||||
|
||||
This module provides the gr00tPolicy class that wraps the GR00T model
|
||||
for training and evaluation in the imitation learning framework.
|
||||
"""
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision.transforms import v2
|
||||
import torch
|
||||
from roboimi.gr00t.main import build_gr00t_model_and_optimizer
|
||||
|
||||
|
||||
class gr00tPolicy(nn.Module):
|
||||
"""
|
||||
GR00T Policy for action prediction using diffusion-based DiT architecture.
|
||||
|
||||
This policy wraps the GR00T model and handles:
|
||||
- Image resizing to match DINOv2 patch size requirements
|
||||
- Image normalization (ImageNet stats)
|
||||
- Training with action chunks and loss computation
|
||||
- Inference with diffusion sampling
|
||||
"""
|
||||
def __init__(self, args_override):
|
||||
super().__init__()
|
||||
model, optimizer = build_gr00t_model_and_optimizer(args_override)
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
# DINOv2 requires image dimensions to be multiples of patch size (14)
|
||||
# Common sizes: 224x224, 336x336, etc. (14*16=224, 14*24=336)
|
||||
self.patch_h = 16 # Number of patches vertically
|
||||
self.patch_w = 22 # Number of patches horizontally
|
||||
target_size = (self.patch_h * 14, self.patch_w * 14) # (224, 308)
|
||||
|
||||
# Training transform with data augmentation
|
||||
self.train_transform = v2.Compose([
|
||||
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
|
||||
v2.RandomPerspective(distortion_scale=0.5),
|
||||
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
|
||||
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
|
||||
v2.Resize(target_size),
|
||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||
])
|
||||
|
||||
# Inference transform (no augmentation)
|
||||
self.inference_transform = v2.Compose([
|
||||
v2.Resize(target_size),
|
||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||
])
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
"""
|
||||
Forward pass for training or inference.
|
||||
|
||||
Args:
|
||||
qpos: Joint positions [B, state_dim]
|
||||
image: Camera images [B, num_cameras, C, H, W]
|
||||
actions: Ground truth actions [B, chunk_size, action_dim] (training only)
|
||||
is_pad: Padding mask [B, chunk_size] (training only)
|
||||
|
||||
Returns:
|
||||
Training: dict with 'mse' loss
|
||||
Inference: predicted actions [B, num_queries, action_dim]
|
||||
"""
|
||||
# Apply transforms (resize + normalization)
|
||||
if actions is not None: # training time
|
||||
image = self.train_transform(image)
|
||||
else: # inference time
|
||||
image = self.inference_transform(image)
|
||||
|
||||
if actions is not None: # training time
|
||||
actions = actions[:, :self.model.num_queries]
|
||||
is_pad = is_pad[:, :self.model.num_queries]
|
||||
_, action_loss = self.model(qpos, image, actions, is_pad)
|
||||
|
||||
# Mask out padded positions
|
||||
mse_loss = (action_loss * ~is_pad.unsqueeze(-1)).mean()
|
||||
|
||||
loss_dict = {
|
||||
'loss': mse_loss
|
||||
}
|
||||
return loss_dict
|
||||
else: # inference time
|
||||
a_hat, _ = self.model(qpos, image)
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Return the optimizer for training."""
|
||||
return self.optimizer
|
||||
1
roboimi/.gitattributes
vendored
1
roboimi/.gitattributes
vendored
@@ -1 +0,0 @@
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
@@ -3,7 +3,7 @@
|
||||
<body name="box" pos="0.2 1.0 0.47">
|
||||
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.018 0.018 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||
<geom contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||
</body>
|
||||
</worldbody>
|
||||
</mujoco>
|
||||
|
||||
@@ -8,6 +8,5 @@
|
||||
</body>
|
||||
<camera name="top" pos="0.0 1.0 2.0" fovy="44" mode="targetbody" target="table"/>
|
||||
<camera name="angle" pos="0.0 0.0 2.0" fovy="37" mode="targetbody" target="table"/>
|
||||
<camera name="front" pos="0 0 0.8" fovy="65" mode="fixed" quat="0.7071 0.7071 0 0"/>
|
||||
</worldbody>
|
||||
</mujoco>
|
||||
|
||||
@@ -58,8 +58,8 @@ class BiDianaMed(ArmBase):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Bidiana",
|
||||
urdf_path="roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf",
|
||||
xml_path="roboimi/assets/models/manipulators/DianaMed/bi_diana_transfer_ee.xml",
|
||||
urdf_path="./assets/models/manipulators/DianaMed/DualDianaMed.urdf",
|
||||
xml_path="./assets/models/manipulators/DianaMed/bi_diana_transfer_ee.xml",
|
||||
gripper=None
|
||||
)
|
||||
self.left_arm = self.Arm(self, 'single', self.urdf_path)
|
||||
|
||||
112
roboimi/ddt/main.py
Normal file
112
roboimi/ddt/main.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DDT 模型构建和优化器配置。
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from .models import build_DDT_model
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
"""获取 DDT 模型的参数解析器。"""
|
||||
parser = argparse.ArgumentParser('DDT model configuration', add_help=False)
|
||||
|
||||
# 学习率
|
||||
parser.add_argument('--lr', default=1e-4, type=float)
|
||||
parser.add_argument('--lr_backbone', default=1e-5, type=float)
|
||||
parser.add_argument('--batch_size', default=2, type=int)
|
||||
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
||||
parser.add_argument('--epochs', default=300, type=int)
|
||||
parser.add_argument('--lr_drop', default=200, type=int)
|
||||
parser.add_argument('--clip_max_norm', default=0.1, type=float,
|
||||
help='gradient clipping max norm')
|
||||
parser.add_argument('--qpos_noise_std', action='store', default=0, type=float)
|
||||
|
||||
# Backbone 参数
|
||||
parser.add_argument('--backbone', default='resnet18', type=str,
|
||||
help="Name of the convolutional backbone to use")
|
||||
parser.add_argument('--dilation', action='store_true',
|
||||
help="If true, replace stride with dilation in the last conv block")
|
||||
parser.add_argument('--position_embedding', default='sine', type=str,
|
||||
choices=('sine', 'learned'),
|
||||
help="Type of positional embedding")
|
||||
parser.add_argument('--camera_names', default=[], type=list,
|
||||
help="A list of camera names")
|
||||
|
||||
# Transformer 参数
|
||||
parser.add_argument('--enc_layers', default=4, type=int,
|
||||
help="Number of encoding layers in the transformer")
|
||||
parser.add_argument('--dec_layers', default=6, type=int,
|
||||
help="Number of decoding layers in the transformer")
|
||||
parser.add_argument('--dim_feedforward', default=2048, type=int,
|
||||
help="Intermediate size of the feedforward layers")
|
||||
parser.add_argument('--hidden_dim', default=512, type=int,
|
||||
help="Size of the embeddings (dimension of the transformer)")
|
||||
parser.add_argument('--dropout', default=0.1, type=float,
|
||||
help="Dropout applied in the transformer")
|
||||
parser.add_argument('--nheads', default=8, type=int,
|
||||
help="Number of attention heads")
|
||||
parser.add_argument('--num_queries', default=100, type=int,
|
||||
help="Number of query slots (action horizon)")
|
||||
parser.add_argument('--pre_norm', action='store_true')
|
||||
parser.add_argument('--state_dim', default=14, type=int)
|
||||
parser.add_argument('--action_dim', default=14, type=int)
|
||||
|
||||
# DDT 特有参数
|
||||
parser.add_argument('--num_blocks', default=12, type=int,
|
||||
help="Total number of transformer blocks in DDT")
|
||||
parser.add_argument('--mlp_ratio', default=4.0, type=float,
|
||||
help="MLP hidden dimension ratio")
|
||||
parser.add_argument('--num_inference_steps', default=10, type=int,
|
||||
help="Number of diffusion inference steps")
|
||||
|
||||
# Segmentation (未使用)
|
||||
parser.add_argument('--masks', action='store_true',
|
||||
help="Train segmentation head if provided")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def build_DDT_model_and_optimizer(args_override):
|
||||
"""构建 DDT 模型和优化器。
|
||||
|
||||
Args:
|
||||
args_override: 覆盖默认参数的字典
|
||||
|
||||
Returns:
|
||||
model: DDT 模型
|
||||
optimizer: AdamW 优化器
|
||||
"""
|
||||
parser = argparse.ArgumentParser('DDT training script', parents=[get_args_parser()])
|
||||
args = parser.parse_args([]) # 空列表避免命令行参数干扰
|
||||
|
||||
# 应用参数覆盖
|
||||
for k, v in args_override.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
# 构建模型
|
||||
model = build_DDT_model(args)
|
||||
model.cuda()
|
||||
|
||||
# 配置优化器(backbone 使用较小学习率)
|
||||
param_dicts = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters()
|
||||
if "backbone" not in n and p.requires_grad]
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters()
|
||||
if "backbone" in n and p.requires_grad],
|
||||
"lr": args.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(
|
||||
param_dicts,
|
||||
lr=args.lr,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
return model, optimizer
|
||||
7
roboimi/ddt/models/__init__.py
Normal file
7
roboimi/ddt/models/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from .model import build as build_ddt
|
||||
from .model import build_ddt
|
||||
|
||||
def build_DDT_model(args):
|
||||
"""构建 DDT 模型的统一入口。"""
|
||||
return build_ddt(args)
|
||||
631
roboimi/ddt/models/ddt.py
Normal file
631
roboimi/ddt/models/ddt.py
Normal file
@@ -0,0 +1,631 @@
|
||||
"""
|
||||
动作序列扩散 Transformer (Action Decoupled Diffusion Transformer)
|
||||
|
||||
基于 DDT 架构修改,用于生成机器人动作序列。
|
||||
主要改动:
|
||||
1. 2D RoPE → 1D RoPE (适配时序数据)
|
||||
2. LabelEmbedder → ObservationEncoder (观测条件)
|
||||
3. 去除 patchify/unpatchify (动作序列已是 1D)
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 通用工具函数
|
||||
# ============================================================================
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
"""AdaLN 调制函数。
|
||||
|
||||
Args:
|
||||
x: 输入张量。
|
||||
shift: 偏移量。
|
||||
scale: 缩放量。
|
||||
|
||||
Returns:
|
||||
调制后的张量: x * (1 + scale) + shift
|
||||
"""
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 1D 旋转位置编码 (RoPE)
|
||||
# ============================================================================
|
||||
|
||||
def precompute_freqs_cis_1d(dim: int, seq_len: int, theta: float = 10000.0) -> torch.Tensor:
|
||||
"""预计算 1D 旋转位置编码的复数频率。
|
||||
|
||||
用于时序数据(如动作序列)的位置编码,相比 2D RoPE 更简单高效。
|
||||
|
||||
Args:
|
||||
dim: 每个注意力头的维度 (head_dim)。
|
||||
seq_len: 序列长度。
|
||||
theta: RoPE 的基础频率,默认 10000.0。
|
||||
|
||||
Returns:
|
||||
复数频率张量,形状为 (seq_len, dim//2)。
|
||||
"""
|
||||
# 计算频率: 1 / (theta^(2i/dim))
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) # [dim//2]
|
||||
# 位置索引
|
||||
t = torch.arange(seq_len).float() # [seq_len]
|
||||
# 外积得到位置-频率矩阵
|
||||
freqs = torch.outer(t, freqs) # [seq_len, dim//2]
|
||||
# 转换为复数形式 (极坐标)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # [seq_len, dim//2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb_1d(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""应用 1D 旋转位置编码到 Query 和 Key。
|
||||
|
||||
Args:
|
||||
xq: Query 张量,形状为 (B, N, H, Hc)。
|
||||
xk: Key 张量,形状为 (B, N, H, Hc)。
|
||||
freqs_cis: 预计算的复数频率,形状为 (N, Hc//2)。
|
||||
|
||||
Returns:
|
||||
应用 RoPE 后的 (xq, xk),形状不变。
|
||||
"""
|
||||
# 调整 freqs_cis 形状以便广播: [1, N, 1, Hc//2]
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
|
||||
# 将实数张量视为复数: [B, N, H, Hc] -> [B, N, H, Hc//2] (复数)
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
|
||||
# 复数乘法实现旋转
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [B, N, H, Hc]
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 基础组件
|
||||
# ============================================================================
|
||||
|
||||
class Embed(nn.Module):
|
||||
"""线性嵌入层,将输入投影到隐藏空间。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer: Optional[nn.Module] = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
"""初始化 Embed。
|
||||
|
||||
Args:
|
||||
in_chans: 输入通道数/维度。
|
||||
embed_dim: 输出嵌入维度。
|
||||
norm_layer: 可选的归一化层。
|
||||
bias: 是否使用偏置。
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""扩散时间步嵌入器。
|
||||
|
||||
使用正弦位置编码 + MLP 将标量时间步映射到高维向量。
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
|
||||
"""初始化 TimestepEmbedder。
|
||||
|
||||
Args:
|
||||
hidden_size: 输出嵌入维度。
|
||||
frequency_embedding_size: 正弦编码的维度。
|
||||
"""
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10.0) -> torch.Tensor:
|
||||
"""生成正弦时间步嵌入。
|
||||
|
||||
Args:
|
||||
t: 时间步张量,形状为 (B,)。
|
||||
dim: 嵌入维度。
|
||||
max_period: 最大周期。
|
||||
|
||||
Returns:
|
||||
时间步嵌入,形状为 (B, dim)。
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class ObservationEncoder(nn.Module):
|
||||
"""观测状态编码器。
|
||||
|
||||
将机器人的观测向量(如关节位置、末端位姿、图像特征等)
|
||||
编码为条件向量,用于条件扩散生成。
|
||||
|
||||
Attributes:
|
||||
encoder: 两层 MLP 编码器。
|
||||
|
||||
Example:
|
||||
>>> encoder = ObservationEncoder(obs_dim=128, hidden_size=512)
|
||||
>>> obs = torch.randn(2, 128)
|
||||
>>> cond = encoder(obs) # [2, 512]
|
||||
"""
|
||||
|
||||
def __init__(self, obs_dim: int, hidden_size: int):
|
||||
"""初始化 ObservationEncoder。
|
||||
|
||||
Args:
|
||||
obs_dim: 观测向量的维度。
|
||||
hidden_size: 输出条件向量的维度。
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(obs_dim, hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
)
|
||||
|
||||
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||
"""前向传播。
|
||||
|
||||
Args:
|
||||
obs: 观测向量,形状为 (B, obs_dim)。
|
||||
|
||||
Returns:
|
||||
条件向量,形状为 (B, hidden_size)。
|
||||
"""
|
||||
return self.encoder(obs)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""最终输出层,使用 AdaLN 调制后输出预测结果。"""
|
||||
|
||||
def __init__(self, hidden_size: int, out_channels: int):
|
||||
"""初始化 FinalLayer。
|
||||
|
||||
Args:
|
||||
hidden_size: 输入隐藏维度。
|
||||
out_channels: 输出通道数/维度。
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
"""前向传播。
|
||||
|
||||
Args:
|
||||
x: 输入张量,形状为 (B, N, hidden_size)。
|
||||
c: 条件张量,形状为 (B, N, hidden_size) 或 (B, 1, hidden_size)。
|
||||
|
||||
Returns:
|
||||
输出张量,形状为 (B, N, out_channels)。
|
||||
"""
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 归一化和前馈网络
|
||||
# ============================================================================
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""Root Mean Square Layer Normalization (RMS 归一化层)。
|
||||
|
||||
RMSNorm 是 LayerNorm 的简化版本,去掉了均值中心化操作,只保留缩放。
|
||||
相比 LayerNorm 计算更快,效果相当,被广泛用于 LLaMA、Mistral 等大模型。
|
||||
|
||||
数学公式:
|
||||
RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
"""初始化 RMSNorm。
|
||||
|
||||
Args:
|
||||
hidden_size: 输入特征的维度。
|
||||
eps: 防止除零的小常数。
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""SwiGLU 前馈网络 (Feed-Forward Network)。
|
||||
|
||||
使用 SwiGLU 门控激活函数的前馈网络,来自 LLaMA 架构。
|
||||
|
||||
结构:
|
||||
output = W2(SiLU(W1(x)) * W3(x))
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, hidden_dim: int):
|
||||
"""初始化 FeedForward。
|
||||
|
||||
Args:
|
||||
dim: 输入和输出的特征维度。
|
||||
hidden_dim: 隐藏层维度(实际使用 2/3 * hidden_dim)。
|
||||
"""
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 注意力机制
|
||||
# ============================================================================
|
||||
|
||||
class RAttention(nn.Module):
|
||||
"""带有旋转位置编码的多头自注意力 (Rotary Attention)。
|
||||
|
||||
集成了以下技术:
|
||||
- 1D RoPE: 通过复数旋转编码时序位置信息
|
||||
- QK-Norm: 对 Query 和 Key 进行归一化,稳定训练
|
||||
- Flash Attention: 使用 scaled_dot_product_attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
"""初始化 RAttention。
|
||||
|
||||
Args:
|
||||
dim: 输入特征维度,必须能被 num_heads 整除。
|
||||
num_heads: 注意力头数。
|
||||
qkv_bias: QKV 投影是否使用偏置。
|
||||
qk_norm: 是否对 Q, K 进行归一化。
|
||||
attn_drop: 注意力权重的 dropout 率。
|
||||
proj_drop: 输出投影的 dropout 率。
|
||||
norm_layer: 归一化层类型。
|
||||
"""
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""前向传播。
|
||||
|
||||
Args:
|
||||
x: 输入张量,形状为 (B, N, C)。
|
||||
pos: 1D RoPE 位置编码,形状为 (N, head_dim//2)。
|
||||
mask: 可选的注意力掩码。
|
||||
|
||||
Returns:
|
||||
输出张量,形状为 (B, N, C)。
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
|
||||
# QKV 投影
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # [B, N, H, Hc]
|
||||
|
||||
# QK-Norm
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# 应用 1D RoPE
|
||||
q, k = apply_rotary_emb_1d(q, k, freqs_cis=pos)
|
||||
|
||||
# 调整维度: [B, N, H, Hc] -> [B, H, N, Hc]
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2)
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
# Scaled Dot-Product Attention
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
# 输出投影
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Transformer Block
|
||||
# ============================================================================
|
||||
|
||||
class ActionDDTBlock(nn.Module):
|
||||
"""动作 DDT Transformer Block。
|
||||
|
||||
结构: Pre-Norm + AdaLN + Attention + FFN
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0):
|
||||
"""初始化 ActionDDTBlock。
|
||||
|
||||
Args:
|
||||
hidden_size: 隐藏层维度。
|
||||
num_heads: 注意力头数。
|
||||
mlp_ratio: FFN 隐藏层倍率。
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=num_heads, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""前向传播。
|
||||
|
||||
Args:
|
||||
x: 输入张量,形状为 (B, N, hidden_size)。
|
||||
c: 条件张量,形状为 (B, 1, hidden_size) 或 (B, N, hidden_size)。
|
||||
pos: 位置编码。
|
||||
mask: 可选的注意力掩码。
|
||||
|
||||
Returns:
|
||||
输出张量,形状为 (B, N, hidden_size)。
|
||||
"""
|
||||
# AdaLN 调制参数
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
|
||||
self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
|
||||
# Attention 分支
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
# FFN 分支
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 主模型: ActionDDT
|
||||
# ============================================================================
|
||||
|
||||
class ActionDDT(nn.Module):
|
||||
"""动作序列解耦扩散 Transformer (Action Decoupled Diffusion Transformer)。
|
||||
|
||||
基于 DDT 架构,专为机器人动作序列生成设计。
|
||||
将模型解耦为编码器和解码器两部分,编码器状态可缓存以加速推理。
|
||||
|
||||
架构:
|
||||
- 编码器: 前 num_encoder_blocks 个 block,生成状态 s
|
||||
- 解码器: 剩余 block,使用状态 s 对动作序列 x 去噪
|
||||
- 使用 1D RoPE 进行时序位置编码
|
||||
- 使用 AdaLN 注入时间步和观测条件
|
||||
|
||||
Args:
|
||||
action_dim: 动作向量维度(如 7-DoF 机械臂为 7)。
|
||||
obs_dim: 观测向量维度。
|
||||
action_horizon: 预测的动作序列长度。
|
||||
hidden_size: Transformer 隐藏层维度。
|
||||
num_blocks: Transformer block 总数。
|
||||
num_encoder_blocks: 编码器 block 数量。
|
||||
num_heads: 注意力头数。
|
||||
mlp_ratio: FFN 隐藏层倍率。
|
||||
|
||||
输入:
|
||||
x (Tensor): 带噪声的动作序列,形状为 (B, T, action_dim)。
|
||||
t (Tensor): 扩散时间步,形状为 (B,),取值范围 [0, 1]。
|
||||
obs (Tensor): 观测条件,形状为 (B, obs_dim)。
|
||||
s (Tensor, optional): 缓存的编码器状态。
|
||||
|
||||
输出:
|
||||
x (Tensor): 预测的速度场/噪声,形状为 (B, T, action_dim)。
|
||||
s (Tensor): 编码器状态,可缓存复用。
|
||||
|
||||
Example:
|
||||
>>> model = ActionDDT(action_dim=7, obs_dim=128, action_horizon=16)
|
||||
>>> x = torch.randn(2, 16, 7) # 带噪声的动作序列
|
||||
>>> t = torch.rand(2) # 随机时间步
|
||||
>>> obs = torch.randn(2, 128) # 观测条件
|
||||
>>> out, state = model(x, t, obs)
|
||||
>>> out.shape
|
||||
torch.Size([2, 16, 7])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_dim: int = 7,
|
||||
obs_dim: int = 128,
|
||||
action_horizon: int = 16,
|
||||
hidden_size: int = 512,
|
||||
num_blocks: int = 12,
|
||||
num_encoder_blocks: int = 4,
|
||||
num_heads: int = 8,
|
||||
mlp_ratio: float = 4.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 保存配置
|
||||
self.action_dim = action_dim
|
||||
self.obs_dim = obs_dim
|
||||
self.action_horizon = action_horizon
|
||||
self.hidden_size = hidden_size
|
||||
self.num_blocks = num_blocks
|
||||
self.num_encoder_blocks = num_encoder_blocks
|
||||
self.num_heads = num_heads
|
||||
|
||||
# 动作嵌入层
|
||||
self.x_embedder = Embed(action_dim, hidden_size, bias=True)
|
||||
self.s_embedder = Embed(action_dim, hidden_size, bias=True)
|
||||
|
||||
# 条件嵌入
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.obs_encoder = ObservationEncoder(obs_dim, hidden_size)
|
||||
|
||||
# 输出层
|
||||
self.final_layer = FinalLayer(hidden_size, action_dim)
|
||||
|
||||
# Transformer blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
ActionDDTBlock(hidden_size, num_heads, mlp_ratio)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
# 预计算 1D 位置编码
|
||||
pos = precompute_freqs_cis_1d(hidden_size // num_heads, action_horizon)
|
||||
self.register_buffer('pos', pos)
|
||||
|
||||
# 初始化权重
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
"""初始化模型权重。"""
|
||||
# 嵌入层使用 Xavier 初始化
|
||||
for embedder in [self.x_embedder, self.s_embedder]:
|
||||
w = embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(embedder.proj.bias, 0)
|
||||
|
||||
# 时间步嵌入 MLP
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# 观测编码器
|
||||
for m in self.obs_encoder.encoder:
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# 输出层零初始化 (AdaLN-Zero)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
obs: torch.Tensor,
|
||||
s: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""前向传播。
|
||||
|
||||
Args:
|
||||
x: 带噪声的动作序列 [B, T, action_dim]
|
||||
t: 扩散时间步 [B] 或 [B, 1],取值范围 [0, 1]
|
||||
obs: 观测条件 [B, obs_dim]
|
||||
s: 可选的编码器状态缓存 [B, T, hidden_size]
|
||||
mask: 可选的注意力掩码
|
||||
|
||||
Returns:
|
||||
x: 预测的速度场/噪声 [B, T, action_dim]
|
||||
s: 编码器状态 [B, T, hidden_size],可缓存复用
|
||||
"""
|
||||
B, T, _ = x.shape
|
||||
|
||||
# 1. 时间步嵌入: [B] -> [B, 1, hidden_size]
|
||||
t_emb = self.t_embedder(t.view(-1)).view(B, 1, self.hidden_size)
|
||||
|
||||
# 2. 观测条件嵌入: [B, obs_dim] -> [B, 1, hidden_size]
|
||||
obs_emb = self.obs_encoder(obs).view(B, 1, self.hidden_size)
|
||||
|
||||
# 3. 融合条件: c = SiLU(t + obs)
|
||||
c = nn.functional.silu(t_emb + obs_emb)
|
||||
|
||||
# 4. 编码器部分: 生成状态 s
|
||||
if s is None:
|
||||
# 状态嵌入: [B, T, action_dim] -> [B, T, hidden_size]
|
||||
s = self.s_embedder(x)
|
||||
# 通过编码器 blocks
|
||||
for i in range(self.num_encoder_blocks):
|
||||
s = self.blocks[i](s, c, self.pos, mask)
|
||||
# 融合时间信息
|
||||
s = nn.functional.silu(t_emb + s)
|
||||
|
||||
# 5. 解码器部分: 去噪
|
||||
# 输入嵌入: [B, T, action_dim] -> [B, T, hidden_size]
|
||||
x = self.x_embedder(x)
|
||||
# 通过解码器 blocks,以 s 作为条件
|
||||
for i in range(self.num_encoder_blocks, self.num_blocks):
|
||||
x = self.blocks[i](x, s, self.pos, None)
|
||||
|
||||
# 6. 最终层: [B, T, hidden_size] -> [B, T, action_dim]
|
||||
x = self.final_layer(x, s)
|
||||
|
||||
return x, s
|
||||
304
roboimi/ddt/models/model.py
Normal file
304
roboimi/ddt/models/model.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DDT model and criterion classes.
|
||||
|
||||
核心组装文件,将 Backbone、Transformer、Diffusion 组件组装为完整模型。
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from .backbone import build_backbone
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""Spatial Softmax 层,将特征图转换为关键点坐标。
|
||||
|
||||
来自 Diffusion Policy,保留空间位置信息。
|
||||
对每个通道计算软注意力加权的期望坐标。
|
||||
|
||||
Args:
|
||||
num_kp: 关键点数量(等于输入通道数)
|
||||
temperature: Softmax 温度参数(可学习)
|
||||
learnable_temperature: 是否学习温度参数
|
||||
|
||||
输入: [B, C, H, W]
|
||||
输出: [B, C * 2] - 每个通道输出 (x, y) 坐标
|
||||
"""
|
||||
|
||||
def __init__(self, num_kp: int = None, temperature: float = 1.0, learnable_temperature: bool = True):
|
||||
super().__init__()
|
||||
self.num_kp = num_kp
|
||||
if learnable_temperature:
|
||||
self.temperature = nn.Parameter(torch.ones(1) * temperature)
|
||||
else:
|
||||
self.register_buffer('temperature', torch.ones(1) * temperature)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# 生成归一化坐标网格 [-1, 1]
|
||||
pos_x = torch.linspace(-1, 1, W, device=x.device, dtype=x.dtype)
|
||||
pos_y = torch.linspace(-1, 1, H, device=x.device, dtype=x.dtype)
|
||||
|
||||
# 展平空间维度
|
||||
x_flat = x.view(B, C, -1) # [B, C, H*W]
|
||||
|
||||
# Softmax 得到注意力权重
|
||||
attention = F.softmax(x_flat / self.temperature, dim=-1) # [B, C, H*W]
|
||||
|
||||
# 计算期望坐标
|
||||
# pos_x: [W] -> [1, 1, W] -> repeat -> [1, 1, H*W]
|
||||
pos_x_grid = pos_x.view(1, 1, 1, W).expand(1, 1, H, W).reshape(1, 1, -1)
|
||||
pos_y_grid = pos_y.view(1, 1, H, 1).expand(1, 1, H, W).reshape(1, 1, -1)
|
||||
|
||||
# 加权求和得到期望坐标
|
||||
expected_x = (attention * pos_x_grid).sum(dim=-1) # [B, C]
|
||||
expected_y = (attention * pos_y_grid).sum(dim=-1) # [B, C]
|
||||
|
||||
# 拼接 x, y 坐标
|
||||
keypoints = torch.cat([expected_x, expected_y], dim=-1) # [B, C * 2]
|
||||
|
||||
return keypoints
|
||||
|
||||
from .ddt import ActionDDT
|
||||
|
||||
|
||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
"""生成正弦位置编码表。"""
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
|
||||
class DDT(nn.Module):
|
||||
"""DDT (Decoupled Diffusion Transformer) 模型。
|
||||
|
||||
将视觉 Backbone 和 ActionDDT 扩散模型组合,实现基于图像观测的动作序列生成。
|
||||
|
||||
架构:
|
||||
1. Backbone: 提取多相机图像特征
|
||||
2. 特征投影: 将图像特征投影到隐藏空间 (Bottleneck 降维)
|
||||
3. 状态编码: 编码机器人关节状态
|
||||
4. ActionDDT: 扩散 Transformer 生成动作序列
|
||||
|
||||
Args:
|
||||
backbones: 视觉骨干网络列表(每个相机一个)
|
||||
state_dim: 机器人状态维度
|
||||
action_dim: 动作维度
|
||||
num_queries: 预测的动作序列长度
|
||||
camera_names: 相机名称列表
|
||||
hidden_dim: Transformer 隐藏维度
|
||||
num_blocks: Transformer block 数量
|
||||
num_encoder_blocks: 编码器 block 数量
|
||||
num_heads: 注意力头数
|
||||
num_kp: Spatial Softmax 的关键点数量 (默认 32)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbones,
|
||||
state_dim: int,
|
||||
action_dim: int,
|
||||
num_queries: int,
|
||||
camera_names: list,
|
||||
hidden_dim: int = 512,
|
||||
num_blocks: int = 12,
|
||||
num_encoder_blocks: int = 4,
|
||||
num_heads: int = 8,
|
||||
mlp_ratio: float = 4.0,
|
||||
num_kp: int = 32, # [修改] 新增参数,默认 32
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_queries = num_queries
|
||||
self.camera_names = camera_names
|
||||
self.hidden_dim = hidden_dim
|
||||
self.state_dim = state_dim
|
||||
self.action_dim = action_dim
|
||||
self.num_kp = num_kp
|
||||
|
||||
# Backbone 相关
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
|
||||
# [修改] 投影层: ResNet Channels -> num_kp (32)
|
||||
# 这是一个 Bottleneck 层,大幅减少特征通道数
|
||||
self.input_proj = nn.Conv2d(
|
||||
backbones[0].num_channels, num_kp, kernel_size=1
|
||||
)
|
||||
|
||||
# 状态编码 (2层 MLP,与 Diffusion Policy 一致)
|
||||
# 状态依然映射到 hidden_dim (512),保持信息量
|
||||
self.input_proj_robot_state = nn.Sequential(
|
||||
nn.Linear(state_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
)
|
||||
|
||||
# [修改] 图像特征聚合 (SpatialSoftmax)
|
||||
# 输入: [B, num_kp, H, W]
|
||||
# 输出: [B, num_kp * 2] (每个通道的 x, y 坐标)
|
||||
self.img_feature_proj = SpatialSoftmax(num_kp=num_kp)
|
||||
|
||||
# [修改] 计算观测维度: 图像特征 + 状态
|
||||
# 图像部分: 关键点数量 * 2(x,y) * 摄像头数量
|
||||
img_feature_dim = num_kp * 2 * len(camera_names)
|
||||
obs_dim = img_feature_dim + hidden_dim
|
||||
|
||||
# ActionDDT 扩散模型
|
||||
self.action_ddt = ActionDDT(
|
||||
action_dim=action_dim,
|
||||
obs_dim=obs_dim, # 使用新的、更紧凑的维度
|
||||
action_horizon=num_queries,
|
||||
hidden_size=hidden_dim,
|
||||
num_blocks=num_blocks,
|
||||
num_encoder_blocks=num_encoder_blocks,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
|
||||
def encode_observations(self, qpos, image):
|
||||
"""编码观测(图像 + 状态)为条件向量。
|
||||
|
||||
Args:
|
||||
qpos: 机器人关节状态 [B, state_dim]
|
||||
image: 多相机图像 [B, num_cam, C, H, W]
|
||||
|
||||
Returns:
|
||||
obs: 观测条件向量 [B, obs_dim]
|
||||
"""
|
||||
bs = qpos.shape[0]
|
||||
|
||||
# 编码图像特征
|
||||
all_cam_features = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||
features = features[0] # 取最后一层特征
|
||||
|
||||
# [说明] 这里的 input_proj 现在会将通道压缩到 32
|
||||
features = self.input_proj(features) # [B, num_kp, H', W']
|
||||
|
||||
# [说明] SpatialSoftmax 提取 32 个关键点坐标
|
||||
features = self.img_feature_proj(features) # [B, num_kp * 2]
|
||||
|
||||
all_cam_features.append(features)
|
||||
|
||||
# 拼接所有相机特征
|
||||
img_features = torch.cat(all_cam_features, dim=-1) # [B, num_kp * 2 * num_cam]
|
||||
|
||||
# 编码状态
|
||||
qpos_features = self.input_proj_robot_state(qpos) # [B, hidden_dim]
|
||||
|
||||
# 拼接观测
|
||||
obs = torch.cat([img_features, qpos_features], dim=-1) # [B, obs_dim]
|
||||
|
||||
return obs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
qpos,
|
||||
image,
|
||||
env_state,
|
||||
actions=None,
|
||||
is_pad=None,
|
||||
timesteps=None,
|
||||
):
|
||||
"""前向传播。
|
||||
|
||||
训练时:
|
||||
输入带噪声的动作序列和时间步,预测噪声/速度场
|
||||
推理时:
|
||||
通过扩散采样生成动作序列
|
||||
|
||||
Args:
|
||||
qpos: 机器人关节状态 [B, state_dim]
|
||||
image: 多相机图像 [B, num_cam, C, H, W]
|
||||
env_state: 环境状态(未使用)
|
||||
actions: 动作序列 [B, T, action_dim](训练时为带噪声动作)
|
||||
is_pad: padding 标记 [B, T](未使用)
|
||||
timesteps: 扩散时间步 [B](训练时提供)
|
||||
|
||||
Returns:
|
||||
训练时: (noise_pred, encoder_state)
|
||||
推理时: (action_pred, encoder_state)
|
||||
"""
|
||||
# 1. 编码观测
|
||||
obs = self.encode_observations(qpos, image)
|
||||
|
||||
# 2. 扩散模型前向
|
||||
if actions is not None and timesteps is not None:
|
||||
# 训练模式: 预测噪声
|
||||
noise_pred, encoder_state = self.action_ddt(
|
||||
x=actions,
|
||||
t=timesteps,
|
||||
obs=obs,
|
||||
)
|
||||
return noise_pred, encoder_state
|
||||
else:
|
||||
# 推理模式: 需要在 Policy 层进行扩散采样
|
||||
# 这里返回编码的观测,供 Policy 层使用
|
||||
return obs, None
|
||||
|
||||
def get_obs_dim(self):
|
||||
"""返回观测向量的维度。"""
|
||||
# [修改] 使用 num_kp 重新计算
|
||||
return self.num_kp * 2 * len(self.camera_names) + self.hidden_dim
|
||||
|
||||
|
||||
def build(args):
|
||||
"""构建 DDT 模型。
|
||||
|
||||
Args:
|
||||
args: 包含模型配置的参数对象
|
||||
- state_dim: 状态维度
|
||||
- action_dim: 动作维度
|
||||
- camera_names: 相机名称列表
|
||||
- hidden_dim: 隐藏维度
|
||||
- num_queries: 动作序列长度
|
||||
- num_blocks: Transformer block 数量
|
||||
- enc_layers: 编码器层数
|
||||
- nheads: 注意力头数
|
||||
- num_kp: 关键点数量 (可选,默认32)
|
||||
|
||||
Returns:
|
||||
model: DDT 模型实例
|
||||
"""
|
||||
state_dim = args.state_dim
|
||||
action_dim = args.action_dim
|
||||
|
||||
# 构建 Backbone(每个相机一个)
|
||||
backbones = []
|
||||
for _ in args.camera_names:
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
# 构建 DDT 模型
|
||||
model = DDT(
|
||||
backbones=backbones,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
num_queries=args.num_queries,
|
||||
camera_names=args.camera_names,
|
||||
hidden_dim=args.hidden_dim,
|
||||
num_blocks=getattr(args, 'num_blocks', 12),
|
||||
num_encoder_blocks=getattr(args, 'enc_layers', 4),
|
||||
num_heads=args.nheads,
|
||||
mlp_ratio=getattr(args, 'mlp_ratio', 4.0),
|
||||
num_kp=getattr(args, 'num_kp', 32), # [修改] 传递 num_kp 参数
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def build_ddt(args):
|
||||
"""build 的别名,保持接口一致性。"""
|
||||
return build(args)
|
||||
312
roboimi/ddt/models/transformer.py
Normal file
312
roboimi/ddt/models/transformer.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DETR Transformer class.
|
||||
|
||||
Copy-paste from torch.nn.Transformer with modifications:
|
||||
* positional encodings are passed in MHattention
|
||||
* extra LN at the end of encoder is removed
|
||||
* decoder returns a stack of activations from all decoding layers
|
||||
"""
|
||||
import copy
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
||||
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
||||
activation="relu", normalize_before=False,
|
||||
return_intermediate_dec=False):
|
||||
super().__init__()
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
||||
dropout, activation, normalize_before)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
||||
dropout, activation, normalize_before)
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
||||
return_intermediate=return_intermediate_dec)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
bs, c, h, w = src.shape
|
||||
src = src.flatten(2).permute(2, 0, 1)
|
||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
# flatten NxHWxC to HWxNxC
|
||||
bs, hw, c = src.shape
|
||||
src = src.permute(1, 0, 2)
|
||||
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
|
||||
tgt = torch.zeros_like(query_embed)
|
||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
||||
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
|
||||
pos=pos_embed, query_pos=query_embed)
|
||||
hs = hs.transpose(1, 2)
|
||||
return hs
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(self, src,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None):
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output, src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
|
||||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
self.return_intermediate = return_intermediate
|
||||
|
||||
def forward(self, tgt, memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
output = tgt
|
||||
|
||||
intermediate = []
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output, memory, tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
pos=pos, query_pos=query_pos)
|
||||
if self.return_intermediate:
|
||||
intermediate.append(self.norm(output))
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
if self.return_intermediate:
|
||||
intermediate.pop()
|
||||
intermediate.append(output)
|
||||
|
||||
if self.return_intermediate:
|
||||
return torch.stack(intermediate)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
||||
activation="relu", normalize_before=False):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None):
|
||||
q = k = self.with_pos_embed(src, pos)
|
||||
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
def forward_pre(self, src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None):
|
||||
src2 = self.norm1(src)
|
||||
q = k = self.with_pos_embed(src2, pos)
|
||||
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
||||
src = src + self.dropout2(src2)
|
||||
return src
|
||||
|
||||
def forward(self, src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
||||
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
||||
activation="relu", normalize_before=False):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(self, tgt, memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
q = k = self.with_pos_embed(tgt, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory, attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward_pre(self, tgt, memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory, attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(self, tgt, memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
||||
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
||||
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
||||
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
def build_transformer(args):
|
||||
return Transformer(
|
||||
d_model=args.hidden_dim,
|
||||
dropout=args.dropout,
|
||||
nhead=args.nheads,
|
||||
dim_feedforward=args.dim_feedforward,
|
||||
num_encoder_layers=args.enc_layers,
|
||||
num_decoder_layers=args.dec_layers,
|
||||
normalize_before=args.pre_norm,
|
||||
return_intermediate_dec=True,
|
||||
)
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
||||
147
roboimi/ddt/policy.py
Normal file
147
roboimi/ddt/policy.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
DDT Policy - 基于扩散模型的动作生成策略。
|
||||
|
||||
支持 Flow Matching 训练和推理。
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from torchvision.transforms import v2
|
||||
import math
|
||||
|
||||
from roboimi.ddt.main import build_DDT_model_and_optimizer
|
||||
|
||||
|
||||
class DDTPolicy(nn.Module):
|
||||
"""DDT (Decoupled Diffusion Transformer) 策略。
|
||||
|
||||
使用 Flow Matching 进行训练,支持多步扩散采样推理。
|
||||
带数据增强,适配 DINOv2 等 ViT backbone。
|
||||
|
||||
Args:
|
||||
args_override: 配置参数字典
|
||||
- num_inference_steps: 推理时的扩散步数
|
||||
- qpos_noise_std: qpos 噪声标准差(训练时数据增强)
|
||||
- patch_h, patch_w: 图像 patch 数量(用于计算目标尺寸)
|
||||
"""
|
||||
|
||||
def __init__(self, args_override):
|
||||
super().__init__()
|
||||
model, optimizer = build_DDT_model_and_optimizer(args_override)
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.num_inference_steps = args_override.get('num_inference_steps', 10)
|
||||
self.qpos_noise_std = args_override.get('qpos_noise_std', 0.0)
|
||||
|
||||
# 图像尺寸配置 (适配 DINOv2)
|
||||
self.patch_h = args_override.get('patch_h', 16)
|
||||
self.patch_w = args_override.get('patch_w', 22)
|
||||
|
||||
print(f'DDT Policy: {self.num_inference_steps} steps, '
|
||||
f'image size ({self.patch_h*14}, {self.patch_w*14})')
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
"""前向传播。
|
||||
|
||||
训练时: 使用 Flow Matching 损失
|
||||
推理时: 通过扩散采样生成动作
|
||||
|
||||
Args:
|
||||
qpos: 机器人关节状态 [B, state_dim]
|
||||
image: 多相机图像 [B, num_cam, C, H, W]
|
||||
actions: 目标动作序列 [B, T, action_dim](训练时提供)
|
||||
is_pad: padding 标记 [B, T]
|
||||
|
||||
Returns:
|
||||
训练时: loss_dict
|
||||
推理时: 预测的动作序列 [B, T, action_dim]
|
||||
"""
|
||||
env_state = None
|
||||
|
||||
# 图像预处理
|
||||
if actions is not None: # 训练时:数据增强
|
||||
transform = v2.Compose([
|
||||
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
|
||||
v2.RandomPerspective(distortion_scale=0.5),
|
||||
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
|
||||
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
|
||||
v2.Resize((self.patch_h * 14, self.patch_w * 14)),
|
||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||
])
|
||||
if self.qpos_noise_std > 0:
|
||||
qpos = qpos + (self.qpos_noise_std ** 0.5) * torch.randn_like(qpos)
|
||||
else: # 推理时
|
||||
transform = v2.Compose([
|
||||
v2.Resize((self.patch_h * 14, self.patch_w * 14)),
|
||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||
])
|
||||
|
||||
image = transform(image)
|
||||
|
||||
if actions is not None:
|
||||
actions = actions[:, :self.model.num_queries]
|
||||
is_pad = is_pad[:, :self.model.num_queries]
|
||||
loss_dict = self._compute_loss(qpos, image, actions, is_pad)
|
||||
return loss_dict
|
||||
else:
|
||||
a_hat = self._sample(qpos, image)
|
||||
return a_hat
|
||||
|
||||
def _compute_loss(self, qpos, image, actions, is_pad):
|
||||
"""计算 Flow Matching 损失。
|
||||
|
||||
Flow Matching 目标: 学习从噪声到数据的向量场
|
||||
损失: ||v_theta(x_t, t) - (x_1 - x_0)||^2
|
||||
其中 x_t = (1-t)*x_0 + t*x_1, x_0 是噪声, x_1 是目标动作
|
||||
"""
|
||||
B, T, action_dim = actions.shape
|
||||
device = actions.device
|
||||
|
||||
t = torch.rand(B, device=device)
|
||||
noise = torch.randn_like(actions)
|
||||
|
||||
t_expand = t.view(B, 1, 1).expand(B, T, action_dim)
|
||||
x_t = (1 - t_expand) * noise + t_expand * actions
|
||||
target_velocity = actions - noise
|
||||
|
||||
pred_velocity, _ = self.model(
|
||||
qpos=qpos,
|
||||
image=image,
|
||||
env_state=None,
|
||||
actions=x_t,
|
||||
timesteps=t,
|
||||
)
|
||||
|
||||
all_loss = F.mse_loss(pred_velocity, target_velocity, reduction='none')
|
||||
loss = (all_loss * ~is_pad.unsqueeze(-1)).mean()
|
||||
|
||||
return {'flow_loss': loss, 'loss': loss}
|
||||
|
||||
@torch.no_grad()
|
||||
def _sample(self, qpos, image):
|
||||
"""通过 ODE 求解进行扩散采样。
|
||||
|
||||
使用 Euler 方法从 t=0 积分到 t=1:
|
||||
x_{t+dt} = x_t + v_theta(x_t, t) * dt
|
||||
"""
|
||||
B = qpos.shape[0]
|
||||
T = self.model.num_queries
|
||||
action_dim = self.model.action_dim
|
||||
device = qpos.device
|
||||
|
||||
x = torch.randn(B, T, action_dim, device=device)
|
||||
obs = self.model.encode_observations(qpos, image)
|
||||
|
||||
dt = 1.0 / self.num_inference_steps
|
||||
for i in range(self.num_inference_steps):
|
||||
t = torch.full((B,), i * dt, device=device)
|
||||
velocity, _ = self.model.action_ddt(x=x, t=t, obs=obs)
|
||||
x = x + velocity * dt
|
||||
|
||||
return x
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""返回优化器。"""
|
||||
return self.optimizer
|
||||
1
roboimi/ddt/util/__init__.py
Normal file
1
roboimi/ddt/util/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
88
roboimi/ddt/util/box_ops.py
Normal file
88
roboimi/ddt/util/box_ops.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Utilities for bounding box manipulation and GIoU.
|
||||
"""
|
||||
import torch
|
||||
from torchvision.ops.boxes import box_area
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
||||
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_cxcywh(x):
|
||||
x0, y0, x1, y1 = x.unbind(-1)
|
||||
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
||||
(x1 - x0), (y1 - y0)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# modified from torchvision to also return the union
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/
|
||||
|
||||
The boxes should be in [x0, y0, x1, y1] format
|
||||
|
||||
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
||||
and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
||||
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
area = wh[:, :, 0] * wh[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
def masks_to_boxes(masks):
|
||||
"""Compute the bounding boxes around the provided masks
|
||||
|
||||
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
||||
|
||||
Returns a [N, 4] tensors, with the boxes in xyxy format
|
||||
"""
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device)
|
||||
|
||||
h, w = masks.shape[-2:]
|
||||
|
||||
y = torch.arange(0, h, dtype=torch.float)
|
||||
x = torch.arange(0, w, dtype=torch.float)
|
||||
y, x = torch.meshgrid(y, x)
|
||||
|
||||
x_mask = (masks * x.unsqueeze(0))
|
||||
x_max = x_mask.flatten(1).max(-1)[0]
|
||||
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||
|
||||
y_mask = (masks * y.unsqueeze(0))
|
||||
y_max = y_mask.flatten(1).max(-1)[0]
|
||||
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||
|
||||
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
||||
468
roboimi/ddt/util/misc.py
Normal file
468
roboimi/ddt/util/misc.py
Normal file
@@ -0,0 +1,468 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
import pickle
|
||||
from packaging import version
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
||||
import torchvision
|
||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||
from torchvision.ops import _new_empty_tensor
|
||||
from torchvision.ops.misc import _output_size
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value)
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that all processes
|
||||
have the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
if average:
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
return reduced_dict
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, str(meter))
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ''
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
data_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}',
|
||||
'max mem: {memory:.0f}'
|
||||
])
|
||||
else:
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}'
|
||||
])
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB))
|
||||
else:
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time)))
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('{} Total time: {} ({:.4f} s / it)'.format(
|
||||
header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
||||
sha = 'N/A'
|
||||
diff = "clean"
|
||||
branch = 'N/A'
|
||||
try:
|
||||
sha = _run(['git', 'rev-parse', 'HEAD'])
|
||||
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
||||
diff = _run(['git', 'diff-index', 'HEAD'])
|
||||
diff = "has uncommited changes" if diff else "clean"
|
||||
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
batch = list(zip(*batch))
|
||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||
return tuple(batch)
|
||||
|
||||
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
class NestedTensor(object):
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
# type: (Device) -> NestedTensor # noqa
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
assert mask is not None
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
# TODO make this more general
|
||||
if tensor_list[0].ndim == 3:
|
||||
if torchvision._is_tracing():
|
||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||||
|
||||
# TODO make it support different-sized images
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
b, c, h, w = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], :img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError('not supported')
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||||
@torch.jit.unused
|
||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
||||
max_size = []
|
||||
for i in range(tensor_list[0].dim()):
|
||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
||||
max_size.append(max_size_i)
|
||||
max_size = tuple(max_size)
|
||||
|
||||
# work around for
|
||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
# m[: img.shape[1], :img.shape[2]] = False
|
||||
# which is not yet supported in onnx
|
||||
padded_imgs = []
|
||||
padded_masks = []
|
||||
for img in tensor_list:
|
||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||
padded_imgs.append(padded_img)
|
||||
|
||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
||||
padded_masks.append(padded_mask.to(torch.bool))
|
||||
|
||||
tensor = torch.stack(padded_imgs)
|
||||
mask = torch.stack(padded_masks)
|
||||
|
||||
return NestedTensor(tensor, mask=mask)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop('force', False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print('Not using distributed mode')
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}'.format(
|
||||
args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
if target.numel() == 0:
|
||||
return [torch.zeros([], device=output.device)]
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
||||
"""
|
||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
||||
This will eventually be supported natively by PyTorch, and this
|
||||
class can go away.
|
||||
"""
|
||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||
if input.numel() > 0:
|
||||
return torch.nn.functional.interpolate(
|
||||
input, size, scale_factor, mode, align_corners
|
||||
)
|
||||
|
||||
output_shape = _output_size(2, input, size, scale_factor)
|
||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
||||
return _new_empty_tensor(input, output_shape)
|
||||
else:
|
||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
107
roboimi/ddt/util/plot_utils.py
Normal file
107
roboimi/ddt/util/plot_utils.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Plotting utilities to visualize training logs.
|
||||
"""
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from pathlib import Path, PurePath
|
||||
|
||||
|
||||
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
|
||||
'''
|
||||
Function to plot specific fields from training log(s). Plots both training and test results.
|
||||
|
||||
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
|
||||
- fields = which results to plot from each log file - plots both training and test for each field.
|
||||
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
|
||||
- log_name = optional, name of log file if different than default 'log.txt'.
|
||||
|
||||
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
|
||||
- solid lines are training results, dashed lines are test results.
|
||||
|
||||
'''
|
||||
func_name = "plot_utils.py::plot_logs"
|
||||
|
||||
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
|
||||
# convert single Path to list to avoid 'not iterable' error
|
||||
|
||||
if not isinstance(logs, list):
|
||||
if isinstance(logs, PurePath):
|
||||
logs = [logs]
|
||||
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
|
||||
else:
|
||||
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
|
||||
Expect list[Path] or single Path obj, received {type(logs)}")
|
||||
|
||||
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
|
||||
for i, dir in enumerate(logs):
|
||||
if not isinstance(dir, PurePath):
|
||||
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
|
||||
if not dir.exists():
|
||||
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
|
||||
# verify log_name exists
|
||||
fn = Path(dir / log_name)
|
||||
if not fn.exists():
|
||||
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
|
||||
print(f"--> full path of missing log file: {fn}")
|
||||
return
|
||||
|
||||
# load log file(s) and plot
|
||||
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
|
||||
|
||||
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
|
||||
|
||||
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
|
||||
for j, field in enumerate(fields):
|
||||
if field == 'mAP':
|
||||
coco_eval = pd.DataFrame(
|
||||
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
|
||||
).ewm(com=ewm_col).mean()
|
||||
axs[j].plot(coco_eval, c=color)
|
||||
else:
|
||||
df.interpolate().ewm(com=ewm_col).mean().plot(
|
||||
y=[f'train_{field}', f'test_{field}'],
|
||||
ax=axs[j],
|
||||
color=[color] * 2,
|
||||
style=['-', '--']
|
||||
)
|
||||
for ax, field in zip(axs, fields):
|
||||
ax.legend([Path(p).name for p in logs])
|
||||
ax.set_title(field)
|
||||
|
||||
|
||||
def plot_precision_recall(files, naming_scheme='iter'):
|
||||
if naming_scheme == 'exp_id':
|
||||
# name becomes exp_id
|
||||
names = [f.parts[-3] for f in files]
|
||||
elif naming_scheme == 'iter':
|
||||
names = [f.stem for f in files]
|
||||
else:
|
||||
raise ValueError(f'not supported {naming_scheme}')
|
||||
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
|
||||
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
|
||||
data = torch.load(f)
|
||||
# precision is n_iou, n_points, n_cat, n_area, max_det
|
||||
precision = data['precision']
|
||||
recall = data['params'].recThrs
|
||||
scores = data['scores']
|
||||
# take precision for all classes, all areas and 100 detections
|
||||
precision = precision[0, :, :, 0, -1].mean(1)
|
||||
scores = scores[0, :, :, 0, -1].mean(1)
|
||||
prec = precision.mean()
|
||||
rec = data['recall'][0, :, 0, -1].mean()
|
||||
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
|
||||
f'score={scores.mean():0.3f}, ' +
|
||||
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
|
||||
)
|
||||
axs[0].plot(recall, precision, c=color)
|
||||
axs[1].plot(recall, scores, c=color)
|
||||
|
||||
axs[0].set_title('Precision / Recall')
|
||||
axs[0].legend(names)
|
||||
axs[1].set_title('Scores / Recall')
|
||||
axs[1].legend(names)
|
||||
return fig, axs
|
||||
@@ -8,7 +8,8 @@ temporal_agg: false
|
||||
|
||||
# policy_class: "ACT"
|
||||
# backbone: 'resnet18'
|
||||
policy_class: "GR00T"
|
||||
policy_class: "ACTTV"
|
||||
# policy_class: "DDT"
|
||||
backbone: 'dino_v2'
|
||||
|
||||
seed: 0
|
||||
@@ -38,13 +39,8 @@ episode_len: # leave empty here by default
|
||||
camera_names: [] # leave empty here by default
|
||||
xml_dir: # leave empty here by default
|
||||
|
||||
# action smoothing settings (for GR00T)
|
||||
use_action_smoothing: true
|
||||
smooth_method: "ema" # Options: "ema", "moving_avg", "lowpass", "none"
|
||||
smooth_alpha: 0.3 # Smoothing factor (0-1), smaller = smoother
|
||||
|
||||
# transformer settings
|
||||
batch_size: 10
|
||||
batch_size: 32
|
||||
state_dim: 16
|
||||
action_dim: 16
|
||||
lr_backbone: 0.00001
|
||||
@@ -56,21 +52,6 @@ nheads: 8
|
||||
qpos_noise_std: 0
|
||||
DT: 0.02
|
||||
|
||||
gr00t:
|
||||
action_dim: 16
|
||||
state_dim: 16
|
||||
embed_dim: 1536
|
||||
hidden_dim: 1024
|
||||
num_queries: 8
|
||||
|
||||
nheads: 32
|
||||
mlp_ratio: 4
|
||||
dropout: 0.2
|
||||
|
||||
num_layers: 16
|
||||
|
||||
|
||||
|
||||
# DO NOT CHANGE IF UNNECESSARY
|
||||
lr: 0.00001
|
||||
kl_weight: 100
|
||||
@@ -78,3 +59,8 @@ chunk_size: 10
|
||||
hidden_dim: 512
|
||||
dim_feedforward: 3200
|
||||
|
||||
# DDT 特有参数
|
||||
num_blocks: 12 # Transformer blocks 数量
|
||||
mlp_ratio: 4.0 # MLP 维度比例
|
||||
num_inference_steps: 10 # 扩散推理步数
|
||||
|
||||
|
||||
119
roboimi/demos/diana_eval.py
Normal file
119
roboimi/demos/diana_eval.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import torch
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
from roboimi.utils.utils import set_seed
|
||||
from roboimi.utils.io_utils import IOUtils
|
||||
from roboimi.utils.model_interface import ModelInterface
|
||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
# from visualize_episodes import save_videos
|
||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
||||
|
||||
|
||||
|
||||
#should be added into IOUtils
|
||||
def get_image(obs,camera_names):
|
||||
curr_images = []
|
||||
for cam_name in camera_names:
|
||||
curr_image = rearrange(obs['images'][cam_name], 'h w c -> c h w')
|
||||
curr_images.append(curr_image)
|
||||
curr_image = np.stack(curr_images, axis=0)
|
||||
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
|
||||
return curr_image
|
||||
|
||||
|
||||
def eval_bc(config, ckpt_name='policy_best.ckpt', save_episode=True):
|
||||
set_seed(1)
|
||||
model_interface = ModelInterface(config)
|
||||
model_interface.setup()
|
||||
policy = IOUtils.load_policy(config, ckpt_name)
|
||||
stats = IOUtils.load_stats(config['ckpt_dir'])
|
||||
num_rollouts = 3
|
||||
episode_returns = []
|
||||
highest_rewards = []
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
run_episode(config, policy, stats,
|
||||
save_episode,num_rollouts)
|
||||
# episode_return, episode_highest_reward = run_episode(config, policy, stats,
|
||||
# save_episode,num_rollouts)
|
||||
|
||||
|
||||
|
||||
|
||||
def run_episode(config, policy, stats, save_episode,num_rollouts):
|
||||
|
||||
if 'sim_transfer' in config['task_name']:
|
||||
task_name = 'sim_transfer' #config['task_name']
|
||||
env = make_sim_env(task_name)
|
||||
|
||||
max_timesteps = config['episode_len']
|
||||
max_timesteps = int(max_timesteps * 1)
|
||||
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
||||
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
|
||||
box_pos = sample_transfer_pose()
|
||||
for rollout_id in range(num_rollouts):
|
||||
print("\nrollout_id===",rollout_id,"\n")
|
||||
image_list = []
|
||||
rewards = []
|
||||
query_frequency = config['policy_config'].get('num_queries', 1)
|
||||
print("query_freq =====",query_frequency)
|
||||
env.reset(box_pos)
|
||||
with torch.inference_mode():
|
||||
for t in range(700):
|
||||
image_list.append(env._get_image_obs()['images'] if 'images' in env._get_image_obs() else {print("img error")})
|
||||
qpos_numpy = np.array(env._get_qpos_obs()['qpos'])
|
||||
qpos = pre_process(qpos_numpy)
|
||||
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
||||
curr_image = get_image(env._get_image_obs(), config['camera_names'])
|
||||
if config['policy_class'] == "ACT" or "ACTTV":
|
||||
if t % query_frequency == 0:
|
||||
all_actions = policy(qpos, curr_image)
|
||||
raw_action = all_actions[:, t % query_frequency]
|
||||
# raw_action = all_actions[:, t % 1]
|
||||
raw_action = raw_action.squeeze(0).cpu().numpy()
|
||||
elif config['policy_class'] == "CNNMLP":
|
||||
raw_action = policy(qpos, curr_image)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
action = post_process(raw_action)
|
||||
print("action == ",action)
|
||||
env.step_jnt(action)
|
||||
rewards.append(env.rew)
|
||||
env.render()
|
||||
|
||||
|
||||
rewards = np.array(rewards)
|
||||
# episode_return = np.sum(rewards[rewards != None])
|
||||
# episode_highest_reward = np.max(rewards)
|
||||
# env.viewer.close()
|
||||
|
||||
# del env
|
||||
# return episode_return, episode_highest_reward
|
||||
|
||||
|
||||
|
||||
|
||||
def test_env():
|
||||
try:
|
||||
env = make_sim_env('sim_transfer')
|
||||
env.reset()
|
||||
while True: pass
|
||||
except KeyboardInterrupt:
|
||||
del env
|
||||
print("stop")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_env()
|
||||
io_utils = IOUtils()
|
||||
config = io_utils.load_config()
|
||||
eval_bc(config)
|
||||
|
||||
|
||||
@@ -104,8 +104,8 @@ class TestPickAndTransferPolicy(PolicyBase):
|
||||
{"t": 1, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": -100}, # sleep
|
||||
{"t": 75, "xyz": np.array([(0.8+box_xyz[0])*0.5,(1.0+box_xyz[1])*0.5,init_mocap_pose_right[2]]), "quat": gripper_approach_quat.elements, "gripper": 100},
|
||||
{"t": 225, "xyz": box_xyz + np.array([0, 0, 0.3]), "quat": gripper_pick_quat.elements, "gripper": 100}, # approach the cube
|
||||
{"t": 275, "xyz": box_xyz + np.array([0, 0, 0.11]), "quat": gripper_pick_quat.elements, "gripper": 100}, # go down
|
||||
{"t": 280, "xyz": box_xyz + np.array([0, 0, 0.11]), "quat": gripper_pick_quat.elements, "gripper": -100}, # close gripper
|
||||
{"t": 275, "xyz": box_xyz + np.array([0, 0, 0.12]), "quat": gripper_pick_quat.elements, "gripper": 100}, # go down
|
||||
{"t": 280, "xyz": box_xyz + np.array([0, 0, 0.12]), "quat": gripper_pick_quat.elements, "gripper": -100}, # close gripper
|
||||
{"t": 450, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": -100},# approach wait position
|
||||
{"t": 500, "xyz": meet_xyz + np.array([0.1, 0, 0.0]), "quat": meet_right_quat.elements, "gripper": -100},# approach meet position
|
||||
{"t": 510, "xyz": meet_xyz + np.array([0.1, 0, 0.0]), "quat": meet_right_quat.elements, "gripper": 100}, # open gripper
|
||||
@@ -116,8 +116,8 @@ class TestPickAndTransferPolicy(PolicyBase):
|
||||
self.left_trajectory = [
|
||||
{"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": -100},# sleep
|
||||
{"t": 250, "xyz": meet_xyz + np.array([-0.5, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": 100}, # approach meet position
|
||||
{"t": 500, "xyz": meet_xyz + np.array([-0.14, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": 100}, # move to meet position
|
||||
{"t": 505, "xyz": meet_xyz + np.array([-0.14, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # close gripper
|
||||
{"t": 500, "xyz": meet_xyz + np.array([-0.15, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": 100}, # move to meet position
|
||||
{"t": 505, "xyz": meet_xyz + np.array([-0.15, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # close gripper
|
||||
{"t": 675, "xyz": meet_xyz + np.array([-0.3, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # move left
|
||||
{"t": 700, "xyz": meet_xyz + np.array([-0.3, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # stay
|
||||
]
|
||||
|
||||
@@ -21,7 +21,7 @@ def main():
|
||||
render_cam_name = 'angle'
|
||||
|
||||
episode_len = 700 #SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||
camera_names = ['angle','r_vis', 'top', 'front'] #SIM_TASK_CONFIGS[task_name]['camera_names']
|
||||
camera_names = ['angle','r_vis', 'top'] #SIM_TASK_CONFIGS[task_name]['camera_names']
|
||||
if task_name == 'sim_transfer':
|
||||
policy = TestPickAndTransferPolicy(inject_noise)
|
||||
print(task_name)
|
||||
@@ -32,12 +32,6 @@ def main():
|
||||
|
||||
env = make_sim_env(task_name)
|
||||
policy = TestPickAndTransferPolicy(inject_noise)
|
||||
|
||||
# 等待osmesa完全启动后再开始收集数据
|
||||
print("等待osmesa线程启动...")
|
||||
time.sleep(60)
|
||||
print("osmesa已就绪,开始收集数据...")
|
||||
|
||||
for episode_idx in range(num_episodes):
|
||||
obs = []
|
||||
reward_ee = []
|
||||
|
||||
152
roboimi/demos/eval.py
Normal file
152
roboimi/demos/eval.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import torch
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
from roboimi.utils.utils import set_seed
|
||||
from roboimi.utils.io_utils import IOUtils
|
||||
from roboimi.utils.model_interface import ModelInterface
|
||||
from roboimi.envs.vx300s_jnt import make_sim_env
|
||||
import time
|
||||
|
||||
# from visualize_episodes import save_videos
|
||||
from roboimi.utils.utils import sample_box_pose, sample_insertion_pose
|
||||
|
||||
|
||||
|
||||
#should be added into IOUtils
|
||||
def get_image(obs,camera_names):
|
||||
curr_images = []
|
||||
for cam_name in camera_names:
|
||||
curr_image = rearrange(obs['images'][cam_name], 'h w c -> c h w')
|
||||
curr_images.append(curr_image)
|
||||
curr_image = np.stack(curr_images, axis=0)
|
||||
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
|
||||
return curr_image
|
||||
|
||||
|
||||
def eval_bc(config, ckpt_name='policy_best.ckpt', save_episode=True):
|
||||
set_seed(1)
|
||||
model_interface = ModelInterface(config)
|
||||
task_name = 'sim_insertion' #config['task_name']
|
||||
model_interface.setup()
|
||||
policy = IOUtils.load_policy(config, ckpt_name)
|
||||
stats = IOUtils.load_stats(config['ckpt_dir'])
|
||||
num_rollouts = 3
|
||||
episode_returns = []
|
||||
highest_rewards = []
|
||||
for rollout_id in range(num_rollouts):
|
||||
episode_return, episode_highest_reward = run_episode(config, policy, stats,
|
||||
save_episode,rollout_id)
|
||||
|
||||
|
||||
|
||||
|
||||
def run_episode(config, policy, stats, save_episode,rollout_id):
|
||||
print("\nrollout_id===",rollout_id,"\n")
|
||||
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
||||
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
|
||||
if 'sim_insertion' in config['task_name']:
|
||||
peg_pose, socket_pose = sample_insertion_pose()
|
||||
box_pose = np.hstack((peg_pose[:3],socket_pose[:3])) # used in sim reset
|
||||
task_name = 'sim_insertion' #config['task_name']
|
||||
env = make_sim_env(task_name)
|
||||
env.reset(box_pose)
|
||||
max_timesteps = config['episode_len']
|
||||
max_timesteps = int(max_timesteps * 1)
|
||||
|
||||
image_list = []
|
||||
rewards = []
|
||||
query_frequency = config['policy_config'].get('num_queries', 1)
|
||||
|
||||
with torch.inference_mode():
|
||||
for t in range(700):
|
||||
# print("obs_img",env.obs['images'])
|
||||
image_list.append(env.obs['images'] if 'images' in env.obs else {print("img error")})
|
||||
qpos_numpy = np.array(env.obs['qpos'])
|
||||
qpos = pre_process(qpos_numpy)
|
||||
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
||||
curr_image = get_image(env.obs, config['camera_names'])
|
||||
if config['policy_class'] == "ACT" or "ACTTV":
|
||||
if t % query_frequency == 0:
|
||||
all_actions = policy(qpos, curr_image)
|
||||
elif config['policy_class'] == "CNNMLP":
|
||||
raw_action = policy(qpos, curr_image)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raw_action = all_actions[:, t % query_frequency]
|
||||
raw_action = raw_action.squeeze(0).cpu().numpy()
|
||||
action = post_process(raw_action)
|
||||
|
||||
env.step(action)
|
||||
rewards.append(env.rew)
|
||||
env.render()
|
||||
|
||||
|
||||
rewards = np.array(rewards)
|
||||
episode_return = np.sum(rewards[rewards != None])
|
||||
episode_highest_reward = np.max(rewards)
|
||||
env.viewer.close()
|
||||
|
||||
del env
|
||||
return episode_return, episode_highest_reward
|
||||
|
||||
|
||||
def test_env():
|
||||
try:
|
||||
env = make_sim_env('sim_insertion')
|
||||
box_pos = np.concatenate(sample_insertion_pose())
|
||||
env.reset(box_pos)
|
||||
while True: pass
|
||||
except KeyboardInterrupt:
|
||||
del env
|
||||
print("stop")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_env()
|
||||
# io_utils = IOUtils()
|
||||
# config = io_utils.load_config()
|
||||
# eval_bc(config)
|
||||
|
||||
|
||||
|
||||
|
||||
# config===== {'onscreen_render': False,
|
||||
# 'eval': 1,
|
||||
# 'ckpt_dir': 'ckpt_models',
|
||||
# 'num_epochs': 3000,
|
||||
# 'temporal_agg': False,
|
||||
# 'policy_class': 'ACT',
|
||||
# 'backbone': 'resnet18',
|
||||
# 'seed': 0, 'real_robot': 0,
|
||||
# 'task_name': 'sim_insertion',
|
||||
# 'images_render_height': 480,
|
||||
# 'images_render_width': 640,
|
||||
# 'left_arm_DOF_number': 6,
|
||||
# 'right_arm_DOF_number': 6,
|
||||
# 'left_qpos_raw': 8,
|
||||
# 'right_qpos_raw': 8,
|
||||
# 'left_qvel_raw': 8,
|
||||
# 'right_qvel_raw': 8,
|
||||
# 'dataset_dir': '/home/arm/lzd/act_env/dataset/sim_insertion',
|
||||
# 'num_episodes': 7,
|
||||
# 'episode_len': 400,
|
||||
# 'camera_names': ['top'],
|
||||
# 'xml_dir': None,
|
||||
# 'batch_size': 8,
|
||||
# 'state_dim': 14,
|
||||
# 'action_dim': 14,
|
||||
# 'lr_backbone': 1e-05,
|
||||
# 'enc_layers': 4,
|
||||
# 'dec_layers': 7,
|
||||
# 'nheads': 8,
|
||||
# 'qpos_noise_std': 0,
|
||||
# 'DT': 0.02,
|
||||
# 'lr': 1e-05,
|
||||
# 'kl_weight': 10,
|
||||
# 'chunk_size': 100,
|
||||
# 'hidden_dim': 512,
|
||||
# 'dim_feedforward': 3200,
|
||||
# 'policy_config': {'lr': 1e-05, 'num_queries': 100, 'kl_weight': 10, 'hidden_dim': 512, 'dim_feedforward': 3200, 'lr_backbone': 1e-05, 'backbone': 'resnet18', 'enc_layers': 4, 'dec_layers': 7, 'nheads': 8, 'camera_names': ['top']}}
|
||||
179
roboimi/demos/training.py
Normal file
179
roboimi/demos/training.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import torch
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from itertools import repeat
|
||||
import matplotlib.pyplot as plt
|
||||
import time
|
||||
from roboimi.utils.utils import set_seed, compute_dict_mean, detach_dict, load_data
|
||||
from roboimi.utils.io_utils import IOUtils
|
||||
from roboimi.utils.model_interface import ModelInterface
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def train_bc(config):
|
||||
num_epochs = config['num_epochs']
|
||||
ckpt_dir = config['ckpt_dir']
|
||||
seed = config['seed']
|
||||
|
||||
os.makedirs(ckpt_dir, exist_ok=True)
|
||||
|
||||
set_seed(seed)
|
||||
|
||||
model_interface = ModelInterface(config)
|
||||
model_interface.setup()
|
||||
|
||||
policy = model_interface.make_policy()
|
||||
policy.cuda()
|
||||
optimizer = model_interface.make_optimizer(policy)
|
||||
# print("cam names=====",config['camera_names'])
|
||||
train_dataloader, val_dataloader, stats, _ = load_data(
|
||||
config['dataset_dir'],
|
||||
config['num_episodes'],
|
||||
config['camera_names'],
|
||||
config['batch_size'],
|
||||
config['batch_size'])
|
||||
|
||||
IOUtils.save_stats(ckpt_dir, stats)
|
||||
|
||||
train_history = []
|
||||
validation_history = []
|
||||
min_val_loss = np.inf
|
||||
min_train_loss = np.inf
|
||||
best_ckpt_info = None
|
||||
|
||||
plt.ion()
|
||||
fig, ax = plt.subplots()
|
||||
train_losses, val_losses = [], []
|
||||
train_line, = ax.plot([], [], label='Train Loss')
|
||||
val_line, = ax.plot([], [], label='Validation Loss')
|
||||
ax.autoscale_view()
|
||||
ax.set_xlabel('Epoch')
|
||||
ax.set_ylabel('Loss')
|
||||
ax.legend()
|
||||
ax.grid(True)
|
||||
|
||||
|
||||
train_annotation = ax.annotate('', xy=(0, 0), textcoords='offset points')
|
||||
val_annotation = ax.annotate('', xy=(0, 0), textcoords='offset points')
|
||||
|
||||
|
||||
min_train_text = ax.text(0.85, 0.5, '', transform=ax.transAxes, fontsize=10, verticalalignment='center', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.5))
|
||||
min_val_text = ax.text(0.85, 0.45, '', transform=ax.transAxes, fontsize=10, verticalalignment='center', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.5))
|
||||
|
||||
for epoch in tqdm(range(num_epochs)):
|
||||
print(f'\nEpoch {epoch}')
|
||||
|
||||
# Validation
|
||||
epoch_val_loss, epoch_summary = validate(policy, val_dataloader)
|
||||
validation_history.append(epoch_summary)
|
||||
val_losses.append(epoch_val_loss.cpu().item())
|
||||
|
||||
if epoch_val_loss < min_val_loss:
|
||||
min_val_loss = epoch_val_loss
|
||||
min_val_epoch = epoch
|
||||
best_ckpt_info = (epoch, min_val_loss,
|
||||
deepcopy(policy.state_dict()))
|
||||
|
||||
print(f'Val loss: {epoch_val_loss:.5f}')
|
||||
print_summary(epoch_summary)
|
||||
|
||||
# Training
|
||||
epoch_train_loss, epoch_summary = train_epoch(
|
||||
policy, optimizer, train_dataloader)
|
||||
train_history.append(epoch_summary)
|
||||
train_losses.append(epoch_train_loss.cpu().item())
|
||||
|
||||
if epoch_train_loss < min_train_loss:
|
||||
min_train_loss = epoch_train_loss
|
||||
min_train_epoch = epoch
|
||||
|
||||
print(f'Train loss: {epoch_train_loss:.5f}')
|
||||
print_summary(epoch_summary)
|
||||
|
||||
# Update the plot with the new data
|
||||
train_line.set_xdata(range(len(train_losses)))
|
||||
train_line.set_ydata(train_losses)
|
||||
val_line.set_xdata(range(len(val_losses)))
|
||||
val_line.set_ydata(val_losses)
|
||||
|
||||
# Update annotations with the latest loss values at their respective positions
|
||||
train_annotation.set_position((len(train_losses)-1, train_losses[-1]))
|
||||
train_annotation.xy = (len(train_losses)-1, train_losses[-1])
|
||||
train_annotation.set_text(f'{train_losses[-1]:.5f}')
|
||||
|
||||
val_annotation.set_position((len(val_losses)-1, val_losses[-1]))
|
||||
val_annotation.xy = (len(val_losses)-1, val_losses[-1])
|
||||
val_annotation.set_text(f'{val_losses[-1]:.5f}')
|
||||
|
||||
# Update text objects with the minimum loss values, fixed on the right side
|
||||
min_train_text.set_text(f'Min Train Loss: {min_train_loss:.5f} (Epoch {min_train_epoch})')
|
||||
min_val_text.set_text(f'Min Val Loss: {min_val_loss:.5f} (Epoch {min_val_epoch})')
|
||||
|
||||
ax.relim()
|
||||
ax.autoscale_view()
|
||||
plt.draw()
|
||||
plt.pause(0.1)
|
||||
|
||||
|
||||
plt.ioff()
|
||||
IOUtils.save_checkpoint(policy, 'last', ckpt_dir, seed, 'last')
|
||||
|
||||
best_epoch, min_val_loss, best_state_dict = best_ckpt_info
|
||||
IOUtils.save_checkpoint(best_state_dict, best_epoch,
|
||||
ckpt_dir, seed, 'best', min_val_loss)
|
||||
print(
|
||||
f'Training finished:\nSeed {seed}, val loss {min_val_loss:.6f} at epoch {best_epoch}')
|
||||
|
||||
IOUtils.plot_history(train_history, validation_history,
|
||||
num_epochs, ckpt_dir, seed)
|
||||
|
||||
return best_ckpt_info
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def validate(policy, dataloader):
|
||||
policy.eval()
|
||||
epoch_dicts = []
|
||||
with torch.inference_mode():
|
||||
for data in dataloader:
|
||||
forward_dict = forward_pass(data, policy)
|
||||
epoch_dicts.append(forward_dict)
|
||||
epoch_summary = compute_dict_mean(epoch_dicts)
|
||||
return epoch_summary['loss'], epoch_summary
|
||||
|
||||
|
||||
def train_epoch(policy, optimizer, dataloader):
|
||||
policy.train()
|
||||
epoch_dicts = []
|
||||
for data in dataloader:
|
||||
optimizer.zero_grad()
|
||||
forward_dict = forward_pass(data, policy)
|
||||
loss = forward_dict['loss']
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
epoch_dicts.append(detach_dict(forward_dict))
|
||||
epoch_summary = compute_dict_mean(epoch_dicts)
|
||||
return epoch_summary['loss'], epoch_summary
|
||||
|
||||
|
||||
def forward_pass(data, policy):
|
||||
image_data, qpos_data, action_data, is_pad = data
|
||||
image_data, qpos_data, action_data, is_pad = image_data.cuda(
|
||||
), qpos_data.cuda(), action_data.cuda(), is_pad.cuda()
|
||||
return policy(qpos_data, image_data, action_data, is_pad)
|
||||
|
||||
|
||||
def print_summary(summary):
|
||||
summary_string = ' '.join(
|
||||
[f'{k}: {v.item():.3f}' for k, v in summary.items()])
|
||||
print(summary_string)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
io_utils = IOUtils()
|
||||
config = io_utils.load_config()
|
||||
train_bc(config)
|
||||
@@ -1,312 +0,0 @@
|
||||
"""
|
||||
VLA 策略评估脚本(简化版)
|
||||
|
||||
该脚本使用 agent 内置的队列管理来评估训练好的 VLA 策略。
|
||||
无需单独的评估器类 - agent 处理一切!
|
||||
|
||||
使用方法:
|
||||
python roboimi/demos/eval_vla_simple.py
|
||||
python roboimi/demos/eval_vla_simple.py eval.ckpt_path=checkpoints/vla_model_final.pt
|
||||
python roboimi/demos/eval_vla_simple.py eval.ckpt_path=checkpoints/vla_model_best.pt
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
import hydra
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from tqdm import tqdm
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from hydra.utils import instantiate
|
||||
from einops import rearrange
|
||||
|
||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
if not OmegaConf.has_resolver("len"):
|
||||
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
ckpt_path: str,
|
||||
agent_cfg: DictConfig,
|
||||
device: str = 'cuda'
|
||||
) -> torch.nn.Module:
|
||||
"""
|
||||
从检查点加载训练好的 VLA 模型,使用 Hydra agent 配置。
|
||||
|
||||
Args:
|
||||
ckpt_path: 检查点文件路径 (.pt)
|
||||
agent_cfg: Hydra agent 配置,用于实例化
|
||||
device: 加载模型的设备
|
||||
|
||||
Returns:
|
||||
加载后的 VLAAgent 模型
|
||||
"""
|
||||
from pathlib import Path as PathLib
|
||||
|
||||
ckpt_path = PathLib(ckpt_path).absolute()
|
||||
if not ckpt_path.exists():
|
||||
raise FileNotFoundError(f"检查点未找到: {ckpt_path}")
|
||||
|
||||
log.info(f"从 {ckpt_path} 加载检查点")
|
||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
||||
log.info(f"检查点键值: {checkpoint.keys()}")
|
||||
|
||||
# 加载数据集统计信息用于归一化
|
||||
stats = checkpoint.get('dataset_stats', None)
|
||||
|
||||
# 使用数据集统计信息从 Hydra 配置实例化 agent
|
||||
log.info("从配置实例化 agent...")
|
||||
agent = instantiate(agent_cfg, dataset_stats=stats)
|
||||
|
||||
# 加载模型状态
|
||||
agent.load_state_dict(checkpoint['model_state_dict'])
|
||||
log.info(f"✅ 模型状态已加载 (步数: {checkpoint.get('step', 'unknown')})")
|
||||
|
||||
if stats is not None:
|
||||
log.info(f"✅ 数据集统计信息已加载 (归一化: {stats.get('normalization_type', 'gaussian')})")
|
||||
else:
|
||||
# 后备方案:尝试从外部 JSON 文件加载(兼容旧检查点)
|
||||
stats_path = ckpt_path.parent / 'dataset_stats.json'
|
||||
if stats_path.exists():
|
||||
with open(stats_path, 'r') as f:
|
||||
stats = json.load(f)
|
||||
log.info("✅ 数据集统计信息已从外部 JSON 加载(旧版本兼容)")
|
||||
else:
|
||||
log.warning("⚠️ 未找到数据集统计信息。动作将无法反归一化!")
|
||||
|
||||
agent.eval()
|
||||
agent.to(device)
|
||||
|
||||
log.info(f"✅ 模型已成功加载到 {device}")
|
||||
return agent, stats
|
||||
|
||||
|
||||
def prepare_observation(obs: Dict, camera_names: list) -> Dict:
|
||||
"""
|
||||
将环境观测转换为 agent 格式。
|
||||
|
||||
Args:
|
||||
obs: 环境观测字典,包含图像和 qpos
|
||||
camera_names: 摄像头名称列表
|
||||
|
||||
Returns:
|
||||
agent 格式的观测字典
|
||||
"""
|
||||
import cv2
|
||||
|
||||
# 转换图像: numpy -> tensor, HWC -> CHW
|
||||
images = {}
|
||||
for cam_name in camera_names:
|
||||
img = obs['images'][cam_name]
|
||||
# Resize 到 224x224(与训练时一致)
|
||||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
||||
img = rearrange(img, 'h w c -> c h w')
|
||||
img = torch.from_numpy(img / 255.0).float()
|
||||
images[cam_name] = img
|
||||
|
||||
# 转换 qpos: numpy -> tensor
|
||||
qpos = torch.from_numpy(obs['qpos']).float()
|
||||
|
||||
return {'qpos': qpos, 'images': images}
|
||||
|
||||
|
||||
class ActionSmoother:
|
||||
"""
|
||||
动作平滑器(指数移动平均)
|
||||
用于平滑执行动作以获得更稳定的控制
|
||||
"""
|
||||
|
||||
def __init__(self, alpha: float = 0.3):
|
||||
"""
|
||||
Args:
|
||||
alpha: 平滑系数 (0-1),值越大越重视当前动作
|
||||
"""
|
||||
self.alpha = alpha
|
||||
self.prev_action = None
|
||||
|
||||
def smooth(self, action: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
平滑动作
|
||||
|
||||
Args:
|
||||
action: 当前动作
|
||||
|
||||
Returns:
|
||||
平滑后的动作
|
||||
"""
|
||||
if self.prev_action is None:
|
||||
smoothed = action
|
||||
else:
|
||||
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
|
||||
self.prev_action = smoothed
|
||||
return smoothed
|
||||
|
||||
def reset(self):
|
||||
"""重置平滑器状态"""
|
||||
self.prev_action = None
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||||
def main(cfg: DictConfig):
|
||||
"""
|
||||
使用 agent 内置队列管理的简化版 VLA 评估
|
||||
|
||||
所有评估参数来自 vla/conf/eval.yaml,合并到 cfg 中。
|
||||
命令行覆盖: python eval_vla_simple.py eval.ckpt_path=... eval.num_episodes=5
|
||||
"""
|
||||
|
||||
# 打印配置
|
||||
print("=" * 80)
|
||||
print("VLA 评估配置:")
|
||||
print("=" * 80)
|
||||
print(OmegaConf.to_yaml(cfg))
|
||||
print("=" * 80)
|
||||
|
||||
eval_cfg = cfg.eval
|
||||
device = eval_cfg.device
|
||||
camera_names = list(eval_cfg.camera_names)
|
||||
|
||||
# =========================================================================
|
||||
# 加载模型
|
||||
# =========================================================================
|
||||
log.info(f"🚀 从 {eval_cfg.ckpt_path} 加载模型...")
|
||||
agent, dataset_stats = load_checkpoint(
|
||||
ckpt_path=eval_cfg.ckpt_path,
|
||||
agent_cfg=cfg.agent,
|
||||
device=device
|
||||
)
|
||||
|
||||
# 重置 agent 的队列
|
||||
agent.reset()
|
||||
|
||||
# 可选:动作平滑器
|
||||
smoother = ActionSmoother(alpha=eval_cfg.smooth_alpha) if eval_cfg.use_smoothing else None
|
||||
|
||||
# =========================================================================
|
||||
# 创建环境
|
||||
# =========================================================================
|
||||
env = make_sim_env(eval_cfg.task_name)
|
||||
|
||||
# =========================================================================
|
||||
# 运行评估回合
|
||||
# =========================================================================
|
||||
all_stats = []
|
||||
|
||||
for episode_idx in range(eval_cfg.num_episodes):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"回合 {episode_idx + 1}/{eval_cfg.num_episodes}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
box_pos = sample_transfer_pose()
|
||||
env.reset(box_pos)
|
||||
|
||||
# 为新回合重置 agent 队列
|
||||
agent.reset()
|
||||
if smoother:
|
||||
smoother.reset()
|
||||
|
||||
# 计时统计
|
||||
inference_times = []
|
||||
total_times = []
|
||||
|
||||
with torch.inference_mode():
|
||||
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"回合 {episode_idx + 1}"):
|
||||
start_total = time.time()
|
||||
|
||||
# 从环境获取观测
|
||||
obs = env._get_image_obs()
|
||||
qpos_obs = env._get_qpos_obs()
|
||||
obs['qpos'] = qpos_obs['qpos']
|
||||
|
||||
# 准备给 agent 的观测
|
||||
observation = prepare_observation(obs, camera_names)
|
||||
|
||||
# 选择动作(agent 内部处理队列管理)
|
||||
start_inference = time.time()
|
||||
action = agent.select_action(observation)
|
||||
|
||||
if device == 'cuda':
|
||||
torch.cuda.synchronize()
|
||||
end_inference = time.time()
|
||||
|
||||
# 转换为 numpy
|
||||
action = action.cpu().numpy()
|
||||
|
||||
# 调试:打印当前时间步的动作(由配置控制)
|
||||
if eval_cfg.get('verbose_action', False):
|
||||
print(f"\n[Step {t:3d}] 预测动作: {action}")
|
||||
print(f" - 动作形状: {action.shape}")
|
||||
print(f" - 动作范围: [{action.min():.4f}, {action.max():.4f}]")
|
||||
print(f" - 动作均值: {action.mean():.4f}, 标准差: {action.std():.4f}")
|
||||
|
||||
# 可选:平滑动作
|
||||
if smoother:
|
||||
action = smoother.smooth(action)
|
||||
|
||||
# 执行动作
|
||||
env.step_jnt(action)
|
||||
env.render()
|
||||
|
||||
end_total = time.time()
|
||||
|
||||
# 记录计时
|
||||
inference_times.append(end_inference - start_inference)
|
||||
total_times.append(end_total - start_total)
|
||||
|
||||
# =========================================================================
|
||||
# 打印回合统计
|
||||
# =========================================================================
|
||||
avg_inference_time = np.mean(inference_times)
|
||||
avg_total_time = np.mean(total_times)
|
||||
|
||||
stats = {
|
||||
'inference_fps': 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0,
|
||||
'control_fps': 1.0 / avg_total_time if avg_total_time > 0 else 0.0,
|
||||
'avg_inference_time_ms': avg_inference_time * 1000,
|
||||
'avg_total_time_ms': avg_total_time * 1000,
|
||||
'num_inferences': len([t for t in inference_times if t > 0.001]), # 统计实际推理次数
|
||||
'num_steps': len(total_times)
|
||||
}
|
||||
all_stats.append(stats)
|
||||
|
||||
print(f"\n回合 {episode_idx + 1} 完成 ({eval_cfg.max_timesteps} 时间步)")
|
||||
print(f" 模型推理 FPS: {stats['inference_fps']:.2f} Hz")
|
||||
print(f" 控制循环 FPS: {stats['control_fps']:.2f} Hz")
|
||||
print(f" 平均推理时间: {stats['avg_inference_time_ms']:.2f} ms")
|
||||
print(f" 平均总时间: {stats['avg_total_time_ms']:.2f} ms")
|
||||
print(f" 总推理次数: {stats['num_inferences']}")
|
||||
|
||||
# =========================================================================
|
||||
# 总体统计
|
||||
# =========================================================================
|
||||
print(f"\n{'='*60}")
|
||||
print("评估完成!")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if all_stats:
|
||||
avg_inference_fps = np.mean([s['inference_fps'] for s in all_stats])
|
||||
avg_control_fps = np.mean([s['control_fps'] for s in all_stats])
|
||||
avg_inference_time = np.mean([s['avg_inference_time_ms'] for s in all_stats])
|
||||
avg_total_time = np.mean([s['avg_total_time_ms'] for s in all_stats])
|
||||
|
||||
print(f"\n总体统计 ({eval_cfg.num_episodes} 个回合):")
|
||||
print(f" 平均模型推理 FPS: {avg_inference_fps:.2f} Hz")
|
||||
print(f" 平均控制循环 FPS: {avg_control_fps:.2f} Hz")
|
||||
print(f" 平均推理时间: {avg_inference_time:.2f} ms")
|
||||
print(f" 平均总时间: {avg_total_time:.2f} ms")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,532 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
import pickle
|
||||
import hydra
|
||||
import torch
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from pathlib import Path
|
||||
|
||||
# 确保正确的导入路径
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from hydra.utils import instantiate
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# 注册列表长度解析器(用于配置中如 ${len:${data.camera_names}})
|
||||
if not OmegaConf.has_resolver("len"):
|
||||
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||
|
||||
|
||||
def recursive_to_device(data, device):
|
||||
"""
|
||||
递归地将嵌套字典/列表中的张量移动到指定设备。
|
||||
|
||||
Args:
|
||||
data: 字典、列表或张量
|
||||
device: 目标设备 (例如 'cuda', 'cpu')
|
||||
|
||||
Returns:
|
||||
所有张量已移动到指定设备的数据结构
|
||||
"""
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.to(device)
|
||||
elif isinstance(data, dict):
|
||||
return {k: recursive_to_device(v, device) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [recursive_to_device(v, device) for v in data]
|
||||
return data
|
||||
|
||||
|
||||
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
|
||||
"""
|
||||
解析恢复训练用的 checkpoint 路径。
|
||||
|
||||
Args:
|
||||
resume_ckpt: 配置中的 resume_ckpt,支持路径或 "auto"
|
||||
checkpoint_dir: 默认检查点目录
|
||||
|
||||
Returns:
|
||||
Path 或 None
|
||||
"""
|
||||
if resume_ckpt is None:
|
||||
return None
|
||||
|
||||
if str(resume_ckpt).lower() != "auto":
|
||||
return Path(resume_ckpt)
|
||||
|
||||
pattern = re.compile(r"vla_model_step_(\d+)\.pt$")
|
||||
candidates = []
|
||||
for ckpt_path in checkpoint_dir.glob("vla_model_step_*.pt"):
|
||||
match = pattern.search(ckpt_path.name)
|
||||
if match:
|
||||
candidates.append((int(match.group(1)), ckpt_path))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
return max(candidates, key=lambda x: x[0])[1]
|
||||
|
||||
|
||||
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
|
||||
"""
|
||||
创建带预热的学习率调度器。
|
||||
|
||||
Args:
|
||||
optimizer: PyTorch 优化器
|
||||
warmup_steps: 预热步数
|
||||
max_steps: 总训练步数
|
||||
scheduler_type: 预热后的调度器类型 ('cosine' 或 'constant')
|
||||
min_lr: 最小学习率(用于余弦衰减)
|
||||
|
||||
Returns:
|
||||
LambdaLR 调度器
|
||||
"""
|
||||
import math
|
||||
# 在 LambdaLR 修改前捕获初始学习率
|
||||
base_lr = optimizer.param_groups[0]['lr']
|
||||
min_lr_ratio = min_lr / base_lr if base_lr > 0 else 0.0
|
||||
|
||||
def lr_lambda(step):
|
||||
# 预热阶段:从 0 线性增加到 1
|
||||
if step < warmup_steps:
|
||||
return float(step) / float(max(1, warmup_steps))
|
||||
|
||||
# 预热后阶段
|
||||
if scheduler_type == 'cosine':
|
||||
# 从 1 到 min_lr_ratio 的余弦退火
|
||||
progress = float(step - warmup_steps) / float(max(1, max_steps - warmup_steps))
|
||||
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
|
||||
return max(min_lr_ratio, cosine_decay)
|
||||
else:
|
||||
# 恒定学习率
|
||||
return 1.0
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||||
def main(cfg: DictConfig):
|
||||
"""
|
||||
VLA 训练脚本(ResNet 骨干网络 + Diffusion 策略)
|
||||
|
||||
该脚本功能:
|
||||
1. 从 HDF5 文件加载数据集
|
||||
2. 实例化带 ResNet 视觉编码器的 VLAAgent
|
||||
3. 训练基于扩散的动作预测模型
|
||||
4. 定期保存检查点
|
||||
"""
|
||||
|
||||
# 打印配置
|
||||
print("=" * 80)
|
||||
print("VLA 训练配置:")
|
||||
print("=" * 80)
|
||||
print(OmegaConf.to_yaml(cfg))
|
||||
print("=" * 80)
|
||||
|
||||
log.info(f"🚀 开始 VLA 训练 (设备: {cfg.train.device})")
|
||||
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = Path("checkpoints")
|
||||
checkpoint_dir.mkdir(exist_ok=True)
|
||||
|
||||
# =========================================================================
|
||||
# 1. 实例化数据集与 DataLoader
|
||||
# =========================================================================
|
||||
log.info("📦 加载数据集...")
|
||||
try:
|
||||
dataset = instantiate(cfg.data)
|
||||
log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}")
|
||||
except Exception as e:
|
||||
log.error(f"❌ 数据集加载失败: {e}")
|
||||
raise
|
||||
|
||||
# 训练/验证集划分
|
||||
val_split = float(cfg.train.get('val_split', 0.1))
|
||||
seed = int(cfg.train.get('seed', 42))
|
||||
val_size = int(len(dataset) * val_split)
|
||||
train_size = len(dataset) - val_size
|
||||
if val_size > 0:
|
||||
train_dataset, val_dataset = random_split(
|
||||
dataset,
|
||||
[train_size, val_size],
|
||||
generator=torch.Generator().manual_seed(seed)
|
||||
)
|
||||
log.info(f"✅ 数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})")
|
||||
else:
|
||||
train_dataset, val_dataset = dataset, None
|
||||
log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)")
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=cfg.train.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=cfg.train.num_workers,
|
||||
pin_memory=(cfg.train.device != "cpu"),
|
||||
persistent_workers=(cfg.train.num_workers > 0),
|
||||
drop_last=True # 丢弃不完整批次以稳定训练
|
||||
)
|
||||
|
||||
val_loader = None
|
||||
if val_dataset is not None:
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=cfg.train.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=cfg.train.num_workers,
|
||||
pin_memory=(cfg.train.device != "cpu"),
|
||||
persistent_workers=(cfg.train.num_workers > 0),
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
log.info(f"✅ 训练加载器每轮批次数: {len(train_loader)}")
|
||||
if val_loader is not None:
|
||||
log.info(f"✅ 验证加载器每轮批次数: {len(val_loader)}")
|
||||
|
||||
# =========================================================================
|
||||
# 2. 加载数据集统计信息(将传递给 agent)
|
||||
# =========================================================================
|
||||
log.info("💾 加载数据集统计信息...")
|
||||
dataset_stats = None
|
||||
try:
|
||||
dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer')
|
||||
stats_path = Path(dataset_dir) / 'dataset_stats.pkl'
|
||||
|
||||
if stats_path.exists():
|
||||
with open(stats_path, 'rb') as f:
|
||||
stats = pickle.load(f)
|
||||
|
||||
# 扁平化stats字典(嵌套结构→扁平结构)以匹配NormalizationModule的期望格式
|
||||
dataset_stats = {
|
||||
'action_mean': stats['action_mean'].tolist(),
|
||||
'action_std': stats['action_std'].tolist(),
|
||||
'action_min': stats['action_min'].tolist(),
|
||||
'action_max': stats['action_max'].tolist(),
|
||||
'qpos_mean': stats['qpos_mean'].tolist(),
|
||||
'qpos_std': stats['qpos_std'].tolist(),
|
||||
'qpos_min': stats['qpos_min'].tolist(),
|
||||
'qpos_max': stats['qpos_max'].tolist(),
|
||||
}
|
||||
log.info(f"✅ 数据集统计信息加载完成 (归一化: {cfg.agent.normalization_type})")
|
||||
else:
|
||||
log.warning(f"⚠️ 统计文件未找到: {stats_path}")
|
||||
log.warning("⚠️ 推理时动作将无法反归一化!")
|
||||
|
||||
except Exception as e:
|
||||
log.warning(f"⚠️ 统计信息加载失败: {e}")
|
||||
log.warning("⚠️ 训练将继续,但推理可能无法正常工作")
|
||||
|
||||
# =========================================================================
|
||||
# 3. 实例化 VLA Agent
|
||||
# =========================================================================
|
||||
log.info("🤖 初始化 VLA Agent...")
|
||||
try:
|
||||
# 将 dataset_stats 和 normalization_type 传递给 agent
|
||||
agent = instantiate(cfg.agent, dataset_stats=dataset_stats)
|
||||
agent.to(cfg.train.device)
|
||||
agent.train()
|
||||
log.info(f"✅ Agent 初始化完成并已移至 {cfg.train.device}")
|
||||
|
||||
# 统计参数量
|
||||
total_params = sum(p.numel() for p in agent.parameters())
|
||||
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
|
||||
log.info(f"📊 总参数量: {total_params:,}")
|
||||
log.info(f"📊 可训练参数量: {trainable_params:,}")
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"❌ Agent 初始化失败: {e}")
|
||||
raise
|
||||
|
||||
# =========================================================================
|
||||
# 3.1 从预训练 checkpoint 加载权重(微调)
|
||||
# =========================================================================
|
||||
pretrained_ckpt = cfg.train.get('pretrained_ckpt', None)
|
||||
if pretrained_ckpt is not None:
|
||||
ckpt_path = Path(pretrained_ckpt)
|
||||
if ckpt_path.exists():
|
||||
log.info(f"🔄 [Finetune] 从预训练 checkpoint 加载权重: {ckpt_path}")
|
||||
try:
|
||||
checkpoint = torch.load(ckpt_path, map_location=cfg.train.device)
|
||||
|
||||
# 只加载模型权重(不加载 optimizer、scheduler)
|
||||
missing_keys, unexpected_keys = agent.load_state_dict(
|
||||
checkpoint['model_state_dict'],
|
||||
strict=False # 允许部分加载(结构不完全匹配时)
|
||||
)
|
||||
|
||||
log.info(f"✅ [Finetune] 模型权重加载成功")
|
||||
|
||||
if missing_keys:
|
||||
log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...")
|
||||
if unexpected_keys:
|
||||
log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...")
|
||||
|
||||
log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}")
|
||||
log.info(f"📈 [Finetune] 使用新的训练配置(lr={cfg.train.lr}, max_steps={cfg.train.max_steps})")
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"❌ [Finetune] 加载 checkpoint 失败: {e}")
|
||||
log.warning("⚠️ 将从头开始训练")
|
||||
else:
|
||||
log.error(f"❌ [Finetune] Checkpoint 文件不存在: {ckpt_path}")
|
||||
log.warning("⚠️ 将从头开始训练")
|
||||
|
||||
# =========================================================================
|
||||
# 4. 设置优化器与学习率调度器
|
||||
# =========================================================================
|
||||
weight_decay = float(cfg.train.get('weight_decay', 1e-5))
|
||||
grad_clip = float(cfg.train.get('grad_clip', 1.0))
|
||||
|
||||
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=weight_decay)
|
||||
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr}, weight_decay={weight_decay})")
|
||||
|
||||
# 设置带预热的学習率调度器
|
||||
warmup_steps = int(cfg.train.get('warmup_steps', 500))
|
||||
scheduler_type = cfg.train.get('scheduler_type', 'cosine')
|
||||
min_lr = float(cfg.train.get('min_lr', 1e-6))
|
||||
|
||||
scheduler = get_lr_schedule_with_warmup(
|
||||
optimizer,
|
||||
warmup_steps=warmup_steps,
|
||||
max_steps=cfg.train.max_steps,
|
||||
scheduler_type=scheduler_type,
|
||||
min_lr=min_lr
|
||||
)
|
||||
log.info(f"📈 学习率调度器: {scheduler_type},{warmup_steps} 步预热 (最小学习率={min_lr})")
|
||||
|
||||
# =========================================================================
|
||||
# 4.1 断点续训(恢复模型、优化器、调度器、步数)
|
||||
# =========================================================================
|
||||
start_step = 0
|
||||
resume_loss = None
|
||||
resume_best_loss = float('inf')
|
||||
|
||||
resume_ckpt = cfg.train.get('resume_ckpt', None)
|
||||
resume_path = resolve_resume_checkpoint(resume_ckpt, checkpoint_dir)
|
||||
if resume_ckpt is not None:
|
||||
if pretrained_ckpt is not None:
|
||||
log.warning("⚠️ [Resume] 同时设置了 pretrained_ckpt 与 resume_ckpt,将优先使用 resume_ckpt 进行断点续训")
|
||||
if resume_path is None:
|
||||
log.warning("⚠️ [Resume] 未找到可恢复的 checkpoint,将从头开始训练")
|
||||
elif not resume_path.exists():
|
||||
log.error(f"❌ [Resume] Checkpoint 文件不存在: {resume_path}")
|
||||
log.warning("⚠️ 将从头开始训练")
|
||||
else:
|
||||
log.info(f"🔄 [Resume] 从 checkpoint 恢复训练: {resume_path}")
|
||||
try:
|
||||
checkpoint = torch.load(resume_path, map_location=cfg.train.device)
|
||||
|
||||
agent.load_state_dict(checkpoint['model_state_dict'], strict=True)
|
||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
resume_step = int(checkpoint['step'])
|
||||
start_step = resume_step + 1
|
||||
|
||||
loaded_loss = checkpoint.get('loss', None)
|
||||
loaded_val_loss = checkpoint.get('val_loss', None)
|
||||
resume_loss = float(loaded_loss) if loaded_loss is not None else None
|
||||
if loaded_val_loss is not None:
|
||||
resume_best_loss = float(loaded_val_loss)
|
||||
elif loaded_loss is not None:
|
||||
resume_best_loss = float(loaded_loss)
|
||||
|
||||
log.info(f"✅ [Resume] 恢复成功: 上次步骤={resume_step}, 本次从步骤 {start_step} 开始")
|
||||
log.info(f"📈 [Resume] 当前学习率: {optimizer.param_groups[0]['lr']:.2e}")
|
||||
except Exception as e:
|
||||
log.error(f"❌ [Resume] 恢复失败: {e}")
|
||||
log.warning("⚠️ 将从头开始训练")
|
||||
start_step = 0
|
||||
resume_loss = None
|
||||
resume_best_loss = float('inf')
|
||||
|
||||
# =========================================================================
|
||||
# 5. 训练循环
|
||||
# =========================================================================
|
||||
log.info("🏋️ 开始训练循环...")
|
||||
|
||||
def build_agent_input(batch_data):
|
||||
"""构建 agent 输入格式"""
|
||||
images = {}
|
||||
# SimpleRobotDataset 返回 observation.{cam_name} 格式
|
||||
for cam_name in cfg.data.camera_names:
|
||||
key = f"observation.{cam_name}"
|
||||
if key in batch_data:
|
||||
images[cam_name] = batch_data[key]
|
||||
|
||||
return {
|
||||
'images': images,
|
||||
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
|
||||
'action': batch_data['action'],
|
||||
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
|
||||
}
|
||||
|
||||
def run_validation():
|
||||
"""运行验证"""
|
||||
if val_loader is None:
|
||||
return None
|
||||
agent.eval()
|
||||
|
||||
# 设置确定性种子以获得可重现的损失
|
||||
# 这确保验证损失在不同步骤之间可比较
|
||||
torch.manual_seed(42)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
with torch.no_grad():
|
||||
for val_batch in val_loader:
|
||||
val_batch = recursive_to_device(val_batch, cfg.train.device)
|
||||
val_input = build_agent_input(val_batch)
|
||||
val_loss = agent.compute_loss(val_input)
|
||||
total_loss += val_loss.item()
|
||||
num_batches += 1
|
||||
agent.train()
|
||||
return total_loss / max(num_batches, 1)
|
||||
|
||||
data_iter = iter(train_loader)
|
||||
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
|
||||
|
||||
best_loss = resume_best_loss
|
||||
last_loss = resume_loss
|
||||
|
||||
if start_step >= cfg.train.max_steps:
|
||||
log.warning(
|
||||
f"⚠️ [Resume] start_step={start_step} 已达到/超过 max_steps={cfg.train.max_steps},跳过训练循环"
|
||||
)
|
||||
|
||||
for step in pbar:
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
except StopIteration:
|
||||
# 轮次结束时重启迭代器
|
||||
data_iter = iter(train_loader)
|
||||
batch = next(data_iter)
|
||||
|
||||
# =====================================================================
|
||||
# 将批次移至设备
|
||||
# =====================================================================
|
||||
batch = recursive_to_device(batch, cfg.train.device)
|
||||
|
||||
# =====================================================================
|
||||
# 准备 agent 输入
|
||||
# =====================================================================
|
||||
# 数据集返回: {action, qpos, image_<cam_name>, ...}
|
||||
# Agent 期望: {images: dict, qpos: tensor, action: tensor}
|
||||
|
||||
# 准备 agent 输入
|
||||
agent_input = build_agent_input(batch)
|
||||
|
||||
# =====================================================================
|
||||
# 前向传播与损失计算
|
||||
# =====================================================================
|
||||
try:
|
||||
loss = agent.compute_loss(agent_input)
|
||||
except Exception as e:
|
||||
log.error(f"❌ 步骤 {step} 前向传播失败: {e}")
|
||||
raise
|
||||
|
||||
last_loss = loss.item()
|
||||
|
||||
# =====================================================================
|
||||
# 反向传播与优化
|
||||
# =====================================================================
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# 梯度裁剪以稳定训练
|
||||
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=grad_clip)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
# =====================================================================
|
||||
# 日志记录
|
||||
# =====================================================================
|
||||
if step % cfg.train.log_freq == 0:
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
pbar.set_postfix({
|
||||
"loss": f"{loss.item():.4f}",
|
||||
"lr": f"{current_lr:.2e}",
|
||||
"best_loss": f"{best_loss:.4f}"
|
||||
})
|
||||
log.info(f"步骤 {step}/{cfg.train.max_steps} | 损失: {loss.item():.4f} | 学习率: {current_lr:.2e}")
|
||||
|
||||
# =====================================================================
|
||||
# 检查点保存与验证
|
||||
# =====================================================================
|
||||
if step > 0 and step % cfg.train.save_freq == 0:
|
||||
# 运行验证
|
||||
val_loss = run_validation()
|
||||
if val_loss is not None:
|
||||
log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}")
|
||||
|
||||
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
||||
# 使用agent的归一化统计信息(包含normalization_type)
|
||||
agent_stats = agent.get_normalization_stats()
|
||||
torch.save({
|
||||
'step': step,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'val_loss': val_loss,
|
||||
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||
'current_lr': optimizer.param_groups[0]['lr'],
|
||||
}, checkpoint_path)
|
||||
log.info(f"💾 检查点已保存: {checkpoint_path}")
|
||||
|
||||
# 根据验证损失保存最佳模型
|
||||
eval_loss = val_loss if val_loss is not None else loss.item()
|
||||
if eval_loss < best_loss:
|
||||
best_loss = eval_loss
|
||||
best_model_path = checkpoint_dir / "vla_model_best.pt"
|
||||
agent_stats = agent.get_normalization_stats()
|
||||
torch.save({
|
||||
'step': step,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'val_loss': val_loss,
|
||||
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||
'current_lr': optimizer.param_groups[0]['lr'],
|
||||
}, best_model_path)
|
||||
log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})")
|
||||
|
||||
# =========================================================================
|
||||
# 6. 保存最终模型
|
||||
# =========================================================================
|
||||
final_model_path = checkpoint_dir / "vla_model_final.pt"
|
||||
agent_stats = agent.get_normalization_stats()
|
||||
torch.save({
|
||||
'step': cfg.train.max_steps,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'loss': last_loss,
|
||||
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||
'current_lr': optimizer.param_groups[0]['lr'],
|
||||
}, final_model_path)
|
||||
log.info(f"💾 最终模型已保存: {final_model_path}")
|
||||
|
||||
log.info("✅ 训练成功完成!")
|
||||
if last_loss is not None:
|
||||
log.info(f"📊 最终损失: {last_loss:.4f}")
|
||||
else:
|
||||
log.info("📊 最终损失: N/A(未执行训练步)")
|
||||
if best_loss != float('inf'):
|
||||
log.info(f"📊 最佳损失: {best_loss:.4f}")
|
||||
else:
|
||||
log.info("📊 最佳损失: N/A(无有效验证/训练损失)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
201
roboimi/detr/LICENSE
Normal file
201
roboimi/detr/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2020 - present, Facebook, Inc
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
9
roboimi/detr/README.md
Normal file
9
roboimi/detr/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0.
|
||||
|
||||
@article{Carion2020EndtoEndOD,
|
||||
title={End-to-End Object Detection with Transformers},
|
||||
author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko},
|
||||
journal={ArXiv},
|
||||
year={2020},
|
||||
volume={abs/2005.12872}
|
||||
}
|
||||
106
roboimi/detr/main.py
Normal file
106
roboimi/detr/main.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from .models import build_ACT_model, build_CNNMLP_model
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
|
||||
parser.add_argument('--lr', default=1e-4, type=float) # will be overridden
|
||||
parser.add_argument('--lr_backbone', default=1e-5, type=float) # will be overridden
|
||||
parser.add_argument('--batch_size', default=2, type=int) # not used
|
||||
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
||||
parser.add_argument('--epochs', default=300, type=int) # not used
|
||||
parser.add_argument('--lr_drop', default=200, type=int) # not used
|
||||
parser.add_argument('--clip_max_norm', default=0.1, type=float, # not used
|
||||
help='gradient clipping max norm')
|
||||
parser.add_argument('--qpos_noise_std', action='store', default=0, type=float, help='lr', required=False)
|
||||
|
||||
# Model parameters
|
||||
# * Backbone
|
||||
parser.add_argument('--backbone', default='resnet18', type=str, # will be overridden
|
||||
help="Name of the convolutional backbone to use")
|
||||
parser.add_argument('--dilation', action='store_true',
|
||||
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
|
||||
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
|
||||
help="Type of positional embedding to use on top of the image features")
|
||||
parser.add_argument('--camera_names', default=[], type=list, # will be overridden
|
||||
help="A list of camera names")
|
||||
|
||||
# * Transformer
|
||||
parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
|
||||
help="Number of encoding layers in the transformer")
|
||||
parser.add_argument('--dec_layers', default=6, type=int, # will be overridden
|
||||
help="Number of decoding layers in the transformer")
|
||||
parser.add_argument('--dim_feedforward', default=2048, type=int, # will be overridden
|
||||
help="Intermediate size of the feedforward layers in the transformer blocks")
|
||||
parser.add_argument('--hidden_dim', default=256, type=int, # will be overridden
|
||||
help="Size of the embeddings (dimension of the transformer)")
|
||||
parser.add_argument('--dropout', default=0.1, type=float,
|
||||
help="Dropout applied in the transformer")
|
||||
parser.add_argument('--nheads', default=8, type=int, # will be overridden
|
||||
help="Number of attention heads inside the transformer's attentions")
|
||||
parser.add_argument('--num_queries', default=400, type=int, # will be overridden
|
||||
help="Number of query slots")
|
||||
parser.add_argument('--pre_norm', action='store_true')
|
||||
parser.add_argument('--state_dim', default=14, type=int)
|
||||
parser.add_argument('--action_dim', default=14, type=int)
|
||||
|
||||
|
||||
# * Segmentation
|
||||
parser.add_argument('--masks', action='store_true',
|
||||
help="Train segmentation head if the flag is provided")
|
||||
|
||||
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def build_ACT_model_and_optimizer(args_override):
|
||||
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
||||
args = parser.parse_args()
|
||||
|
||||
for k, v in args_override.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
model = build_ACT_model(args)
|
||||
model.cuda()
|
||||
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
|
||||
"lr": args.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def build_CNNMLP_model_and_optimizer(args_override):
|
||||
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
||||
args = parser.parse_args()
|
||||
|
||||
for k, v in args_override.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
model = build_CNNMLP_model(args)
|
||||
model.cuda()
|
||||
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
|
||||
"lr": args.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
9
roboimi/detr/models/__init__.py
Normal file
9
roboimi/detr/models/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from .detr_vae import build as build_vae
|
||||
from .detr_vae import build_cnnmlp as build_cnnmlp
|
||||
|
||||
def build_ACT_model(args):
|
||||
return build_vae(args)
|
||||
|
||||
def build_CNNMLP_model(args):
|
||||
return build_cnnmlp(args)
|
||||
168
roboimi/detr/models/backbone.py
Normal file
168
roboimi/detr/models/backbone.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Backbone modules.
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from typing import Dict, List
|
||||
|
||||
from util.misc import NestedTensor, is_main_process
|
||||
|
||||
from .position_encoding import build_position_encoding
|
||||
|
||||
class FrozenBatchNorm2d(torch.nn.Module):
|
||||
"""
|
||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||
|
||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
||||
produce nans.
|
||||
"""
|
||||
|
||||
def __init__(self, n):
|
||||
super(FrozenBatchNorm2d, self).__init__()
|
||||
self.register_buffer("weight", torch.ones(n))
|
||||
self.register_buffer("bias", torch.zeros(n))
|
||||
self.register_buffer("running_mean", torch.zeros(n))
|
||||
self.register_buffer("running_var", torch.ones(n))
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
||||
if num_batches_tracked_key in state_dict:
|
||||
del state_dict[num_batches_tracked_key]
|
||||
|
||||
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def forward(self, x):
|
||||
# move reshapes to the beginning
|
||||
# to make it fuser-friendly
|
||||
w = self.weight.reshape(1, -1, 1, 1)
|
||||
b = self.bias.reshape(1, -1, 1, 1)
|
||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||
eps = 1e-5
|
||||
scale = w * (rv + eps).rsqrt()
|
||||
bias = b - rm * scale
|
||||
return x * scale + bias
|
||||
|
||||
|
||||
class BackboneBase(nn.Module):
|
||||
|
||||
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
|
||||
super().__init__()
|
||||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
||||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
||||
# parameter.requires_grad_(False)
|
||||
if return_interm_layers:
|
||||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||
else:
|
||||
return_layers = {'layer4': "0"}
|
||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, tensor):
|
||||
xs = self.body(tensor)
|
||||
return xs
|
||||
# out: Dict[str, NestedTensor] = {}
|
||||
# for name, x in xs.items():
|
||||
# m = tensor_list.mask
|
||||
# assert m is not None
|
||||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||
# out[name] = NestedTensor(x, mask)
|
||||
# return out
|
||||
|
||||
|
||||
class Backbone(BackboneBase):
|
||||
"""ResNet backbone with frozen BatchNorm."""
|
||||
def __init__(self, name: str,
|
||||
train_backbone: bool,
|
||||
return_interm_layers: bool,
|
||||
dilation: bool):
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
|
||||
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||
|
||||
|
||||
# class DINOv2BackBone(nn.Module):
|
||||
# def __init__(self) -> None:
|
||||
# super().__init__()
|
||||
# self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
||||
# self.body.eval()
|
||||
# self.num_channels = 384
|
||||
|
||||
# @torch.no_grad()
|
||||
# def forward(self, tensor):
|
||||
# xs = self.body.forward_features(tensor)["x_norm_patchtokens"]
|
||||
# od = OrderedDict()
|
||||
# od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||
# return od
|
||||
|
||||
class DINOv2BackBone(nn.Module):
|
||||
def __init__(self, return_interm_layers: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
||||
self.body.eval()
|
||||
self.num_channels = 384
|
||||
self.return_interm_layers = return_interm_layers
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, tensor):
|
||||
features = self.body.forward_features(tensor)
|
||||
|
||||
if self.return_interm_layers:
|
||||
|
||||
layer1 = features["x_norm_patchtokens"]
|
||||
layer2 = features["x_norm_patchtokens"]
|
||||
layer3 = features["x_norm_patchtokens"]
|
||||
layer4 = features["x_norm_patchtokens"]
|
||||
|
||||
od = OrderedDict()
|
||||
od["0"] = layer1.reshape(layer1.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||
od["1"] = layer2.reshape(layer2.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||
od["2"] = layer3.reshape(layer3.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||
od["3"] = layer4.reshape(layer4.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||
return od
|
||||
else:
|
||||
xs = features["x_norm_patchtokens"]
|
||||
od = OrderedDict()
|
||||
od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||
return od
|
||||
|
||||
class Joiner(nn.Sequential):
|
||||
def __init__(self, backbone, position_embedding):
|
||||
super().__init__(backbone, position_embedding)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
xs = self[0](tensor_list)
|
||||
out: List[NestedTensor] = []
|
||||
pos = []
|
||||
for name, x in xs.items():
|
||||
out.append(x)
|
||||
# position encoding
|
||||
pos.append(self[1](x).to(x.dtype))
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
def build_backbone(args):
|
||||
position_embedding = build_position_encoding(args)
|
||||
train_backbone = args.lr_backbone > 0
|
||||
return_interm_layers = args.masks
|
||||
if args.backbone == 'dino_v2':
|
||||
backbone = DINOv2BackBone()
|
||||
else:
|
||||
assert args.backbone in ['resnet18', 'resnet34']
|
||||
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
||||
model = Joiner(backbone, position_embedding)
|
||||
model.num_channels = backbone.num_channels
|
||||
return model
|
||||
300
roboimi/detr/models/detr_vae.py
Normal file
300
roboimi/detr/models/detr_vae.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DETR model and criterion classes.
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
from .backbone import build_backbone
|
||||
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def reparametrize(mu, logvar):
|
||||
std = logvar.div(2).exp()
|
||||
eps = Variable(std.data.new(std.size()).normal_())
|
||||
return mu + std * eps
|
||||
|
||||
|
||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
|
||||
class DETRVAE(nn.Module):
|
||||
""" This is the DETR module that performs object detection """
|
||||
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names):
|
||||
""" Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.camera_names = camera_names
|
||||
self.transformer = transformer
|
||||
self.encoder = encoder
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||
if backbones is not None:
|
||||
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
# input_dim = 14 + 7 # robot_state + env_state
|
||||
# self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
# self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
||||
# self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||
# self.backbones = None
|
||||
|
||||
# encoder extra parameters
|
||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
|
||||
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
|
||||
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
|
||||
|
||||
# decoder extra parameters
|
||||
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
||||
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
### Obtain latent z from action sequence
|
||||
if is_training:
|
||||
# project action sequence to embedding dim, and concat with a CLS token
|
||||
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
|
||||
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
||||
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
|
||||
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
||||
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
|
||||
encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
|
||||
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
|
||||
# do not mask cls token
|
||||
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
|
||||
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||
# obtain position embedding
|
||||
pos_embed = self.pos_table.clone().detach()
|
||||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||
# query model
|
||||
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
|
||||
encoder_output = encoder_output[0] # take cls output only
|
||||
latent_info = self.latent_proj(encoder_output)
|
||||
mu = latent_info[:, :self.latent_dim]
|
||||
logvar = latent_info[:, self.latent_dim:]
|
||||
latent_sample = reparametrize(mu, logvar)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
else:
|
||||
mu = logvar = None
|
||||
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
|
||||
if self.backbones is not None:
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
all_cam_pos = []
|
||||
|
||||
|
||||
|
||||
|
||||
# print(f"Image shape: {image.shape}, Number of cameras: {len(self.camera_names)}")
|
||||
|
||||
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0]
|
||||
all_cam_features.append(self.input_proj(features))
|
||||
all_cam_pos.append(pos)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# proprioception features
|
||||
proprio_input = self.input_proj_robot_state(qpos)
|
||||
# fold camera dimension into width dimension
|
||||
src = torch.cat(all_cam_features, axis=3)
|
||||
pos = torch.cat(all_cam_pos, axis=3)
|
||||
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
|
||||
else:
|
||||
qpos = self.input_proj_robot_state(qpos)
|
||||
env_state = self.input_proj_env_state(env_state)
|
||||
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
|
||||
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
|
||||
a_hat = self.action_head(hs)
|
||||
is_pad_hat = self.is_pad_head(hs)
|
||||
return a_hat, is_pad_hat, [mu, logvar]
|
||||
|
||||
|
||||
|
||||
class CNNMLP(nn.Module):
|
||||
def __init__(self, backbones, state_dim, camera_names):
|
||||
""" Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.camera_names = camera_names
|
||||
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
||||
if backbones is not None:
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
backbone_down_projs = []
|
||||
for backbone in backbones:
|
||||
down_proj = nn.Sequential(
|
||||
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
|
||||
nn.Conv2d(128, 64, kernel_size=5),
|
||||
nn.Conv2d(64, 32, kernel_size=5)
|
||||
)
|
||||
backbone_down_projs.append(down_proj)
|
||||
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
||||
|
||||
mlp_in_dim = 768 * len(backbones) + 14
|
||||
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0] # not used
|
||||
all_cam_features.append(self.backbone_down_projs[cam_id](features))
|
||||
# flatten everything
|
||||
flattened_features = []
|
||||
for cam_feature in all_cam_features:
|
||||
flattened_features.append(cam_feature.reshape([bs, -1]))
|
||||
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
||||
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
||||
a_hat = self.mlp(features)
|
||||
return a_hat
|
||||
|
||||
|
||||
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
||||
if hidden_depth == 0:
|
||||
mods = [nn.Linear(input_dim, output_dim)]
|
||||
else:
|
||||
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
for i in range(hidden_depth - 1):
|
||||
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
mods.append(nn.Linear(hidden_dim, output_dim))
|
||||
trunk = nn.Sequential(*mods)
|
||||
return trunk
|
||||
|
||||
|
||||
def build_encoder(args):
|
||||
d_model = args.hidden_dim # 256
|
||||
dropout = args.dropout # 0.1
|
||||
nhead = args.nheads # 8
|
||||
dim_feedforward = args.dim_feedforward # 2048
|
||||
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
|
||||
normalize_before = args.pre_norm # False
|
||||
activation = "relu"
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
||||
dropout, activation, normalize_before)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def build(args):
|
||||
state_dim = args.state_dim
|
||||
action_dim = args.action_dim
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
# backbone = build_backbone(args)
|
||||
# backbones.append(backbone)
|
||||
for _ in args.camera_names:
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
transformer = build_transformer(args)
|
||||
|
||||
encoder = build_encoder(args)
|
||||
|
||||
model = DETRVAE(
|
||||
backbones,
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
num_queries=args.num_queries,
|
||||
camera_names=args.camera_names,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters/1e6,))
|
||||
|
||||
return model
|
||||
|
||||
def build_cnnmlp(args):
|
||||
state_dim = 14 # TODO hardcode
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
for _ in args.camera_names:
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
model = CNNMLP(
|
||||
backbones,
|
||||
state_dim=state_dim,
|
||||
camera_names=args.camera_names,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters/1e6,))
|
||||
|
||||
return model
|
||||
|
||||
91
roboimi/detr/models/position_encoding.py
Normal file
91
roboimi/detr/models/position_encoding.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Various positional encodings for the transformer.
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from util.misc import NestedTensor
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, tensor):
|
||||
x = tensor
|
||||
# mask = tensor_list.mask
|
||||
# assert mask is not None
|
||||
# not_mask = ~mask
|
||||
|
||||
not_mask = torch.ones_like(x[0, [0]])
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Module):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.uniform_(self.row_embed.weight)
|
||||
nn.init.uniform_(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
h, w = x.shape[-2:]
|
||||
i = torch.arange(w, device=x.device)
|
||||
j = torch.arange(h, device=x.device)
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
pos = torch.cat([
|
||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(args):
|
||||
N_steps = args.hidden_dim // 2
|
||||
if args.position_embedding in ('v2', 'sine'):
|
||||
# TODO find a better way of exposing other arguments
|
||||
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
||||
elif args.position_embedding in ('v3', 'learned'):
|
||||
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {args.position_embedding}")
|
||||
|
||||
return position_embedding
|
||||
312
roboimi/detr/models/transformer.py
Normal file
312
roboimi/detr/models/transformer.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DETR Transformer class.
|
||||
|
||||
Copy-paste from torch.nn.Transformer with modifications:
|
||||
* positional encodings are passed in MHattention
|
||||
* extra LN at the end of encoder is removed
|
||||
* decoder returns a stack of activations from all decoding layers
|
||||
"""
|
||||
import copy
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
||||
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
||||
activation="relu", normalize_before=False,
|
||||
return_intermediate_dec=False):
|
||||
super().__init__()
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
||||
dropout, activation, normalize_before)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
||||
dropout, activation, normalize_before)
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
||||
return_intermediate=return_intermediate_dec)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
bs, c, h, w = src.shape
|
||||
src = src.flatten(2).permute(2, 0, 1)
|
||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
# flatten NxHWxC to HWxNxC
|
||||
bs, hw, c = src.shape
|
||||
src = src.permute(1, 0, 2)
|
||||
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
|
||||
tgt = torch.zeros_like(query_embed)
|
||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
||||
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
|
||||
pos=pos_embed, query_pos=query_embed)
|
||||
hs = hs.transpose(1, 2)
|
||||
return hs
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(self, src,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None):
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output, src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
|
||||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
self.return_intermediate = return_intermediate
|
||||
|
||||
def forward(self, tgt, memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
output = tgt
|
||||
|
||||
intermediate = []
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output, memory, tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
pos=pos, query_pos=query_pos)
|
||||
if self.return_intermediate:
|
||||
intermediate.append(self.norm(output))
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
if self.return_intermediate:
|
||||
intermediate.pop()
|
||||
intermediate.append(output)
|
||||
|
||||
if self.return_intermediate:
|
||||
return torch.stack(intermediate)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
||||
activation="relu", normalize_before=False):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None):
|
||||
q = k = self.with_pos_embed(src, pos)
|
||||
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
def forward_pre(self, src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None):
|
||||
src2 = self.norm1(src)
|
||||
q = k = self.with_pos_embed(src2, pos)
|
||||
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
||||
src = src + self.dropout2(src2)
|
||||
return src
|
||||
|
||||
def forward(self, src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
||||
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
||||
activation="relu", normalize_before=False):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(self, tgt, memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
q = k = self.with_pos_embed(tgt, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory, attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward_pre(self, tgt, memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory, attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(self, tgt, memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
||||
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
||||
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
||||
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
def build_transformer(args):
|
||||
return Transformer(
|
||||
d_model=args.hidden_dim,
|
||||
dropout=args.dropout,
|
||||
nhead=args.nheads,
|
||||
dim_feedforward=args.dim_feedforward,
|
||||
num_encoder_layers=args.enc_layers,
|
||||
num_decoder_layers=args.dec_layers,
|
||||
normalize_before=args.pre_norm,
|
||||
return_intermediate_dec=True,
|
||||
)
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
||||
163
roboimi/detr/policy.py
Normal file
163
roboimi/detr/policy.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from torchvision.transforms import v2
|
||||
import torch
|
||||
from roboimi.detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
|
||||
|
||||
|
||||
class ACTPolicy(nn.Module):
|
||||
def __init__(self, args_override):
|
||||
super().__init__()
|
||||
model, optimizer = build_ACT_model_and_optimizer(args_override)
|
||||
self.model = model # CVAE decoder
|
||||
self.optimizer = optimizer
|
||||
self.kl_weight = args_override['kl_weight']
|
||||
print(f'KL Weight {self.kl_weight}')
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
if actions is not None: # training time
|
||||
actions = actions[:, :self.model.num_queries]
|
||||
is_pad = is_pad[:, :self.model.num_queries]
|
||||
|
||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
loss_dict = dict()
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
|
||||
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
loss_dict['l1'] = l1
|
||||
loss_dict['kl'] = total_kld[0]
|
||||
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
|
||||
return loss_dict
|
||||
else: # inference time
|
||||
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
return self.optimizer
|
||||
|
||||
class ACTTVPolicy(nn.Module):
|
||||
def __init__(self, args_override):
|
||||
super().__init__()
|
||||
model, optimizer = build_ACT_model_and_optimizer(args_override)
|
||||
self.model = model # CVAE decoder
|
||||
self.optimizer = optimizer
|
||||
self.kl_weight = args_override['kl_weight']
|
||||
self.qpos_noise_std = args_override['qpos_noise_std']
|
||||
print(f'KL Weight {self.kl_weight}')
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
env_state = None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
# std=[0.229, 0.224, 0.225])
|
||||
# image = normalize(image)
|
||||
|
||||
|
||||
patch_h = 16
|
||||
patch_w = 22
|
||||
if actions is not None:
|
||||
transform = v2.Compose([
|
||||
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
|
||||
v2.RandomPerspective(distortion_scale=0.5),
|
||||
v2.RandomAffine(degrees=10, translate=(0.1,0.1), scale=(0.9,1.1)),
|
||||
v2.GaussianBlur(kernel_size=(9,9), sigma=(0.1,2.0)),
|
||||
v2.Resize((patch_h * 14, patch_w * 14)),
|
||||
# v2.CenterCrop((patch_h * 14, patch_w * 14)),
|
||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||
])
|
||||
qpos += (self.qpos_noise_std**0.5)*torch.randn_like(qpos)
|
||||
else: # inference time
|
||||
transform = v2.Compose([
|
||||
v2.Resize((patch_h * 14, patch_w * 14)),
|
||||
# v2.CenterCrop((patch_h * 14, patch_w * 14)),
|
||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
|
||||
])
|
||||
|
||||
image = transform(image)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if actions is not None: # training time
|
||||
actions = actions[:, :self.model.num_queries]
|
||||
is_pad = is_pad[:, :self.model.num_queries]
|
||||
|
||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
loss_dict = dict()
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
|
||||
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
loss_dict['l1'] = l1
|
||||
loss_dict['kl'] = total_kld[0]
|
||||
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
|
||||
return loss_dict
|
||||
else: # inference time
|
||||
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
return self.optimizer
|
||||
|
||||
|
||||
class CNNMLPPolicy(nn.Module):
|
||||
def __init__(self, args_override):
|
||||
super().__init__()
|
||||
model, optimizer = build_CNNMLP_model_and_optimizer(args_override)
|
||||
self.model = model # decoder
|
||||
self.optimizer = optimizer
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
env_state = None # TODO
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
if actions is not None: # training time
|
||||
actions = actions[:, 0]
|
||||
a_hat = self.model(qpos, image, env_state, actions)
|
||||
mse = F.mse_loss(actions, a_hat)
|
||||
loss_dict = dict()
|
||||
loss_dict['mse'] = mse
|
||||
loss_dict['loss'] = loss_dict['mse']
|
||||
return loss_dict
|
||||
else: # inference time
|
||||
a_hat = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
return self.optimizer
|
||||
|
||||
def kl_divergence(mu, logvar):
|
||||
batch_size = mu.size(0)
|
||||
assert batch_size != 0
|
||||
if mu.data.ndimension() == 4:
|
||||
mu = mu.view(mu.size(0), mu.size(1))
|
||||
if logvar.data.ndimension() == 4:
|
||||
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
||||
|
||||
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
||||
total_kld = klds.sum(1).mean(0, True)
|
||||
dimension_wise_kld = klds.mean(0)
|
||||
mean_kld = klds.mean(1).mean(0, True)
|
||||
|
||||
return total_kld, dimension_wise_kld, mean_kld
|
||||
10
roboimi/detr/setup.py
Normal file
10
roboimi/detr/setup.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from distutils.core import setup
|
||||
from setuptools import find_packages
|
||||
|
||||
setup(
|
||||
name='detr',
|
||||
version='0.0.0',
|
||||
packages=find_packages(),
|
||||
license='MIT License',
|
||||
long_description=open('README.md').read(),
|
||||
)
|
||||
1
roboimi/detr/util/__init__.py
Normal file
1
roboimi/detr/util/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
88
roboimi/detr/util/box_ops.py
Normal file
88
roboimi/detr/util/box_ops.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Utilities for bounding box manipulation and GIoU.
|
||||
"""
|
||||
import torch
|
||||
from torchvision.ops.boxes import box_area
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
||||
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_cxcywh(x):
|
||||
x0, y0, x1, y1 = x.unbind(-1)
|
||||
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
||||
(x1 - x0), (y1 - y0)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# modified from torchvision to also return the union
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/
|
||||
|
||||
The boxes should be in [x0, y0, x1, y1] format
|
||||
|
||||
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
||||
and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
||||
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
area = wh[:, :, 0] * wh[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
def masks_to_boxes(masks):
|
||||
"""Compute the bounding boxes around the provided masks
|
||||
|
||||
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
||||
|
||||
Returns a [N, 4] tensors, with the boxes in xyxy format
|
||||
"""
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device)
|
||||
|
||||
h, w = masks.shape[-2:]
|
||||
|
||||
y = torch.arange(0, h, dtype=torch.float)
|
||||
x = torch.arange(0, w, dtype=torch.float)
|
||||
y, x = torch.meshgrid(y, x)
|
||||
|
||||
x_mask = (masks * x.unsqueeze(0))
|
||||
x_max = x_mask.flatten(1).max(-1)[0]
|
||||
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||
|
||||
y_mask = (masks * y.unsqueeze(0))
|
||||
y_max = y_mask.flatten(1).max(-1)[0]
|
||||
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||
|
||||
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
||||
468
roboimi/detr/util/misc.py
Normal file
468
roboimi/detr/util/misc.py
Normal file
@@ -0,0 +1,468 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
import pickle
|
||||
from packaging import version
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
||||
import torchvision
|
||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||
from torchvision.ops import _new_empty_tensor
|
||||
from torchvision.ops.misc import _output_size
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value)
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that all processes
|
||||
have the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
if average:
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
return reduced_dict
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, str(meter))
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ''
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
data_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}',
|
||||
'max mem: {memory:.0f}'
|
||||
])
|
||||
else:
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}'
|
||||
])
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB))
|
||||
else:
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time)))
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('{} Total time: {} ({:.4f} s / it)'.format(
|
||||
header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
||||
sha = 'N/A'
|
||||
diff = "clean"
|
||||
branch = 'N/A'
|
||||
try:
|
||||
sha = _run(['git', 'rev-parse', 'HEAD'])
|
||||
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
||||
diff = _run(['git', 'diff-index', 'HEAD'])
|
||||
diff = "has uncommited changes" if diff else "clean"
|
||||
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
batch = list(zip(*batch))
|
||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||
return tuple(batch)
|
||||
|
||||
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
class NestedTensor(object):
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
# type: (Device) -> NestedTensor # noqa
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
assert mask is not None
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
# TODO make this more general
|
||||
if tensor_list[0].ndim == 3:
|
||||
if torchvision._is_tracing():
|
||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||||
|
||||
# TODO make it support different-sized images
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
b, c, h, w = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], :img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError('not supported')
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||||
@torch.jit.unused
|
||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
||||
max_size = []
|
||||
for i in range(tensor_list[0].dim()):
|
||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
||||
max_size.append(max_size_i)
|
||||
max_size = tuple(max_size)
|
||||
|
||||
# work around for
|
||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
# m[: img.shape[1], :img.shape[2]] = False
|
||||
# which is not yet supported in onnx
|
||||
padded_imgs = []
|
||||
padded_masks = []
|
||||
for img in tensor_list:
|
||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||
padded_imgs.append(padded_img)
|
||||
|
||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
||||
padded_masks.append(padded_mask.to(torch.bool))
|
||||
|
||||
tensor = torch.stack(padded_imgs)
|
||||
mask = torch.stack(padded_masks)
|
||||
|
||||
return NestedTensor(tensor, mask=mask)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop('force', False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print('Not using distributed mode')
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}'.format(
|
||||
args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
if target.numel() == 0:
|
||||
return [torch.zeros([], device=output.device)]
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
||||
"""
|
||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
||||
This will eventually be supported natively by PyTorch, and this
|
||||
class can go away.
|
||||
"""
|
||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||
if input.numel() > 0:
|
||||
return torch.nn.functional.interpolate(
|
||||
input, size, scale_factor, mode, align_corners
|
||||
)
|
||||
|
||||
output_shape = _output_size(2, input, size, scale_factor)
|
||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
||||
return _new_empty_tensor(input, output_shape)
|
||||
else:
|
||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
107
roboimi/detr/util/plot_utils.py
Normal file
107
roboimi/detr/util/plot_utils.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Plotting utilities to visualize training logs.
|
||||
"""
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from pathlib import Path, PurePath
|
||||
|
||||
|
||||
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
|
||||
'''
|
||||
Function to plot specific fields from training log(s). Plots both training and test results.
|
||||
|
||||
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
|
||||
- fields = which results to plot from each log file - plots both training and test for each field.
|
||||
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
|
||||
- log_name = optional, name of log file if different than default 'log.txt'.
|
||||
|
||||
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
|
||||
- solid lines are training results, dashed lines are test results.
|
||||
|
||||
'''
|
||||
func_name = "plot_utils.py::plot_logs"
|
||||
|
||||
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
|
||||
# convert single Path to list to avoid 'not iterable' error
|
||||
|
||||
if not isinstance(logs, list):
|
||||
if isinstance(logs, PurePath):
|
||||
logs = [logs]
|
||||
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
|
||||
else:
|
||||
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
|
||||
Expect list[Path] or single Path obj, received {type(logs)}")
|
||||
|
||||
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
|
||||
for i, dir in enumerate(logs):
|
||||
if not isinstance(dir, PurePath):
|
||||
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
|
||||
if not dir.exists():
|
||||
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
|
||||
# verify log_name exists
|
||||
fn = Path(dir / log_name)
|
||||
if not fn.exists():
|
||||
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
|
||||
print(f"--> full path of missing log file: {fn}")
|
||||
return
|
||||
|
||||
# load log file(s) and plot
|
||||
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
|
||||
|
||||
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
|
||||
|
||||
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
|
||||
for j, field in enumerate(fields):
|
||||
if field == 'mAP':
|
||||
coco_eval = pd.DataFrame(
|
||||
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
|
||||
).ewm(com=ewm_col).mean()
|
||||
axs[j].plot(coco_eval, c=color)
|
||||
else:
|
||||
df.interpolate().ewm(com=ewm_col).mean().plot(
|
||||
y=[f'train_{field}', f'test_{field}'],
|
||||
ax=axs[j],
|
||||
color=[color] * 2,
|
||||
style=['-', '--']
|
||||
)
|
||||
for ax, field in zip(axs, fields):
|
||||
ax.legend([Path(p).name for p in logs])
|
||||
ax.set_title(field)
|
||||
|
||||
|
||||
def plot_precision_recall(files, naming_scheme='iter'):
|
||||
if naming_scheme == 'exp_id':
|
||||
# name becomes exp_id
|
||||
names = [f.parts[-3] for f in files]
|
||||
elif naming_scheme == 'iter':
|
||||
names = [f.stem for f in files]
|
||||
else:
|
||||
raise ValueError(f'not supported {naming_scheme}')
|
||||
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
|
||||
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
|
||||
data = torch.load(f)
|
||||
# precision is n_iou, n_points, n_cat, n_area, max_det
|
||||
precision = data['precision']
|
||||
recall = data['params'].recThrs
|
||||
scores = data['scores']
|
||||
# take precision for all classes, all areas and 100 detections
|
||||
precision = precision[0, :, :, 0, -1].mean(1)
|
||||
scores = scores[0, :, :, 0, -1].mean(1)
|
||||
prec = precision.mean()
|
||||
rec = data['recall'][0, :, 0, -1].mean()
|
||||
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
|
||||
f'score={scores.mean():0.3f}, ' +
|
||||
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
|
||||
)
|
||||
axs[0].plot(recall, precision, c=color)
|
||||
axs[1].plot(recall, scores, c=color)
|
||||
|
||||
axs[0].set_title('Precision / Recall')
|
||||
axs[0].legend(names)
|
||||
axs[1].set_title('Scores / Recall')
|
||||
axs[1].legend(names)
|
||||
return fig, axs
|
||||
@@ -53,7 +53,6 @@ class DualDianaMed(MujocoEnv):
|
||||
self.l_vis = None
|
||||
self.top = None
|
||||
self.angle = None
|
||||
self.front = None
|
||||
self.obs = None
|
||||
|
||||
self.rew = None
|
||||
@@ -169,7 +168,6 @@ class DualDianaMed(MujocoEnv):
|
||||
obs['images']['angle'] = self.angle
|
||||
obs['images']['r_vis'] = self.r_vis
|
||||
obs['images']['l_vis'] = self.l_vis
|
||||
obs['images']['front'] = self.front
|
||||
return obs
|
||||
|
||||
def _get_image_obs(self):
|
||||
@@ -179,7 +177,6 @@ class DualDianaMed(MujocoEnv):
|
||||
obs['images']['angle'] = self.angle
|
||||
obs['images']['r_vis'] = self.r_vis
|
||||
obs['images']['l_vis'] = self.l_vis
|
||||
obs['images']['front'] = self.front
|
||||
return obs
|
||||
|
||||
def _get_qpos_obs(self):
|
||||
@@ -205,8 +202,6 @@ class DualDianaMed(MujocoEnv):
|
||||
return self.r_vis
|
||||
elif self.cam == 'l_vis':
|
||||
return self.l_vis
|
||||
elif self.cam == 'front':
|
||||
return self.front
|
||||
else:
|
||||
raise AttributeError("please input right name")
|
||||
|
||||
@@ -227,11 +222,7 @@ class DualDianaMed(MujocoEnv):
|
||||
img_renderer.update_scene(self.mj_data,camera="angle")
|
||||
self.angle = img_renderer.render()
|
||||
self.angle = self.angle[:, :, ::-1]
|
||||
img_renderer.update_scene(self.mj_data,camera="front")
|
||||
self.front = img_renderer.render()
|
||||
self.front = self.front[:, :, ::-1]
|
||||
if self.cam_view is not None:
|
||||
cv2.imshow('Cam view', self.cam_view)
|
||||
cv2.imshow('Cam view', self.cam_view)
|
||||
cv2.waitKey(1)
|
||||
|
||||
|
||||
|
||||
@@ -72,17 +72,12 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
|
||||
self.mj_data.joint('red_box_joint').qpos[5] = 0.0
|
||||
self.mj_data.joint('red_box_joint').qpos[6] = 0.0
|
||||
super().reset()
|
||||
self.top = None
|
||||
self.angle = None
|
||||
self.r_vis = None
|
||||
self.front = None
|
||||
self.cam_flage = True
|
||||
t=0
|
||||
while self.cam_flage:
|
||||
if(type(self.top)==type(None)
|
||||
or type(self.angle)==type(None)
|
||||
or type(self.r_vis)==type(None)
|
||||
or type(self.front)==type(None)):
|
||||
or type(self.r_vis)==type(None)):
|
||||
time.sleep(0.001)
|
||||
t+=1
|
||||
else:
|
||||
|
||||
@@ -27,8 +27,8 @@ def sample_insertion_pose():
|
||||
|
||||
def sample_transfer_pose():
|
||||
# Box
|
||||
x_range = [-0.2, 0.2]
|
||||
y_range = [0.7, 1.1]
|
||||
x_range = [0.0, 0.05]
|
||||
y_range = [0.95, 1.05]
|
||||
z_range = [0.47, 0.47]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
|
||||
@@ -18,9 +18,9 @@ SIM_TASK_CONFIGS = {
|
||||
# },
|
||||
'sim_transfer': {
|
||||
'dataset_dir': DATASET_DIR + '/sim_transfer',
|
||||
'num_episodes': 20,
|
||||
'num_episodes': 7,
|
||||
'episode_len': 700,
|
||||
'camera_names': ['top','r_vis','front'],
|
||||
'camera_names': ['angle','r_vis'],
|
||||
'xml_dir': HOME_PATH + '/assets'
|
||||
},
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
import torch
|
||||
from roboimi.utils.utils import load_data, set_seed
|
||||
from roboimi.detr.policy import ACTPolicy, CNNMLPPolicy, ACTTVPolicy
|
||||
from roboimi.gr00t.policy import gr00tPolicy
|
||||
from roboimi.detr.policy import ACTPolicy, CNNMLPPolicy,ACTTVPolicy
|
||||
from roboimi.ddt.policy import DDTPolicy
|
||||
|
||||
class ModelInterface:
|
||||
def __init__(self, config):
|
||||
@@ -66,25 +66,23 @@ class ModelInterface:
|
||||
'num_queries': 1,
|
||||
'camera_names': self.config['camera_names'],
|
||||
}
|
||||
elif self.config['policy_class'] == 'GR00T':
|
||||
# GR00T uses its own config section from config.yaml
|
||||
gr00t_config = self.config.get('gr00t', {})
|
||||
elif self.config['policy_class'] == 'DDT':
|
||||
self.config['policy_config'] = {
|
||||
'lr': gr00t_config.get('lr', self.config['lr']),
|
||||
'lr_backbone': gr00t_config.get('lr_backbone', self.config['lr_backbone']),
|
||||
'weight_decay': gr00t_config.get('weight_decay', 1e-4),
|
||||
'embed_dim': gr00t_config.get('embed_dim', 1536),
|
||||
'hidden_dim': gr00t_config.get('hidden_dim', 1024),
|
||||
'state_dim': gr00t_config.get('state_dim', 16),
|
||||
'action_dim': gr00t_config.get('action_dim', 16),
|
||||
'num_queries': gr00t_config.get('num_queries', 16),
|
||||
'num_layers': gr00t_config.get('num_layers', 16),
|
||||
'nheads': gr00t_config.get('nheads', 32),
|
||||
'mlp_ratio': gr00t_config.get('mlp_ratio', 4),
|
||||
'dropout': gr00t_config.get('dropout', 0.2),
|
||||
'backbone': gr00t_config.get('backbone', 'dino_v2'),
|
||||
'position_embedding': gr00t_config.get('position_embedding', 'sine'),
|
||||
'lr': self.config['lr'],
|
||||
'lr_backbone': self.config['lr_backbone'],
|
||||
'backbone': self.config.get('backbone', 'dino_v2'),
|
||||
'num_queries': self.config['chunk_size'],
|
||||
'hidden_dim': self.config['hidden_dim'],
|
||||
'nheads': self.config['nheads'],
|
||||
'enc_layers': self.config['enc_layers'],
|
||||
'state_dim': self.config.get('state_dim', 16),
|
||||
'action_dim': self.config.get('action_dim', 16),
|
||||
'camera_names': self.config['camera_names'],
|
||||
'qpos_noise_std': self.config.get('qpos_noise_std', 0),
|
||||
# DDT 特有参数
|
||||
'num_blocks': self.config.get('num_blocks', 12),
|
||||
'mlp_ratio': self.config.get('mlp_ratio', 4.0),
|
||||
'num_inference_steps': self.config.get('num_inference_steps', 10),
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -96,8 +94,8 @@ class ModelInterface:
|
||||
return ACTTVPolicy(self.config['policy_config'])
|
||||
elif self.config['policy_class'] == 'CNNMLP':
|
||||
return CNNMLPPolicy(self.config['policy_config'])
|
||||
elif self.config['policy_class'] == 'GR00T':
|
||||
return gr00tPolicy(self.config['policy_config'])
|
||||
elif self.config['policy_class'] == 'DDT':
|
||||
return DDTPolicy(self.config['policy_config'])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# export VLAAgent, VLAModelConfig
|
||||
@@ -1,401 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from typing import Dict, Optional, Any, Tuple
|
||||
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D
|
||||
from roboimi.vla.models.normalization import NormalizationModule
|
||||
|
||||
class VLAAgent(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_backbone, # 视觉编码器(ResNet 等)
|
||||
state_encoder,
|
||||
action_encoder,
|
||||
head,
|
||||
action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper)
|
||||
obs_dim, # 本体感知维度 (例如 关节角度)
|
||||
pred_horizon=16, # 预测未来多少步动作
|
||||
obs_horizon=4, # 使用多少步历史观测
|
||||
diffusion_steps=100, # DDPM 加噪步数
|
||||
inference_steps=10, # DDIM 推理步数
|
||||
num_cams=3, # 视觉输入的摄像头数量
|
||||
dataset_stats=None, # 数据集统计信息,用于归一化
|
||||
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
|
||||
num_action_steps=8, # 每次推理实际执行多少步动作
|
||||
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
|
||||
):
|
||||
super().__init__()
|
||||
# 保存参数
|
||||
self.action_dim = action_dim
|
||||
self.obs_dim = obs_dim
|
||||
self.pred_horizon = pred_horizon
|
||||
self.obs_horizon = obs_horizon
|
||||
self.num_cams = num_cams
|
||||
self.num_action_steps = num_action_steps
|
||||
self.inference_steps = inference_steps
|
||||
self.head_type = head_type # 'unet' 或 'transformer'
|
||||
|
||||
|
||||
# 归一化模块 - 统一训练和推理的归一化逻辑
|
||||
self.normalization = NormalizationModule(
|
||||
stats=dataset_stats,
|
||||
normalization_type=normalization_type
|
||||
)
|
||||
|
||||
self.vision_encoder = vision_backbone
|
||||
single_cam_feat_dim = self.vision_encoder.output_dim
|
||||
# global_cond_dim: 展平后的总维度(用于UNet)
|
||||
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
|
||||
total_prop_dim = obs_dim * obs_horizon
|
||||
self.global_cond_dim = total_vision_dim + total_prop_dim
|
||||
|
||||
# per_step_cond_dim: 每步的条件维度(用于Transformer)
|
||||
# 注意:这里不乘以obs_horizon,因为Transformer的输入是序列形式
|
||||
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
|
||||
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=diffusion_steps,
|
||||
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
|
||||
clip_sample=True,
|
||||
prediction_type='epsilon' # 预测噪声
|
||||
)
|
||||
|
||||
# DDIM 调度器用于快速推理
|
||||
self.infer_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=diffusion_steps,
|
||||
beta_schedule='squaredcos_cap_v2',
|
||||
clip_sample=True,
|
||||
prediction_type='epsilon'
|
||||
)
|
||||
|
||||
# 根据head类型初始化不同的参数
|
||||
if head_type == 'transformer':
|
||||
# 如果head已经是nn.Module实例,直接使用;否则需要初始化
|
||||
if isinstance(head, nn.Module):
|
||||
# 已经是实例化的模块(测试时直接传入<E4BCA0><E585A5>
|
||||
self.noise_pred_net = head
|
||||
else:
|
||||
# Hydra部分初始化的对象,调用时传入参数
|
||||
self.noise_pred_net = head(
|
||||
input_dim=action_dim,
|
||||
output_dim=action_dim,
|
||||
horizon=pred_horizon,
|
||||
n_obs_steps=obs_horizon,
|
||||
cond_dim=self.per_step_cond_dim # 每步的条件维度
|
||||
)
|
||||
else: # 'unet' (default)
|
||||
# UNet接口: input_dim, global_cond_dim
|
||||
self.noise_pred_net = head(
|
||||
input_dim=action_dim,
|
||||
global_cond_dim=self.global_cond_dim
|
||||
)
|
||||
|
||||
self.state_encoder = state_encoder
|
||||
self.action_encoder = action_encoder
|
||||
|
||||
# 初始化队列(用于在线推理)
|
||||
self.reset()
|
||||
|
||||
def _get_model_device(self) -> torch.device:
|
||||
"""获取模型当前所在设备。"""
|
||||
return next(self.parameters()).device
|
||||
|
||||
def _move_to_device(self, data, device: torch.device):
|
||||
"""递归地将张量数据移动到指定设备。"""
|
||||
if torch.is_tensor(data):
|
||||
return data.to(device)
|
||||
if isinstance(data, dict):
|
||||
return {k: self._move_to_device(v, device) for k, v in data.items()}
|
||||
if isinstance(data, list):
|
||||
return [self._move_to_device(v, device) for v in data]
|
||||
if isinstance(data, tuple):
|
||||
return tuple(self._move_to_device(v, device) for v in data)
|
||||
return data
|
||||
|
||||
|
||||
# ==========================
|
||||
# 训练阶段 (Training)
|
||||
# ==========================
|
||||
def compute_loss(self, batch):
|
||||
"""
|
||||
计算训练损失
|
||||
|
||||
Args:
|
||||
batch: 包含 images, qpos (本体感知), action, action_is_pad 的字典
|
||||
"""
|
||||
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
||||
action_is_pad = batch.get('action_is_pad', None) # 获取padding mask
|
||||
B = actions.shape[0]
|
||||
|
||||
# 归一化 states (qpos) 和 actions
|
||||
states = self.normalization.normalize_qpos(states)
|
||||
actions = self.normalization.normalize_action(actions)
|
||||
|
||||
state_features = self.state_encoder(states)
|
||||
|
||||
# 1. 提取视觉特征
|
||||
visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim)
|
||||
action_features = self.action_encoder(actions)
|
||||
|
||||
# 2. 采样噪声
|
||||
noise = torch.randn_like(action_features)
|
||||
|
||||
# 3. 随机采样时间步 (Timesteps)
|
||||
timesteps = torch.randint(
|
||||
0, self.noise_scheduler.config.num_train_timesteps,
|
||||
(B,), device=action_features.device
|
||||
).long()
|
||||
|
||||
# 4. 给动作加噪 (Forward Diffusion)
|
||||
noisy_actions = self.noise_scheduler.add_noise(
|
||||
action_features, noise, timesteps
|
||||
)
|
||||
|
||||
# 拼接全局条件并展平
|
||||
# visual_features: (B, obs_horizon, vision_dim)
|
||||
# state_features: (B, obs_horizon, obs_dim)
|
||||
# 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim))
|
||||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||||
global_cond = global_cond.flatten(start_dim=1)
|
||||
|
||||
# 5. 网络预测噪声(根据head类型选择接口)
|
||||
if self.head_type == 'transformer':
|
||||
# Transformer需要序列格式的条件: (B, obs_horizon, cond_dim_per_step)
|
||||
# 将展平的global_cond reshape回序列格式
|
||||
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
||||
pred_noise = self.noise_pred_net(
|
||||
sample=noisy_actions,
|
||||
timestep=timesteps,
|
||||
cond=cond
|
||||
)
|
||||
else: # 'unet'
|
||||
pred_noise = self.noise_pred_net(
|
||||
sample=noisy_actions,
|
||||
timestep=timesteps,
|
||||
global_cond=global_cond
|
||||
)
|
||||
|
||||
# 6. 计算 Loss (MSE),支持 padding mask
|
||||
loss = nn.functional.mse_loss(pred_noise, noise, reduction='none')
|
||||
|
||||
# 如果提供了 action_is_pad,对padding位置进行mask
|
||||
if action_is_pad is not None:
|
||||
# action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim)
|
||||
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype) # 1.0表示有效数据
|
||||
valid_count = mask.sum() * loss.shape[-1]
|
||||
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||||
else:
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
||||
# ==========================
|
||||
# 队列管理 (Queue Management)
|
||||
# ==========================
|
||||
def reset(self):
|
||||
"""清空观测和动作队列。应在 env.reset() 时调用"""
|
||||
self._queues = {
|
||||
'qpos': deque(maxlen=self.obs_horizon),
|
||||
'images': deque(maxlen=self.obs_horizon),
|
||||
'action': deque(maxlen=self.pred_horizon - self.obs_horizon + 1), # 可执行的动作缓存
|
||||
}
|
||||
|
||||
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
将新的观测添加到队列中。
|
||||
|
||||
Args:
|
||||
observation: 包含 'qpos' 和 'images' 的字典
|
||||
"""
|
||||
# 添加本体感知
|
||||
if 'qpos' in observation:
|
||||
self._queues['qpos'].append(observation['qpos'].clone())
|
||||
|
||||
# 添加图像
|
||||
if 'images' in observation:
|
||||
self._queues['images'].append({k: v.clone() for k, v in observation['images'].items()})
|
||||
|
||||
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
从队列中准备用于推理的批量观测。
|
||||
如果队列未满(首次调用时),用最新观测重复填充。
|
||||
|
||||
Returns:
|
||||
batch: 包含堆叠后的历史观测的字典
|
||||
"""
|
||||
# 堆叠历史本体感知
|
||||
qpos_list = list(self._queues['qpos'])
|
||||
if len(qpos_list) == 0:
|
||||
raise ValueError("观测队列为空,请先调用 _populate_queues 添加观测")
|
||||
# 如果队列未满,用最后一个观测填充
|
||||
while len(qpos_list) < self.obs_horizon:
|
||||
qpos_list.append(qpos_list[-1])
|
||||
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
|
||||
|
||||
# 堆叠历史图像
|
||||
images_list = list(self._queues['images'])
|
||||
if len(images_list) == 0:
|
||||
raise ValueError("图像队列为空,请先调用 _populate_queues 添加观测")
|
||||
# 如果队列未满,用最后一个观测填充
|
||||
while len(images_list) < self.obs_horizon:
|
||||
images_list.append(images_list[-1])
|
||||
|
||||
batch_images = {}
|
||||
for cam_name in images_list[0].keys():
|
||||
batch_images[cam_name] = torch.stack([img[cam_name] for img in images_list], dim=0).unsqueeze(0)
|
||||
|
||||
return {'qpos': batch_qpos, 'images': batch_images}
|
||||
|
||||
# ==========================
|
||||
# 在线推理 (Online Inference)
|
||||
# ==========================
|
||||
@torch.no_grad()
|
||||
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
根据当前观测选择单个动作。
|
||||
|
||||
这个方法维护一个历史观测和生成动作轨迹的缓存。工作流程:
|
||||
- 缓存 `obs_horizon` 步的历史观测
|
||||
- Diffusion 模型生成 `pred_horizon` 步的动作
|
||||
- 实际执行 `num_action_steps` 步动作
|
||||
|
||||
示意图:
|
||||
--------------------------------------------------------------
|
||||
(图例: o=obs_horizon, h=pred_horizon, a=num_action_steps)
|
||||
|时间步 | 0 | 1 | ... | o-1 | o | ... | h-1 |
|
||||
|观测是否使用 | 是 | 是 | 是 | 是 | 否 | 否 | 否 |
|
||||
|动作是否生成 | 是 | 是 | 是 | 是 | 是 | 是 | 是 |
|
||||
|动作是否执行 | 否 | 否 | 否 | 否 | 是 | 是 | 是 |
|
||||
--------------------------------------------------------------
|
||||
|
||||
Args:
|
||||
observation: 包含 'qpos' 和 'images' 的字典
|
||||
|
||||
Returns:
|
||||
action: (action_dim,) 单个动作
|
||||
"""
|
||||
# 使用模型当前设备作为唯一真值,将输入移动到模型设备
|
||||
# 避免根据CPU观测把模型错误搬回CPU。
|
||||
device = self._get_model_device()
|
||||
observation = self._move_to_device(observation, device)
|
||||
|
||||
# 将新观测添加到队列
|
||||
self._populate_queues(observation)
|
||||
|
||||
# 如果动作队列为空,生成新的动作序列
|
||||
if len(self._queues['action']) == 0:
|
||||
# 从队列准备批量观测
|
||||
batch = self._prepare_observation_batch()
|
||||
|
||||
# 生成动作块
|
||||
actions = self.predict_action_chunk(batch) # (1, pred_horizon, action_dim)
|
||||
|
||||
# 提取可执行的动作部分
|
||||
# 从 obs_horizon-1 开始,因为前面的动作对应过去的观测
|
||||
start = self.obs_horizon - 1
|
||||
end = start + self.num_action_steps
|
||||
executable_actions = actions[:, start:end] # (1, num_action_steps, action_dim)
|
||||
|
||||
# 将动作添加到队列
|
||||
for i in range(executable_actions.shape[1]):
|
||||
self._queues['action'].append(executable_actions[:, i].squeeze(0)) # (action_dim,)
|
||||
|
||||
# 从队列中取出一个动作
|
||||
action = self._queues['action'].popleft() # (action_dim,)
|
||||
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
预测一个动作块(用于在线推理)。
|
||||
|
||||
Args:
|
||||
batch: 包含 'qpos' 和 'images' 的字典
|
||||
- qpos: (B, obs_horizon, obs_dim)
|
||||
- images: Dict[str, (B, obs_horizon, C, H, W)]
|
||||
|
||||
Returns:
|
||||
actions: (B, pred_horizon, action_dim) 预测的动作序列
|
||||
"""
|
||||
return self.predict_action(batch['images'], batch['qpos'])
|
||||
|
||||
# ==========================
|
||||
# 批量推理 (Batch Inference - 原有方法)
|
||||
# ==========================
|
||||
@torch.no_grad()
|
||||
def predict_action(self, images, proprioception):
|
||||
"""
|
||||
批量预测动作序列(用于训练和离线评估)
|
||||
|
||||
Args:
|
||||
images: 图像观测字典
|
||||
proprioception: 本体感知观测 (qpos)
|
||||
|
||||
Returns:
|
||||
denormalized_actions: 反归一化后的动作序列
|
||||
"""
|
||||
B = proprioception.shape[0]
|
||||
|
||||
# 归一化 proprioception (qpos)
|
||||
proprioception = self.normalization.normalize_qpos(proprioception)
|
||||
|
||||
# 1. 提取当前观测特征(只提取一次)
|
||||
visual_features = self.vision_encoder(images)
|
||||
state_features = self.state_encoder(proprioception)
|
||||
|
||||
# 拼接条件(只计算一次)
|
||||
# visual_features: (B, obs_horizon, vision_dim)
|
||||
# state_features: (B, obs_horizon, obs_dim)
|
||||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||||
global_cond_flat = global_cond.flatten(start_dim=1)
|
||||
if self.head_type == 'transformer':
|
||||
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
||||
else:
|
||||
cond = None
|
||||
|
||||
# 2. 初始化纯高斯噪声动作
|
||||
# 形状: (B, pred_horizon, action_dim)
|
||||
device = visual_features.device
|
||||
current_actions = torch.randn(
|
||||
(B, self.pred_horizon, self.action_dim), device=device
|
||||
)
|
||||
|
||||
# 3. 逐步去噪循环 (Reverse Diffusion)
|
||||
self.infer_scheduler.set_timesteps(self.inference_steps) # DDIM 推理步数
|
||||
|
||||
for t in self.infer_scheduler.timesteps:
|
||||
model_input = current_actions
|
||||
|
||||
# 预测噪声(根据head类型选择接口)
|
||||
if self.head_type == 'transformer':
|
||||
noise_pred = self.noise_pred_net(
|
||||
sample=model_input,
|
||||
timestep=t,
|
||||
cond=cond
|
||||
)
|
||||
else: # 'unet'
|
||||
noise_pred = self.noise_pred_net(
|
||||
sample=model_input,
|
||||
timestep=t,
|
||||
global_cond=global_cond_flat
|
||||
)
|
||||
|
||||
# 移除噪声,更新 current_actions
|
||||
current_actions = self.infer_scheduler.step(
|
||||
noise_pred, t, current_actions
|
||||
).prev_sample
|
||||
|
||||
# 4. 反归一化动作序列
|
||||
denormalized_actions = self.normalization.denormalize_action(current_actions)
|
||||
|
||||
return denormalized_actions
|
||||
|
||||
def get_normalization_stats(self):
|
||||
"""获取归一化统计信息(用于保存到 checkpoint)"""
|
||||
return self.normalization.get_stats()
|
||||
@@ -1,217 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import deque
|
||||
from typing import Dict
|
||||
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
|
||||
from roboimi.vla.models.normalization import NormalizationModule
|
||||
|
||||
|
||||
class VLAAgentGr00tDiT(nn.Module):
|
||||
"""
|
||||
VLA Agent variant that swaps Transformer1D head with gr00t DiT head.
|
||||
Other components (backbone/encoders/scheduler/queue logic) stay aligned
|
||||
with the existing VLAAgent implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_backbone,
|
||||
state_encoder,
|
||||
action_encoder,
|
||||
head,
|
||||
action_dim,
|
||||
obs_dim,
|
||||
pred_horizon=16,
|
||||
obs_horizon=4,
|
||||
diffusion_steps=100,
|
||||
inference_steps=10,
|
||||
num_cams=3,
|
||||
dataset_stats=None,
|
||||
normalization_type="min_max",
|
||||
num_action_steps=8,
|
||||
):
|
||||
super().__init__()
|
||||
self.action_dim = action_dim
|
||||
self.obs_dim = obs_dim
|
||||
self.pred_horizon = pred_horizon
|
||||
self.obs_horizon = obs_horizon
|
||||
self.num_cams = num_cams
|
||||
self.num_action_steps = num_action_steps
|
||||
self.inference_steps = inference_steps
|
||||
|
||||
self.normalization = NormalizationModule(
|
||||
stats=dataset_stats,
|
||||
normalization_type=normalization_type,
|
||||
)
|
||||
|
||||
self.vision_encoder = vision_backbone
|
||||
single_cam_feat_dim = self.vision_encoder.output_dim
|
||||
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
|
||||
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=diffusion_steps,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
clip_sample=True,
|
||||
prediction_type="epsilon",
|
||||
)
|
||||
self.infer_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=diffusion_steps,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
clip_sample=True,
|
||||
prediction_type="epsilon",
|
||||
)
|
||||
|
||||
if isinstance(head, nn.Module):
|
||||
self.noise_pred_net = head
|
||||
else:
|
||||
self.noise_pred_net = head(
|
||||
input_dim=action_dim,
|
||||
output_dim=action_dim,
|
||||
horizon=pred_horizon,
|
||||
n_obs_steps=obs_horizon,
|
||||
cond_dim=self.per_step_cond_dim,
|
||||
)
|
||||
|
||||
self.state_encoder = state_encoder
|
||||
self.action_encoder = action_encoder
|
||||
self.reset()
|
||||
|
||||
def _get_model_device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
def _move_to_device(self, data, device: torch.device):
|
||||
if torch.is_tensor(data):
|
||||
return data.to(device)
|
||||
if isinstance(data, dict):
|
||||
return {k: self._move_to_device(v, device) for k, v in data.items()}
|
||||
if isinstance(data, list):
|
||||
return [self._move_to_device(v, device) for v in data]
|
||||
if isinstance(data, tuple):
|
||||
return tuple(self._move_to_device(v, device) for v in data)
|
||||
return data
|
||||
|
||||
def _build_cond(self, images: Dict[str, torch.Tensor], states: torch.Tensor) -> torch.Tensor:
|
||||
visual_features = self.vision_encoder(images)
|
||||
state_features = self.state_encoder(states)
|
||||
return torch.cat([visual_features, state_features], dim=-1)
|
||||
|
||||
def compute_loss(self, batch):
|
||||
actions, states, images = batch["action"], batch["qpos"], batch["images"]
|
||||
action_is_pad = batch.get("action_is_pad", None)
|
||||
bsz = actions.shape[0]
|
||||
|
||||
states = self.normalization.normalize_qpos(states)
|
||||
actions = self.normalization.normalize_action(actions)
|
||||
|
||||
action_features = self.action_encoder(actions)
|
||||
cond = self._build_cond(images, states)
|
||||
|
||||
noise = torch.randn_like(action_features)
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
self.noise_scheduler.config.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=action_features.device,
|
||||
).long()
|
||||
noisy_actions = self.noise_scheduler.add_noise(action_features, noise, timesteps)
|
||||
|
||||
pred_noise = self.noise_pred_net(
|
||||
sample=noisy_actions,
|
||||
timestep=timesteps,
|
||||
cond=cond,
|
||||
)
|
||||
loss = nn.functional.mse_loss(pred_noise, noise, reduction="none")
|
||||
|
||||
if action_is_pad is not None:
|
||||
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
|
||||
valid_count = mask.sum() * loss.shape[-1]
|
||||
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||||
else:
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
||||
def reset(self):
|
||||
self._queues = {
|
||||
"qpos": deque(maxlen=self.obs_horizon),
|
||||
"images": deque(maxlen=self.obs_horizon),
|
||||
"action": deque(maxlen=self.pred_horizon - self.obs_horizon + 1),
|
||||
}
|
||||
|
||||
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
|
||||
if "qpos" in observation:
|
||||
self._queues["qpos"].append(observation["qpos"].clone())
|
||||
if "images" in observation:
|
||||
self._queues["images"].append({k: v.clone() for k, v in observation["images"].items()})
|
||||
|
||||
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
||||
qpos_list = list(self._queues["qpos"])
|
||||
if len(qpos_list) == 0:
|
||||
raise ValueError("observation queue is empty.")
|
||||
while len(qpos_list) < self.obs_horizon:
|
||||
qpos_list.append(qpos_list[-1])
|
||||
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0)
|
||||
|
||||
images_list = list(self._queues["images"])
|
||||
if len(images_list) == 0:
|
||||
raise ValueError("image queue is empty.")
|
||||
while len(images_list) < self.obs_horizon:
|
||||
images_list.append(images_list[-1])
|
||||
|
||||
batch_images = {}
|
||||
for cam_name in images_list[0].keys():
|
||||
batch_images[cam_name] = torch.stack(
|
||||
[img[cam_name] for img in images_list], dim=0
|
||||
).unsqueeze(0)
|
||||
|
||||
return {"qpos": batch_qpos, "images": batch_images}
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
device = self._get_model_device()
|
||||
observation = self._move_to_device(observation, device)
|
||||
self._populate_queues(observation)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = self._prepare_observation_batch()
|
||||
actions = self.predict_action_chunk(batch)
|
||||
start = self.obs_horizon - 1
|
||||
end = start + self.num_action_steps
|
||||
executable_actions = actions[:, start:end]
|
||||
for i in range(executable_actions.shape[1]):
|
||||
self._queues["action"].append(executable_actions[:, i].squeeze(0))
|
||||
|
||||
return self._queues["action"].popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
return self.predict_action(batch["images"], batch["qpos"])
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(self, images, proprioception):
|
||||
bsz = proprioception.shape[0]
|
||||
proprioception = self.normalization.normalize_qpos(proprioception)
|
||||
cond = self._build_cond(images, proprioception)
|
||||
|
||||
device = cond.device
|
||||
current_actions = torch.randn((bsz, self.pred_horizon, self.action_dim), device=device)
|
||||
self.infer_scheduler.set_timesteps(self.inference_steps)
|
||||
|
||||
for t in self.infer_scheduler.timesteps:
|
||||
noise_pred = self.noise_pred_net(
|
||||
sample=current_actions,
|
||||
timestep=t,
|
||||
cond=cond,
|
||||
)
|
||||
current_actions = self.infer_scheduler.step(
|
||||
noise_pred, t, current_actions
|
||||
).prev_sample
|
||||
|
||||
return self.normalization.denormalize_action(current_actions)
|
||||
|
||||
def get_normalization_stats(self):
|
||||
return self.normalization.get_stats()
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
# - /backbone@vision_backbone: resnet
|
||||
- /backbone@vision_backbone: resnet_diffusion
|
||||
- /modules@state_encoder: identity_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /head: conditional_unet1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent.VLAAgent
|
||||
|
||||
# ====================
|
||||
# 模型维度配置
|
||||
# ====================
|
||||
action_dim: 16 # 动作维度(机器人关节数)
|
||||
obs_dim: 16 # 本体感知维度(关节位置)
|
||||
|
||||
# ====================
|
||||
#
|
||||
# ====================
|
||||
normalization_type: "min_max" # "min_max" or "gaussian"
|
||||
|
||||
# ====================
|
||||
# 时间步配置
|
||||
# ====================
|
||||
pred_horizon: 16 # 预测未来多少步动作
|
||||
obs_horizon: 2 # 使用多少步历史观测
|
||||
num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1)
|
||||
|
||||
# ====================
|
||||
# 相机配置
|
||||
# ====================
|
||||
num_cams: 3 # 摄像头数量 (r_vis, top, front)
|
||||
|
||||
# ====================
|
||||
# 扩散过程配置
|
||||
# ====================
|
||||
diffusion_steps: 100 # 扩散训练步数(DDPM)
|
||||
inference_steps: 10 # 推理时的去噪步数(DDIM,固定为 10)
|
||||
@@ -1,37 +0,0 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: resnet_diffusion
|
||||
- /modules@state_encoder: identity_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /head: gr00t_dit1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent_gr00t_dit.VLAAgentGr00tDiT
|
||||
|
||||
# Model dimensions
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
|
||||
# Normalization
|
||||
normalization_type: "min_max"
|
||||
|
||||
# Horizons
|
||||
pred_horizon: 16
|
||||
obs_horizon: 2
|
||||
num_action_steps: 8
|
||||
|
||||
# Cameras
|
||||
num_cams: 3
|
||||
|
||||
# Diffusion
|
||||
diffusion_steps: 100
|
||||
inference_steps: 10
|
||||
|
||||
# Head overrides
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
cond_dim: 208
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: resnet_diffusion
|
||||
- /modules@state_encoder: identity_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /head: transformer1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent.VLAAgent
|
||||
|
||||
# ====================
|
||||
# 模型维度配置
|
||||
# ====================
|
||||
action_dim: 16 # 动作维度(机器人关节数)
|
||||
obs_dim: 16 # 本体感知维度(关节位置)
|
||||
|
||||
# ====================
|
||||
# 归一化配置
|
||||
# ====================
|
||||
normalization_type: "min_max" # "min_max" or "gaussian"
|
||||
|
||||
# ====================
|
||||
# 时间步配置
|
||||
# ====================
|
||||
pred_horizon: 16 # 预测未来多少步动作
|
||||
obs_horizon: 2 # 使用多少步历史观测
|
||||
num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1)
|
||||
|
||||
# ====================
|
||||
# 相机配置
|
||||
# ====================
|
||||
num_cams: 3 # 摄像头数量 (r_vis, top, front)
|
||||
|
||||
# ====================
|
||||
# 扩散过程配置
|
||||
# ====================
|
||||
diffusion_steps: 100 # 扩散训练步数(DDPM)
|
||||
inference_steps: 10 # 推理时的去噪步数(DDIM,<4D><EFBC8C>定为 10)
|
||||
|
||||
# ====================
|
||||
# Head 类型标识(用于VLAAgent选择调用方式)
|
||||
# ====================
|
||||
head_type: "transformer" # "unet" 或 "transformer"
|
||||
|
||||
# Head 参数覆盖
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
# Transformer的cond_dim是每步的维度
|
||||
# ResNet18 + SpatialSoftmax(32 keypoints) = 64维/相机
|
||||
# 计算方式:单相机特征(64) * 相机数(3) + obs_dim(16) = 208
|
||||
cond_dim: 208
|
||||
@@ -1,33 +0,0 @@
|
||||
_target_: roboimi.vla.models.backbones.resnet_diffusion.ResNetDiffusionBackbone
|
||||
|
||||
# ====================
|
||||
# 骨干网络选择
|
||||
# ====================
|
||||
vision_backbone: "resnet18" # torchvision 模型名称: resnet18, resnet34, resnet50
|
||||
pretrained_backbone_weights: "IMAGENET1K_V1" # 使用ImageNet预训练权重(torchvision>=0.13)
|
||||
|
||||
# ====================
|
||||
# 冻结设置
|
||||
# ====================
|
||||
freeze_backbone: true # 冻结ResNet参数,只训练后面的pool和out层(推荐:true)
|
||||
|
||||
# ====================
|
||||
# 输入配置
|
||||
# ====================
|
||||
input_shape: [3, 224, 224] # 输入图像形状 (C, H, W) - ImageNet标准尺寸
|
||||
crop_shape: null # 裁剪后的图像形状 (H, W) - 设为null禁用裁剪
|
||||
crop_is_random: true # 训练时使用随机裁剪,评估时使用中心裁剪(crop_shape=null时无效)
|
||||
|
||||
# ====================
|
||||
# 归一化和特征提取
|
||||
# ====================
|
||||
use_group_norm: true # 使用 GroupNorm 替代 BatchNorm(更适合小批次训练)
|
||||
spatial_softmax_num_keypoints: 32 # Spatial Softmax 关键点数量
|
||||
|
||||
# ====================
|
||||
# 编码器模式
|
||||
# ====================
|
||||
# false: 共享编码器(所有摄像头共享一个 ResNet,参数少但容量受限)推荐!
|
||||
# true: 独立编码器(每个摄像头有独立的 ResNet,参数多但容量大)
|
||||
use_separate_rgb_encoder_per_camera: true
|
||||
num_cameras: 3 # 摄像头数量
|
||||
@@ -1,44 +0,0 @@
|
||||
defaults:
|
||||
- agent: resnet_transformer
|
||||
- data: simpe_robot_dataset
|
||||
- eval: eval
|
||||
- _self_
|
||||
|
||||
# ====================
|
||||
# 训练配置
|
||||
# ====================
|
||||
train:
|
||||
# 基础训练参数
|
||||
batch_size: 8 # 批次大小
|
||||
lr: 5e-5 # 学习率(Transformer建议更小)
|
||||
max_steps: 100000 # 最大训练步数
|
||||
device: "cuda" # 设备: "cuda" 或 "cpu"
|
||||
|
||||
# 数据加载
|
||||
num_workers: 8 # DataLoader 工作进程数(调试时设为 0,生产环境用 8)
|
||||
val_split: 0.1 # 验证集比例
|
||||
seed: 42 # 随机种子(用于数据划分)
|
||||
|
||||
# 日志和检查点
|
||||
log_freq: 100 # 日志记录频率(步数)
|
||||
save_freq: 2000 # 保存检查点频率(步数)
|
||||
|
||||
# 学习率调度器(带预热)
|
||||
warmup_steps: 2000 # 预热步数(Transformer建议更长)
|
||||
scheduler_type: "cosine" # 预热后的调度器: "constant" 或 "cosine"
|
||||
min_lr: 1e-6 # 最小学习率(用于余弦退火)
|
||||
|
||||
# 优化器
|
||||
weight_decay: 1e-5 # 权重衰减(L2 正则化)
|
||||
grad_clip: 1.0 # 梯度裁剪阈值
|
||||
|
||||
# 微调配置
|
||||
pretrained_ckpt: null # 预训练 checkpoint 路径(用于微调),例如: "checkpoints/vla_model_step_8000.pt"
|
||||
|
||||
# ====================
|
||||
# 实验配置
|
||||
# ====================
|
||||
experiment:
|
||||
name: "vla_diffusion" # 实验名称
|
||||
notes: "" # 实验备注
|
||||
tags: [] # 实验标签
|
||||
@@ -1,21 +0,0 @@
|
||||
# @package data
|
||||
_target_: roboimi.vla.data.simpe_robot_dataset.SimpleRobotDataset
|
||||
|
||||
# ====================
|
||||
# 数据集路径
|
||||
# ====================
|
||||
dataset_dir: "roboimi/demos/dataset/sim_transfer"
|
||||
|
||||
# ====================
|
||||
# 时间步参数(从 agent 配置引用)
|
||||
# ====================
|
||||
pred_horizon: ${agent.pred_horizon} # 预测步数
|
||||
obs_horizon: ${agent.obs_horizon} # 观测步数
|
||||
|
||||
# ====================
|
||||
# 相机配置
|
||||
# ====================
|
||||
camera_names:
|
||||
- r_vis # 机器人视角相机
|
||||
- top # 顶部相机
|
||||
- front # 前方相机
|
||||
@@ -1,34 +0,0 @@
|
||||
# @package eval
|
||||
# 评估配置
|
||||
ckpt_path: "checkpoints/vla_model_best.pt" # 模型检查点路径
|
||||
num_episodes: 3 # 评估回合数
|
||||
max_timesteps: 700 # 每回合最大时间步
|
||||
device: ${train.device} # 与训练保持一致
|
||||
task_name: "sim_transfer" # 环境任务名称
|
||||
|
||||
# ====================
|
||||
# 策略执行参数
|
||||
# ====================
|
||||
# num_queries 已废弃,现在使用 agent 的 select_action() 自动管理队列
|
||||
# 以下参数仅用于兼容旧代码,实际使用 agent.num_action_steps
|
||||
num_queries: ${agent.num_action_steps}
|
||||
obs_horizon: ${agent.obs_horizon}
|
||||
|
||||
# ====================
|
||||
# 相机配置
|
||||
# ====================
|
||||
camera_names: ${data.camera_names}
|
||||
|
||||
# ====================
|
||||
# 动作平滑
|
||||
# ====================
|
||||
use_smoothing: false
|
||||
smooth_method: "ema"
|
||||
smooth_alpha: 0.3
|
||||
|
||||
# ====================
|
||||
# 调试选项
|
||||
# ====================
|
||||
verbose_action: true # 是否打印每个时间步的动作信息
|
||||
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
_target_: roboimi.vla.models.heads.conditional_unet1d.ConditionalUnet1D
|
||||
_partial_: true
|
||||
|
||||
# ====================
|
||||
# UNet1D 配置
|
||||
# ====================
|
||||
kernel_size: 3 # 卷积核大小
|
||||
cond_predict_scale: false # FiLM 条件化时是否同时预测 scale(bias + scale 或仅 bias)
|
||||
|
||||
# ====================
|
||||
# 网络架构(默认值,可覆盖)
|
||||
# ====================
|
||||
# diffusion_step_embed_dim: 256 # 扩散时间步嵌入维度
|
||||
# down_dims: [256, 512, 1024] # 下采样各层通道数
|
||||
# n_groups: 8 # GroupNorm 分组数
|
||||
@@ -1,22 +0,0 @@
|
||||
_target_: roboimi.vla.models.heads.gr00t_dit1d.Gr00tDiT1D
|
||||
_partial_: true
|
||||
|
||||
# DiT architecture
|
||||
n_layer: 6
|
||||
n_head: 8
|
||||
n_emb: 256
|
||||
hidden_dim: 256
|
||||
mlp_ratio: 4
|
||||
dropout: 0.1
|
||||
|
||||
# Positional embeddings
|
||||
add_action_pos_emb: true
|
||||
add_cond_pos_emb: true
|
||||
|
||||
# Supplied by agent interpolation:
|
||||
# - input_dim
|
||||
# - output_dim
|
||||
# - horizon
|
||||
# - n_obs_steps
|
||||
# - cond_dim
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
# Transformer-based Diffusion Policy Head
|
||||
_target_: roboimi.vla.models.heads.transformer1d.Transformer1D
|
||||
_partial_: true
|
||||
|
||||
# ====================
|
||||
# Transformer 架构配置
|
||||
# ====================
|
||||
n_layer: 4 # Transformer层数(先用小模型提高收敛稳定性)
|
||||
n_head: 4 # 注意力头数
|
||||
n_emb: 128 # 嵌入维度
|
||||
p_drop_emb: 0.05 # Embedding dropout
|
||||
p_drop_attn: 0.05 # Attention dropout
|
||||
|
||||
# ====================
|
||||
# 条件配置
|
||||
# ====================
|
||||
causal_attn: false # 是否使用因果注意力(自回归生成)
|
||||
obs_as_cond: true # 观测作为条件(由cond_dim > 0决定)
|
||||
n_cond_layers: 1 # 条件编码器层数(1层先做稳定融合)
|
||||
|
||||
# ====================
|
||||
# 注意事项
|
||||
# ====================
|
||||
# 以下参数将在agent配置中通过interpolation提供:
|
||||
# - input_dim: ${agent.action_dim}
|
||||
# - output_dim: ${agent.action_dim}
|
||||
# - horizon: ${agent.pred_horizon}
|
||||
# - n_obs_steps: ${agent.obs_horizon}
|
||||
# - cond_dim: 通过agent中的global_cond_dim计算
|
||||
@@ -1 +0,0 @@
|
||||
_target_: roboimi.vla.modules.encoders.IdentityActionEncoder
|
||||
@@ -1 +0,0 @@
|
||||
_target_: roboimi.vla.modules.encoders.IdentityStateEncoder
|
||||
@@ -1,46 +0,0 @@
|
||||
import abc
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
class VLABackbone(nn.Module, abc.ABC):
|
||||
"""
|
||||
Contract for Vision/Language Backbones.
|
||||
Must return a feature tensor of shape (B, Seq, Embed_Dim).
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def forward(self, obs: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
obs: Dictionary containing 'image' and optionally 'text'.
|
||||
Returns:
|
||||
features: (B, S, D) embedding.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class VLAProjector(nn.Module, abc.ABC):
|
||||
"""
|
||||
Contract for the adaptation layer (Projector).
|
||||
Connects Backbone features to the Policy Head.
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
class VLAHead(nn.Module, abc.ABC):
|
||||
"""
|
||||
Contract for Action Generation Heads (Policies).
|
||||
Handles both training (loss calculation) and inference (action generation).
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def forward(self, embeddings: torch.Tensor, actions: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
embeddings: (B, S, Hidden) from Projector.
|
||||
actions: (B, Pred_Horizon, Action_Dim) - Ground truth for training.
|
||||
Returns:
|
||||
Dict containing 'loss' (if actions provided) or 'pred_actions'.
|
||||
"""
|
||||
pass
|
||||
@@ -1,242 +0,0 @@
|
||||
import torch
|
||||
import h5py
|
||||
from torch.utils.data import Dataset
|
||||
from typing import List, Dict, Union
|
||||
from pathlib import Path
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class SimpleRobotDataset(Dataset):
|
||||
"""
|
||||
HDF5 懒加载数据集 - LeRobotDataset 格式
|
||||
|
||||
返回格式:
|
||||
- observation.state: (obs_horizon, state_dim)
|
||||
- observation.{cam_name}: (obs_horizon, C, H, W)
|
||||
- action: (pred_horizon, action_dim)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir: Union[str, Path],
|
||||
obs_horizon: int = 2,
|
||||
pred_horizon: int = 8,
|
||||
camera_names: List[str] = None,
|
||||
max_open_files: int = 64,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dataset_dir: HDF5 文件目录路径
|
||||
obs_horizon: 观察过去多少帧
|
||||
pred_horizon: 预测未来多少帧动作
|
||||
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
|
||||
max_open_files: 每个 worker 最多缓存的 HDF5 文件句柄数
|
||||
|
||||
HDF5 文件格式:
|
||||
- action: [T, action_dim]
|
||||
- observations/qpos: [T, obs_dim]
|
||||
- observations/images/{cam_name}: [T, H, W, C]
|
||||
"""
|
||||
self.obs_horizon = obs_horizon
|
||||
self.pred_horizon = pred_horizon
|
||||
self.camera_names = camera_names or []
|
||||
self.max_open_files = max(1, int(max_open_files))
|
||||
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
|
||||
|
||||
self.dataset_dir = Path(dataset_dir)
|
||||
if not self.dataset_dir.exists():
|
||||
raise FileNotFoundError(f"数据集目录不存在: {dataset_dir}")
|
||||
|
||||
# 查找 HDF5 文件
|
||||
self.hdf5_files = sorted(self.dataset_dir.glob("*.hdf5"))
|
||||
if not self.hdf5_files:
|
||||
self.hdf5_files = sorted(self.dataset_dir.glob("episode_*.hdf5"))
|
||||
if not self.hdf5_files:
|
||||
raise FileNotFoundError(f"在 {dataset_dir} 中未找到 HDF5 文件")
|
||||
|
||||
# 构建 episode 索引(只存储元数据,不加载数据)
|
||||
self.episodes = {}
|
||||
self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
|
||||
for ep_idx, hdf5_path in enumerate(self.hdf5_files):
|
||||
with h5py.File(hdf5_path, 'r') as f:
|
||||
T = f['action'].shape[0]
|
||||
start_idx = len(self.frame_meta)
|
||||
for t in range(T):
|
||||
self.frame_meta.append({
|
||||
"ep_idx": ep_idx,
|
||||
"frame_idx": t,
|
||||
"hdf5_path": hdf5_path,
|
||||
})
|
||||
self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta)))
|
||||
|
||||
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)} 帧")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frame_meta)
|
||||
|
||||
def _close_all_files(self) -> None:
|
||||
"""关闭当前 worker 内缓存的所有 HDF5 文件句柄。"""
|
||||
for f in self._file_cache.values():
|
||||
try:
|
||||
f.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._file_cache.clear()
|
||||
|
||||
def _get_h5_file(self, hdf5_path: Union[str, Path]) -> h5py.File:
|
||||
"""
|
||||
获取 HDF5 文件句柄(worker 内 LRU 缓存)。
|
||||
注意:缓存的是文件句柄,不是帧数据本身。
|
||||
"""
|
||||
key = str(hdf5_path)
|
||||
if key in self._file_cache:
|
||||
self._file_cache.move_to_end(key)
|
||||
return self._file_cache[key]
|
||||
|
||||
# 超过上限时淘汰最久未使用的句柄
|
||||
if len(self._file_cache) >= self.max_open_files:
|
||||
_, old_file = self._file_cache.popitem(last=False)
|
||||
try:
|
||||
old_file.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
f = h5py.File(key, 'r')
|
||||
self._file_cache[key] = f
|
||||
return f
|
||||
|
||||
def _load_frame(self, idx: int) -> Dict:
|
||||
"""从 HDF5 文件懒加载单帧数据"""
|
||||
meta = self.frame_meta[idx]
|
||||
f = self._get_h5_file(meta["hdf5_path"])
|
||||
frame = {
|
||||
"episode_index": meta["ep_idx"],
|
||||
"frame_index": meta["frame_idx"],
|
||||
"task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown",
|
||||
"observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(),
|
||||
"action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(),
|
||||
}
|
||||
|
||||
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
|
||||
for cam_name in self.camera_names:
|
||||
h5_path = f'observations/images/{cam_name}'
|
||||
if h5_path in f:
|
||||
img = f[h5_path][meta["frame_idx"]]
|
||||
# Resize图像到224x224(减少内存和I/O负担)
|
||||
import cv2
|
||||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
||||
# 转换为float并归一化到 [0, 1]
|
||||
img = torch.from_numpy(img).float() / 255.0
|
||||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||
|
||||
return frame
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
frame = self._load_frame(idx)
|
||||
ep_idx = frame["episode_index"]
|
||||
|
||||
# 获取当前 episode 的帧索引范围
|
||||
ep_indices = self.episodes[ep_idx]
|
||||
ep_start = ep_indices[0]
|
||||
ep_end = ep_indices[-1]
|
||||
|
||||
# ============================================
|
||||
# 1. 加载观察(过去 obs_horizon 帧)
|
||||
# ============================================
|
||||
observations = {
|
||||
"state": [], # 状态数据
|
||||
}
|
||||
# 为每个摄像头初始化独立列表
|
||||
for cam_name in self.camera_names:
|
||||
observations[f"observation.{cam_name}"] = []
|
||||
|
||||
observation_is_pad = []
|
||||
|
||||
for delta in range(-self.obs_horizon + 1, 1): # [-1, 0] for obs_horizon=2
|
||||
target_idx = idx + delta
|
||||
|
||||
# 边界检查
|
||||
if ep_start <= target_idx <= ep_end:
|
||||
target_frame = self._load_frame(target_idx)
|
||||
is_pad = False
|
||||
else:
|
||||
# 超出边界,用边界帧填充
|
||||
if target_idx < ep_start:
|
||||
target_frame = self._load_frame(ep_start)
|
||||
else:
|
||||
target_frame = self._load_frame(ep_end)
|
||||
is_pad = True
|
||||
|
||||
# 收集状态
|
||||
observations["state"].append(target_frame["observation.state"])
|
||||
|
||||
# 收集每个摄像头的图像
|
||||
for cam_name in self.camera_names:
|
||||
observations[f"observation.{cam_name}"].append(target_frame[f"observation.{cam_name}"])
|
||||
|
||||
observation_is_pad.append(is_pad)
|
||||
|
||||
# ============================================
|
||||
# 2. 加载动作(未来 pred_horizon 帧)
|
||||
# ============================================
|
||||
actions = []
|
||||
action_is_pad = []
|
||||
|
||||
for delta in range(self.pred_horizon):
|
||||
target_idx = idx + delta
|
||||
|
||||
if target_idx <= ep_end:
|
||||
actions.append(self._load_frame(target_idx)["action"])
|
||||
action_is_pad.append(False)
|
||||
else:
|
||||
actions.append(self._load_frame(ep_end)["action"])
|
||||
action_is_pad.append(True)
|
||||
|
||||
# ============================================
|
||||
# 3. 组装返回数据(LeRobotDataset 格式)
|
||||
# ============================================
|
||||
result = {
|
||||
# 状态观察: (obs_horizon, state_dim)
|
||||
"observation.state": torch.stack(observations["state"]),
|
||||
"observation_is_pad": torch.tensor(observation_is_pad, dtype=torch.bool),
|
||||
|
||||
# 动作: (pred_horizon, action_dim)
|
||||
"action": torch.stack(actions),
|
||||
"action_is_pad": torch.tensor(action_is_pad, dtype=torch.bool),
|
||||
|
||||
# 任务
|
||||
"task": frame["task"],
|
||||
}
|
||||
|
||||
# 图像:每个摄像头独立的 key
|
||||
# 形状: (obs_horizon, C, H, W)
|
||||
for cam_name in self.camera_names:
|
||||
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""获取所有相机键名 (LeRobotDataset 格式)"""
|
||||
return [f"observation.{cam_name}" for cam_name in self.camera_names]
|
||||
|
||||
@property
|
||||
def camera_info(self) -> dict:
|
||||
"""获取相机信息"""
|
||||
if not self.camera_names:
|
||||
return {}
|
||||
|
||||
# 从第一个样本获取形状
|
||||
sample = self[0]
|
||||
info = {}
|
||||
for cam_name in self.camera_names:
|
||||
key = f"observation.{cam_name}"
|
||||
if key in sample:
|
||||
info[key] = {
|
||||
"shape": sample[key].shape,
|
||||
"dtype": str(sample[key].dtype),
|
||||
}
|
||||
return info
|
||||
|
||||
def __del__(self):
|
||||
self._close_all_files()
|
||||
@@ -1,4 +0,0 @@
|
||||
# Backbone models
|
||||
from .resnet_diffusion import ResNetDiffusionBackbone
|
||||
|
||||
__all__ = ["ResNetBackbone", "ResNetDiffusionBackbone"]
|
||||
@@ -1,372 +0,0 @@
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
import numpy as np
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
root_module: 需要替换子模块的根模块
|
||||
predicate: 接受一个模块作为参数,如果该模块需要被替换则返回 True。
|
||||
func: 接受一个模块作为参数,并返回一个新的模块来替换它。
|
||||
Returns:
|
||||
子模块已被替换的根模块。
|
||||
"""
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
parent_module = root_module.get_submodule(".".join(parents))
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
src_module = parent_module[int(k)]
|
||||
else:
|
||||
src_module = getattr(parent_module, k)
|
||||
tgt_module = func(src_module)
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
parent_module[int(k)] = tgt_module
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# 验证所有 BN 是否已被替换
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
return root_module
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
Finn 等人在 "Deep Spatial Autoencoders for Visuomotor Learning" 中描述的空间软 Argmax 操作
|
||||
(https://huggingface.co/papers/1509.06113)。这是 robomimic 实现的一个最小移植版本。
|
||||
"""
|
||||
|
||||
def __init__(self, input_shape, num_kp=None):
|
||||
"""
|
||||
Args:
|
||||
input_shape (list): (C, H, W) 输入特征图形状。
|
||||
num_kp (int): 输出中的关键点数量。如果为 None,输出将具有与输入相同的通道数。
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3
|
||||
self._in_c, self._in_h, self._in_w = input_shape
|
||||
|
||||
if num_kp is not None:
|
||||
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||
self._out_c = num_kp
|
||||
else:
|
||||
self.nets = None
|
||||
self._out_c = self._in_c
|
||||
|
||||
# 我们可以直接使用 torch.linspace,但这似乎与 numpy 的行为略有不同
|
||||
# 并且会导致预训练模型的 pc_success 略有下降。
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# 注册为 buffer,以便将其移动到正确的设备。
|
||||
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
features: (B, C, H, W) 输入特征图。
|
||||
Returns:
|
||||
(B, K, 2) 关键点的图像空间坐标。
|
||||
"""
|
||||
if self.nets is not None:
|
||||
features = self.nets(features)
|
||||
|
||||
# [B, K, H, W] -> [B * K, H * W],其中 K 是关键点数量
|
||||
features = features.reshape(-1, self._in_h * self._in_w)
|
||||
# 2d softmax 归一化
|
||||
attention = F.softmax(features, dim=-1)
|
||||
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] 用于 x 和 y 维度的空间坐标均值
|
||||
expected_xy = attention @ self.pos_grid
|
||||
# 重塑为 [B, K, 2]
|
||||
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||
|
||||
return feature_keypoints
|
||||
|
||||
class _SingleRgbEncoder(nn.Module):
|
||||
"""单个摄像头的 RGB 编码器,支持独立或共享使用"""
|
||||
def __init__(
|
||||
self,
|
||||
vision_backbone: str,
|
||||
pretrained_backbone_weights: str | None,
|
||||
input_shape: Tuple[int, int, int],
|
||||
crop_shape: Optional[Tuple[int, int]],
|
||||
crop_is_random: bool,
|
||||
use_group_norm: bool,
|
||||
spatial_softmax_num_keypoints: int,
|
||||
freeze_backbone: bool = True, # 新增:是否冻结backbone
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 设置可选的预处理
|
||||
if crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# 评估时始终使用中心裁剪
|
||||
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
|
||||
if crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
self.do_crop = False
|
||||
crop_shape = input_shape[1:]
|
||||
|
||||
# 设置骨干网络
|
||||
backbone_model = getattr(torchvision.models, vision_backbone)(
|
||||
weights=pretrained_backbone_weights
|
||||
)
|
||||
|
||||
# 移除 AvgPool 和 FC (假设 layer4 是 children()[-3])
|
||||
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||
|
||||
if use_group_norm:
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
|
||||
# 冻结backbone参数(可选)
|
||||
if freeze_backbone:
|
||||
for param in self.backbone.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# 设置池化和最终层
|
||||
# 使用试运行来获取特征图形状
|
||||
dummy_shape = (1, input_shape[0], *crop_shape)
|
||||
with torch.no_grad():
|
||||
dummy_out = self.backbone(torch.zeros(dummy_shape))
|
||||
feature_map_shape = dummy_out.shape[1:] # (C, H, W)
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=spatial_softmax_num_keypoints)
|
||||
self.feature_dim = spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# 注册ImageNet标准化参数为buffer(会自动移到GPU)
|
||||
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
||||
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
||||
|
||||
def forward_single_image(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.do_crop:
|
||||
x = self.maybe_random_crop(x) if self.training else self.center_crop(x)
|
||||
|
||||
# ImageNet标准化(预训练权重期望的输入分布)
|
||||
x = (x - self.mean) / self.std
|
||||
|
||||
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||
return x
|
||||
|
||||
|
||||
class ResNetDiffusionBackbone(VLABackbone):
|
||||
def __init__(
|
||||
self,
|
||||
vision_backbone: str = "resnet18",
|
||||
pretrained_backbone_weights: str | None = None,
|
||||
input_shape: Tuple[int, int, int] = (3, 84, 84), # (C, H, W)
|
||||
crop_shape: Optional[Tuple[int, int]] = None,
|
||||
crop_is_random: bool = True,
|
||||
use_group_norm: bool = True,
|
||||
spatial_softmax_num_keypoints: int = 32,
|
||||
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
|
||||
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
||||
freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True)
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
||||
self.num_cameras = num_cameras
|
||||
|
||||
if use_separate_rgb_encoder_per_camera:
|
||||
# 独立编码器模式:为每个摄像头创建独立的编码器
|
||||
encoders = [
|
||||
_SingleRgbEncoder(
|
||||
vision_backbone=vision_backbone,
|
||||
pretrained_backbone_weights=pretrained_backbone_weights,
|
||||
input_shape=input_shape,
|
||||
crop_shape=crop_shape,
|
||||
crop_is_random=crop_is_random,
|
||||
use_group_norm=use_group_norm,
|
||||
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||
freeze_backbone=freeze_backbone,
|
||||
)
|
||||
for _ in range(num_cameras)
|
||||
]
|
||||
self.rgb_encoder = nn.ModuleList(encoders)
|
||||
# 重要:output_dim 始终表示单个编码器的特征维度(与 lerobot 保持一致)
|
||||
self.feature_dim = encoders[0].feature_dim
|
||||
else:
|
||||
# 共享编码器模式:所有摄像头共享同一个编码器
|
||||
self.rgb_encoder = _SingleRgbEncoder(
|
||||
vision_backbone=vision_backbone,
|
||||
pretrained_backbone_weights=pretrained_backbone_weights,
|
||||
input_shape=input_shape,
|
||||
crop_shape=crop_shape,
|
||||
crop_is_random=crop_is_random,
|
||||
use_group_norm=use_group_norm,
|
||||
spatial_softmax_num_keypoints=spatial_softmax_num_keypoints,
|
||||
freeze_backbone=freeze_backbone,
|
||||
)
|
||||
self.feature_dim = self.rgb_encoder.feature_dim
|
||||
|
||||
def forward(self, images):
|
||||
"""
|
||||
Args:
|
||||
images: Dict[str, Tensor], 每个摄像头的图像
|
||||
形状: {cam_name: (B, T, C, H, W)}
|
||||
|
||||
Returns:
|
||||
Tensor: (B, T, total_feature_dim)
|
||||
"""
|
||||
any_tensor = next(iter(images.values()))
|
||||
B, T = any_tensor.shape[:2]
|
||||
cam_names = sorted(images.keys())
|
||||
|
||||
if self.use_separate_rgb_encoder_per_camera:
|
||||
# 独立编码器模式:每个摄像头使用对应的编码器
|
||||
features_all = []
|
||||
for cam_idx, cam_name in enumerate(cam_names):
|
||||
img = images[cam_name]
|
||||
encoder = self.rgb_encoder[cam_idx]
|
||||
features = encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
else:
|
||||
# 共享编码器模式:所有摄像头共享同一个编码器
|
||||
features_all = []
|
||||
for cam_name in cam_names:
|
||||
img = images[cam_name]
|
||||
features = self.rgb_encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
|
||||
features_all.append(features)
|
||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||
|
||||
@property
|
||||
def output_dim(self):
|
||||
return self.feature_dim
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("🚀 Testing ResNetDiffusionBackbone")
|
||||
print("=" * 60)
|
||||
|
||||
# Configuration
|
||||
B, T = 2, 5
|
||||
C, H, W = 3, 96, 96
|
||||
crop_h, crop_w = 84, 84
|
||||
num_keypoints = 32
|
||||
feature_dim_per_cam = num_keypoints * 2
|
||||
|
||||
# Create dummy input (2 cameras)
|
||||
images = {
|
||||
"cam_high": torch.randn(B, T, C, H, W),
|
||||
"cam_wrist": torch.randn(B, T, C, H, W)
|
||||
}
|
||||
num_cameras = len(images)
|
||||
|
||||
# ============================================================================
|
||||
# Test 1: Shared Encoder (默认模式)
|
||||
# ============================================================================
|
||||
print("\n[Test 1] Shared Encoder Mode")
|
||||
print("-" * 60)
|
||||
backbone_shared = ResNetDiffusionBackbone(
|
||||
vision_backbone="resnet18",
|
||||
pretrained_backbone_weights=None, # Speed up test
|
||||
input_shape=(C, H, W),
|
||||
crop_shape=(crop_h, crop_w),
|
||||
crop_is_random=True,
|
||||
use_group_norm=True,
|
||||
spatial_softmax_num_keypoints=num_keypoints,
|
||||
use_separate_rgb_encoder_per_camera=False, # 共享编码器
|
||||
)
|
||||
|
||||
print(f"✅ Shared encoder model instantiated")
|
||||
print(f" Output dim per camera: {feature_dim_per_cam}")
|
||||
print(f" Number of cameras: {num_cameras}")
|
||||
print(f" Expected total dim: {num_cameras * feature_dim_per_cam}")
|
||||
|
||||
output = backbone_shared(images)
|
||||
print(f"\n🔄 Forward pass completed")
|
||||
print(f" Input shapes: {[v.shape for v in images.values()]}")
|
||||
print(f" Output shape: {output.shape}")
|
||||
|
||||
expected_dim = num_cameras * feature_dim_per_cam
|
||||
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
|
||||
print(f"✨ Test passed!")
|
||||
|
||||
# ============================================================================
|
||||
# Test 2: Separate Encoders (独立编码器模式)
|
||||
# ============================================================================
|
||||
print("\n[Test 2] Separate Encoders Mode")
|
||||
print("-" * 60)
|
||||
backbone_separate = ResNetDiffusionBackbone(
|
||||
vision_backbone="resnet18",
|
||||
pretrained_backbone_weights=None, # Speed up test
|
||||
input_shape=(C, H, W),
|
||||
crop_shape=(crop_h, crop_w),
|
||||
crop_is_random=True,
|
||||
use_group_norm=True,
|
||||
spatial_softmax_num_keypoints=num_keypoints,
|
||||
use_separate_rgb_encoder_per_camera=True, # 独立编码器
|
||||
num_cameras=num_cameras,
|
||||
)
|
||||
|
||||
print(f"✅ Separate encoders model instantiated")
|
||||
print(f" Output dim per camera: {feature_dim_per_cam}")
|
||||
print(f" Number of cameras: {num_cameras}")
|
||||
print(f" Number of encoders: {len(backbone_separate.rgb_encoder)}")
|
||||
|
||||
output = backbone_separate(images)
|
||||
print(f"\n🔄 Forward pass completed")
|
||||
print(f" Input shapes: {[v.shape for v in images.values()]}")
|
||||
print(f" Output shape: {output.shape}")
|
||||
|
||||
expected_dim = num_cameras * feature_dim_per_cam
|
||||
assert output.shape == (B, T, expected_dim), f"Expected shape {(B, T, expected_dim)}, got {output.shape}"
|
||||
print(f"✨ Test passed!")
|
||||
|
||||
# ============================================================================
|
||||
# Test 3: Verify parameters count
|
||||
# ============================================================================
|
||||
print("\n[Test 3] Parameter Count Comparison")
|
||||
print("-" * 60)
|
||||
shared_params = sum(p.numel() for p in backbone_shared.parameters())
|
||||
separate_params = sum(p.numel() for p in backbone_separate.parameters())
|
||||
|
||||
print(f" Shared encoder parameters: {shared_params:,}")
|
||||
print(f" Separate encoders parameters: {separate_params:,}")
|
||||
print(f" Ratio: {separate_params / shared_params:.2f}x")
|
||||
|
||||
assert separate_params > shared_params, "Separate encoders should have more parameters"
|
||||
print(f"✨ Verification passed!")
|
||||
|
||||
# ============================================================================
|
||||
# Test 4: Verify independent parameters
|
||||
# ============================================================================
|
||||
print("\n[Test 4] Verify Independent Parameters")
|
||||
print("-" * 60)
|
||||
# Check that encoders have independent parameters
|
||||
encoder_0_first_param = list(backbone_separate.rgb_encoder[0].parameters())[0]
|
||||
encoder_1_first_param = list(backbone_separate.rgb_encoder[1].parameters())[0]
|
||||
|
||||
# Modify first encoder's parameter
|
||||
with torch.no_grad():
|
||||
encoder_0_first_param += 1.0
|
||||
|
||||
# Verify they are not the same tensor
|
||||
assert not torch.allclose(encoder_0_first_param, encoder_1_first_param), \
|
||||
"Encoders should have independent parameters"
|
||||
|
||||
print(f"✅ Encoders have independent parameters")
|
||||
print(f"✨ All tests passed!")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 All tests completed successfully!")
|
||||
print("=" * 60)
|
||||
@@ -1,5 +0,0 @@
|
||||
# Action Head models
|
||||
from .conditional_unet1d import ConditionalUnet1D
|
||||
from .transformer1d import Transformer1D
|
||||
|
||||
__all__ = ["ConditionalUnet1D", "Transformer1D"]
|
||||
@@ -1,256 +0,0 @@
|
||||
# Diffusion Policy Action Head 实现
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Optional
|
||||
from diffusers import DDPMScheduler
|
||||
from roboimi.vla.core.interfaces import VLAHead
|
||||
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import einops
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops.layers.torch import Rearrange
|
||||
import math
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
'''
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
'''
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
class ConditionalResidualBlock1D(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
cond_dim,
|
||||
kernel_size=3,
|
||||
n_groups=8,
|
||||
cond_predict_scale=False):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList([
|
||||
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
])
|
||||
|
||||
|
||||
|
||||
cond_channels = out_channels
|
||||
if cond_predict_scale:
|
||||
cond_channels = out_channels * 2
|
||||
self.cond_predict_scale = cond_predict_scale
|
||||
self.out_channels = out_channels
|
||||
self.cond_encoder = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(cond_dim, cond_channels),
|
||||
Rearrange('batch t -> batch t 1'),
|
||||
)
|
||||
|
||||
# make sure dimensions compatible
|
||||
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
|
||||
if in_channels != out_channels else nn.Identity()
|
||||
|
||||
def forward(self, x, cond):
|
||||
'''
|
||||
x : [ batch_size x in_channels x horizon ]
|
||||
cond : [ batch_size x cond_dim]
|
||||
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
'''
|
||||
out = self.blocks[0](x)
|
||||
embed = self.cond_encoder(cond)
|
||||
if self.cond_predict_scale:
|
||||
embed = embed.reshape(
|
||||
embed.shape[0], 2, self.out_channels, 1)
|
||||
scale = embed[:,0,...]
|
||||
bias = embed[:,1,...]
|
||||
out = scale * out + bias
|
||||
else:
|
||||
out = out + embed
|
||||
out = self.blocks[1](out)
|
||||
out = out + self.residual_conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class ConditionalUnet1D(nn.Module):
|
||||
def __init__(self,
|
||||
input_dim,
|
||||
global_cond_dim=None,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=[256,512,1024],
|
||||
kernel_size=3,
|
||||
n_groups=8,
|
||||
cond_predict_scale=False
|
||||
):
|
||||
super().__init__()
|
||||
all_dims = [input_dim] + list(down_dims)
|
||||
start_dim = down_dims[0]
|
||||
|
||||
dsed = diffusion_step_embed_dim
|
||||
diffusion_step_encoder = nn.Sequential(
|
||||
SinusoidalPosEmb(dsed),
|
||||
nn.Linear(dsed, dsed * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dsed * 4, dsed),
|
||||
)
|
||||
cond_dim = dsed
|
||||
if global_cond_dim is not None:
|
||||
cond_dim += global_cond_dim
|
||||
|
||||
in_out = list(zip(all_dims[:-1], all_dims[1:]))
|
||||
|
||||
mid_dim = all_dims[-1]
|
||||
self.mid_modules = nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim, mid_dim, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim, mid_dim, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale
|
||||
),
|
||||
])
|
||||
|
||||
down_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
down_modules.append(nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in, dim_out, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out, dim_out, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
up_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
up_modules.append(nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out*2, dim_in, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in, dim_in, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
final_conv = nn.Sequential(
|
||||
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
|
||||
nn.Conv1d(start_dim, input_dim, 1),
|
||||
)
|
||||
|
||||
self.diffusion_step_encoder = diffusion_step_encoder
|
||||
self.up_modules = up_modules
|
||||
self.down_modules = down_modules
|
||||
self.final_conv = final_conv
|
||||
|
||||
|
||||
def forward(self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
global_cond=None,
|
||||
**kwargs):
|
||||
"""
|
||||
x: (B,T,input_dim)
|
||||
timestep: (B,) or int, diffusion step
|
||||
global_cond: (B,global_cond_dim)
|
||||
output: (B,T,input_dim)
|
||||
"""
|
||||
sample = einops.rearrange(sample, 'b h t -> b t h')
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
global_feature = self.diffusion_step_encoder(timesteps)
|
||||
|
||||
if global_cond is not None:
|
||||
global_feature = torch.cat([
|
||||
global_feature, global_cond
|
||||
], axis=-1)
|
||||
|
||||
x = sample
|
||||
h = []
|
||||
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||
x = resnet(x, global_feature)
|
||||
x = resnet2(x, global_feature)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
for mid_module in self.mid_modules:
|
||||
x = mid_module(x, global_feature)
|
||||
|
||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, global_feature)
|
||||
x = resnet2(x, global_feature)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
x = einops.rearrange(x, 'b t h -> b h t')
|
||||
return x
|
||||
@@ -1,146 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional, Union
|
||||
from pathlib import Path
|
||||
import importlib.util
|
||||
|
||||
|
||||
def _load_gr00t_dit():
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
dit_path = repo_root / "gr00t" / "models" / "dit.py"
|
||||
spec = importlib.util.spec_from_file_location("gr00t_dit_standalone", dit_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Unable to load DiT from {dit_path}")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module.DiT
|
||||
|
||||
|
||||
DiT = _load_gr00t_dit()
|
||||
|
||||
|
||||
class Gr00tDiT1D(nn.Module):
|
||||
"""
|
||||
Adapter that wraps gr00t DiT with the same call signature used by VLA heads.
|
||||
|
||||
Expected forward interface:
|
||||
- sample: (B, T_action, input_dim)
|
||||
- timestep: (B,) or scalar diffusion timestep
|
||||
- cond: (B, T_obs, cond_dim)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
horizon: int,
|
||||
n_obs_steps: int,
|
||||
cond_dim: int,
|
||||
n_layer: int = 8,
|
||||
n_head: int = 8,
|
||||
n_emb: int = 256,
|
||||
hidden_dim: int = 256,
|
||||
mlp_ratio: int = 4,
|
||||
dropout: float = 0.1,
|
||||
add_action_pos_emb: bool = True,
|
||||
add_cond_pos_emb: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if cond_dim <= 0:
|
||||
raise ValueError("Gr00tDiT1D requires cond_dim > 0.")
|
||||
|
||||
self.horizon = horizon
|
||||
self.n_obs_steps = n_obs_steps
|
||||
|
||||
self.input_proj = nn.Linear(input_dim, n_emb)
|
||||
self.cond_proj = nn.Linear(cond_dim, n_emb)
|
||||
self.output_proj = nn.Linear(hidden_dim, output_dim)
|
||||
|
||||
self.action_pos_emb = (
|
||||
nn.Parameter(torch.zeros(1, horizon, n_emb))
|
||||
if add_action_pos_emb
|
||||
else None
|
||||
)
|
||||
self.cond_pos_emb = (
|
||||
nn.Parameter(torch.zeros(1, n_obs_steps, n_emb))
|
||||
if add_cond_pos_emb
|
||||
else None
|
||||
)
|
||||
|
||||
args = SimpleNamespace(
|
||||
embed_dim=n_emb,
|
||||
nheads=n_head,
|
||||
mlp_ratio=mlp_ratio,
|
||||
dropout=dropout,
|
||||
num_layers=n_layer,
|
||||
hidden_dim=hidden_dim,
|
||||
)
|
||||
self.dit = DiT(args, cross_attention_dim=n_emb)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
if self.action_pos_emb is not None:
|
||||
nn.init.normal_(self.action_pos_emb, mean=0.0, std=0.02)
|
||||
if self.cond_pos_emb is not None:
|
||||
nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
|
||||
|
||||
def _normalize_timesteps(
|
||||
self,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
if not torch.is_tensor(timestep):
|
||||
timesteps = torch.tensor([timestep], device=device)
|
||||
else:
|
||||
timesteps = timestep.to(device)
|
||||
|
||||
if timesteps.ndim == 0:
|
||||
timesteps = timesteps[None]
|
||||
if timesteps.shape[0] != batch_size:
|
||||
timesteps = timesteps.expand(batch_size)
|
||||
|
||||
return timesteps.long()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if cond is None:
|
||||
raise ValueError("`cond` is required for Gr00tDiT1D forward.")
|
||||
|
||||
bsz, t_act, _ = sample.shape
|
||||
if t_act > self.horizon:
|
||||
raise ValueError(
|
||||
f"sample length {t_act} exceeds configured horizon {self.horizon}"
|
||||
)
|
||||
|
||||
hidden_states = self.input_proj(sample)
|
||||
if self.action_pos_emb is not None:
|
||||
hidden_states = hidden_states + self.action_pos_emb[:, :t_act, :]
|
||||
|
||||
encoder_hidden_states = self.cond_proj(cond)
|
||||
if self.cond_pos_emb is not None:
|
||||
t_obs = encoder_hidden_states.shape[1]
|
||||
if t_obs > self.n_obs_steps:
|
||||
raise ValueError(
|
||||
f"cond length {t_obs} exceeds configured n_obs_steps {self.n_obs_steps}"
|
||||
)
|
||||
encoder_hidden_states = (
|
||||
encoder_hidden_states + self.cond_pos_emb[:, :t_obs, :]
|
||||
)
|
||||
|
||||
timesteps = self._normalize_timesteps(
|
||||
timestep, batch_size=bsz, device=sample.device
|
||||
)
|
||||
dit_output = self.dit(
|
||||
hidden_states=hidden_states,
|
||||
timestep=timesteps,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
return self.output_proj(dit_output)
|
||||
@@ -1,396 +0,0 @@
|
||||
"""
|
||||
Transformer-based Diffusion Policy Head
|
||||
|
||||
使用Transformer架构(Encoder-Decoder)替代UNet进行噪声预测。
|
||||
支持通过Cross-Attention注入全局条件(观测特征)。
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
"""正弦位置编码(用于时间步嵌入)"""
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class Transformer1D(nn.Module):
|
||||
"""
|
||||
Transformer-based 1D Diffusion Model
|
||||
|
||||
使用Encoder-Decoder架构:
|
||||
- Encoder: 处理条件(观测 + 时间步)
|
||||
- Decoder: 通过Cross-Attention预测噪声
|
||||
|
||||
Args:
|
||||
input_dim: 输入动作维度
|
||||
output_dim: 输出动作维度
|
||||
horizon: 预测horizon长度
|
||||
n_obs_steps: 观测步数
|
||||
cond_dim: 条件维度
|
||||
n_layer: Transformer层数
|
||||
n_head: 注意力头数
|
||||
n_emb: 嵌入维度
|
||||
p_drop_emb: Embedding dropout
|
||||
p_drop_attn: Attention dropout
|
||||
causal_attn: 是否使用因果注意力(自回归)
|
||||
n_cond_layers: Encoder层数(0表示使用MLP)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
horizon: int,
|
||||
n_obs_steps: int = None,
|
||||
cond_dim: int = 0,
|
||||
n_layer: int = 8,
|
||||
n_head: int = 8,
|
||||
n_emb: int = 256,
|
||||
p_drop_emb: float = 0.1,
|
||||
p_drop_attn: float = 0.1,
|
||||
causal_attn: bool = False,
|
||||
obs_as_cond: bool = False,
|
||||
n_cond_layers: int = 0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 计算序列长度
|
||||
if n_obs_steps is None:
|
||||
n_obs_steps = horizon
|
||||
|
||||
T = horizon
|
||||
T_cond = 1 # 时间步token数量
|
||||
|
||||
# 确定是否使用观测作为条件
|
||||
obs_as_cond = cond_dim > 0
|
||||
if obs_as_cond:
|
||||
T_cond += n_obs_steps
|
||||
|
||||
# 保存配置
|
||||
self.T = T
|
||||
self.T_cond = T_cond
|
||||
self.horizon = horizon
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
# ==================== 输入嵌入 ====================
|
||||
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||
self.drop = nn.Dropout(p_drop_emb)
|
||||
|
||||
# ==================== 条件编码 ====================
|
||||
# 时间步嵌入
|
||||
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||
|
||||
# 观测条件嵌入(可选)
|
||||
self.cond_obs_emb = None
|
||||
if obs_as_cond:
|
||||
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
|
||||
|
||||
# 条件位置编码
|
||||
self.cond_pos_emb = None
|
||||
if T_cond > 0:
|
||||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||
|
||||
# ==================== Encoder ====================
|
||||
self.encoder = None
|
||||
self.encoder_only = False
|
||||
|
||||
if T_cond > 0:
|
||||
if n_cond_layers > 0:
|
||||
# 使用Transformer Encoder
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True # Pre-LN更稳定
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_cond_layers
|
||||
)
|
||||
else:
|
||||
# 使用简单的MLP
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(n_emb, 4 * n_emb),
|
||||
nn.Mish(),
|
||||
nn.Linear(4 * n_emb, n_emb)
|
||||
)
|
||||
else:
|
||||
# Encoder-only模式(BERT风格)
|
||||
self.encoder_only = True
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_layer
|
||||
)
|
||||
|
||||
# ==================== Attention Mask ====================
|
||||
self.mask = None
|
||||
self.memory_mask = None
|
||||
|
||||
if causal_attn:
|
||||
# 因果mask:确保只关注左侧
|
||||
sz = T
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
if obs_as_cond:
|
||||
# 交叉注意力mask
|
||||
S = T_cond
|
||||
t, s = torch.meshgrid(
|
||||
torch.arange(T),
|
||||
torch.arange(S),
|
||||
indexing='ij'
|
||||
)
|
||||
mask = t >= (s - 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('memory_mask', mask)
|
||||
|
||||
# ==================== Decoder ====================
|
||||
if not self.encoder_only:
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True
|
||||
)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=n_layer
|
||||
)
|
||||
|
||||
# ==================== 输出头 ====================
|
||||
self.ln_f = nn.LayerNorm(n_emb)
|
||||
self.head = nn.Linear(n_emb, output_dim)
|
||||
|
||||
# ==================== 初始化 ====================
|
||||
self.apply(self._init_weights)
|
||||
|
||||
# 打印参数量
|
||||
total_params = sum(p.numel() for p in self.parameters())
|
||||
print(f"Transformer1D parameters: {total_params:,}")
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""初始化权重"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
# MultiheadAttention的权重初始化
|
||||
for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']:
|
||||
weight = getattr(module, name, None)
|
||||
if weight is not None:
|
||||
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||
|
||||
for name in ['in_proj_bias', 'bias_k', 'bias_v']:
|
||||
bias = getattr(module, name, None)
|
||||
if bias is not None:
|
||||
torch.nn.init.zeros_(bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
torch.nn.init.ones_(module.weight)
|
||||
elif isinstance(module, Transformer1D):
|
||||
# 位置编码初始化
|
||||
torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
|
||||
if self.cond_pos_emb is not None:
|
||||
torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
sample: (B, T, input_dim) 输入序列(加噪动作)
|
||||
timestep: (B,) 时间步
|
||||
cond: (B, T', cond_dim) 条件序列(观测特征)
|
||||
|
||||
Returns:
|
||||
(B, T, output_dim) 预测的噪声
|
||||
"""
|
||||
# ==================== 处理时间步 ====================
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# 扩展到batch维度
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb)
|
||||
|
||||
# ==================== 处理输入 ====================
|
||||
input_emb = self.input_emb(sample) # (B, T, n_emb)
|
||||
|
||||
# ==================== Encoder-Decoder模式 ====================
|
||||
if not self.encoder_only:
|
||||
# --- Encoder: 处理条件 ---
|
||||
cond_embeddings = time_emb
|
||||
|
||||
if self.obs_as_cond and cond is not None:
|
||||
# 添加观测条件
|
||||
cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb)
|
||||
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
|
||||
|
||||
# 添加位置编码
|
||||
tc = cond_embeddings.shape[1]
|
||||
pos_emb = self.cond_pos_emb[:, :tc, :]
|
||||
x = self.drop(cond_embeddings + pos_emb)
|
||||
|
||||
# 通过encoder
|
||||
memory = self.encoder(x) # (B, T_cond, n_emb)
|
||||
|
||||
# --- Decoder: 预测噪声 ---
|
||||
# 添加位置编码到输入
|
||||
token_embeddings = input_emb
|
||||
t = token_embeddings.shape[1]
|
||||
pos_emb = self.pos_emb[:, :t, :]
|
||||
x = self.drop(token_embeddings + pos_emb)
|
||||
|
||||
# Cross-Attention: Query来自输入,Key/Value来自memory
|
||||
x = self.decoder(
|
||||
tgt=x,
|
||||
memory=memory,
|
||||
tgt_mask=self.mask,
|
||||
memory_mask=self.memory_mask
|
||||
)
|
||||
|
||||
# ==================== Encoder-Only模式 ====================
|
||||
else:
|
||||
# BERT风格:时间步作为特殊token
|
||||
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
|
||||
t = token_embeddings.shape[1]
|
||||
pos_emb = self.pos_emb[:, :t, :]
|
||||
x = self.drop(token_embeddings + pos_emb)
|
||||
|
||||
x = self.encoder(src=x, mask=self.mask)
|
||||
x = x[:, 1:, :] # 移除时间步token
|
||||
|
||||
# ==================== 输出头 ====================
|
||||
x = self.ln_f(x)
|
||||
x = self.head(x) # (B, T, output_dim)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 便捷函数:创建Transformer1D模型
|
||||
# ============================================================================
|
||||
def create_transformer1d(
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
horizon: int,
|
||||
n_obs_steps: int,
|
||||
cond_dim: int,
|
||||
n_layer: int = 8,
|
||||
n_head: int = 8,
|
||||
n_emb: int = 256,
|
||||
**kwargs
|
||||
) -> Transformer1D:
|
||||
"""
|
||||
创建Transformer1D模型的便捷函数
|
||||
|
||||
Args:
|
||||
input_dim: 输入动作维度
|
||||
output_dim: 输出动作维度
|
||||
horizon: 预测horizon
|
||||
n_obs_steps: 观测步数
|
||||
cond_dim: 条件维度
|
||||
n_layer: Transformer层数
|
||||
n_head: 注意力头数
|
||||
n_emb: 嵌入维度
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
Transformer1D模型
|
||||
"""
|
||||
model = Transformer1D(
|
||||
input_dim=input_dim,
|
||||
output_dim=output_dim,
|
||||
horizon=horizon,
|
||||
n_obs_steps=n_obs_steps,
|
||||
cond_dim=cond_dim,
|
||||
n_layer=n_layer,
|
||||
n_head=n_head,
|
||||
n_emb=n_emb,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 80)
|
||||
print("Testing Transformer1D")
|
||||
print("=" * 80)
|
||||
|
||||
# 配置
|
||||
B = 4
|
||||
T = 16
|
||||
action_dim = 16
|
||||
obs_horizon = 2
|
||||
cond_dim = 416 # vision + state特征维度
|
||||
|
||||
# 创建模型
|
||||
model = Transformer1D(
|
||||
input_dim=action_dim,
|
||||
output_dim=action_dim,
|
||||
horizon=T,
|
||||
n_obs_steps=obs_horizon,
|
||||
cond_dim=cond_dim,
|
||||
n_layer=4,
|
||||
n_head=8,
|
||||
n_emb=256,
|
||||
causal_attn=False
|
||||
)
|
||||
|
||||
# 测试前向传播
|
||||
sample = torch.randn(B, T, action_dim)
|
||||
timestep = torch.randint(0, 100, (B,))
|
||||
cond = torch.randn(B, obs_horizon, cond_dim)
|
||||
|
||||
output = model(sample, timestep, cond)
|
||||
|
||||
print(f"\n输入:")
|
||||
print(f" sample: {sample.shape}")
|
||||
print(f" timestep: {timestep.shape}")
|
||||
print(f" cond: {cond.shape}")
|
||||
print(f"\n输出:")
|
||||
print(f" output: {output.shape}")
|
||||
print(f"\n✅ 测试通过!")
|
||||
@@ -1,126 +0,0 @@
|
||||
"""
|
||||
归一化模块 - 统一训练和推理的归一化逻辑
|
||||
|
||||
支持两种归一化方式:
|
||||
1. Gaussian (z-score): (x - mean) / std
|
||||
2. MinMax: 2 * (x - min) / (max - min) - 1 -> [-1, 1]
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Dict, Literal
|
||||
|
||||
|
||||
class NormalizationModule(nn.Module):
|
||||
"""
|
||||
统一的归一化模块
|
||||
用于在 Agent 内部对 qpos 和 action 进行归一化/反归一化
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats: Optional[Dict] = None,
|
||||
normalization_type: Literal['gaussian', 'min_max'] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
stats: 数据集统计信息字典,格式:
|
||||
{
|
||||
'qpos_mean': [...],
|
||||
'qpos_std': [...],
|
||||
'qpos_min': [...], # 仅 min_max 需要
|
||||
'qpos_max': [...], # 仅 min_max 需要
|
||||
'action_mean': [...],
|
||||
'action_std': [...],
|
||||
'action_min': [...], # 仅 min_max 需要
|
||||
'action_max': [...], # 仅 min_max 需要
|
||||
}
|
||||
normalization_type: 归一化类型 ('gaussian' 或 'min_max')
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.normalization_type = normalization_type
|
||||
self.enabled = stats is not None
|
||||
|
||||
if self.enabled:
|
||||
if self.normalization_type == 'gaussian':
|
||||
self.register_buffer('qpos_mean', torch.tensor(stats['qpos_mean'], dtype=torch.float32))
|
||||
self.register_buffer('qpos_std', torch.tensor(stats['qpos_std'], dtype=torch.float32))
|
||||
self.register_buffer('action_mean', torch.tensor(stats['action_mean'], dtype=torch.float32))
|
||||
self.register_buffer('action_std', torch.tensor(stats['action_std'], dtype=torch.float32))
|
||||
|
||||
elif self.normalization_type == 'min_max':
|
||||
self.register_buffer('qpos_min', torch.tensor(stats['qpos_min'], dtype=torch.float32))
|
||||
self.register_buffer('qpos_max', torch.tensor(stats['qpos_max'], dtype=torch.float32))
|
||||
self.register_buffer('action_min', torch.tensor(stats['action_min'], dtype=torch.float32))
|
||||
self.register_buffer('action_max', torch.tensor(stats['action_max'], dtype=torch.float32))
|
||||
|
||||
def normalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
|
||||
"""归一化 qpos"""
|
||||
if not self.enabled:
|
||||
return qpos
|
||||
|
||||
if self.normalization_type == 'gaussian':
|
||||
return (qpos - self.qpos_mean) / self.qpos_std
|
||||
elif self.normalization_type == 'min_max':
|
||||
return 2 * (qpos - self.qpos_min) / (self.qpos_max - self.qpos_min) - 1
|
||||
else:
|
||||
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
|
||||
|
||||
def denormalize_qpos(self, qpos: torch.Tensor) -> torch.Tensor:
|
||||
"""反归一化 qpos"""
|
||||
if not self.enabled:
|
||||
return qpos
|
||||
|
||||
if self.normalization_type == 'gaussian':
|
||||
return qpos * self.qpos_std + self.qpos_mean
|
||||
elif self.normalization_type == 'min_max':
|
||||
return (qpos + 1) / 2 * (self.qpos_max - self.qpos_min) + self.qpos_min
|
||||
else:
|
||||
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
|
||||
|
||||
def normalize_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""归一化 action"""
|
||||
if not self.enabled:
|
||||
return action
|
||||
|
||||
if self.normalization_type == 'gaussian':
|
||||
return (action - self.action_mean) / self.action_std
|
||||
elif self.normalization_type == 'min_max':
|
||||
return 2 * (action - self.action_min) / (self.action_max - self.action_min) - 1
|
||||
else:
|
||||
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
|
||||
|
||||
def denormalize_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""反归一化 action"""
|
||||
if not self.enabled:
|
||||
return action
|
||||
|
||||
if self.normalization_type == 'gaussian':
|
||||
return action * self.action_std + self.action_mean
|
||||
elif self.normalization_type == 'min_max':
|
||||
return (action + 1) / 2 * (self.action_max - self.action_min) + self.action_min
|
||||
else:
|
||||
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
|
||||
|
||||
def get_stats(self) -> Optional[Dict]:
|
||||
"""导出统计信息(用于保存到 checkpoint)"""
|
||||
if not self.enabled:
|
||||
return None
|
||||
|
||||
stats = {
|
||||
'normalization_type': self.normalization_type,
|
||||
}
|
||||
|
||||
if self.normalization_type == 'gaussian':
|
||||
stats['qpos_mean'] = self.qpos_mean.cpu().tolist()
|
||||
stats['qpos_std'] = self.qpos_std.cpu().tolist()
|
||||
stats['action_mean'] = self.action_mean.cpu().tolist()
|
||||
stats['action_std'] = self.action_std.cpu().tolist()
|
||||
elif self.normalization_type == 'min_max':
|
||||
stats['qpos_min'] = self.qpos_min.cpu().tolist()
|
||||
stats['qpos_max'] = self.qpos_max.cpu().tolist()
|
||||
stats['action_min'] = self.action_min.cpu().tolist()
|
||||
stats['action_max'] = self.action_max.cpu().tolist()
|
||||
|
||||
return stats
|
||||
@@ -1,18 +0,0 @@
|
||||
from torch import nn
|
||||
|
||||
class IdentityStateEncoder(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, state):
|
||||
return state
|
||||
|
||||
|
||||
class IdentityActionEncoder(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, action):
|
||||
return action
|
||||
@@ -1,87 +0,0 @@
|
||||
import h5py
|
||||
import numpy as np
|
||||
import os
|
||||
import glob
|
||||
import pickle
|
||||
|
||||
def get_data_stats(dataset_dir):
|
||||
"""
|
||||
计算 action 和 qpos 的 Min, Max, Mean, Std
|
||||
|
||||
输出扁平化结构(与 NormalizationModule 期望一致):
|
||||
{
|
||||
'action_mean': [...],
|
||||
'action_std': [...],
|
||||
'action_min': [...],
|
||||
'action_max': [...],
|
||||
'qpos_mean': [...],
|
||||
'qpos_std': [...],
|
||||
'qpos_min': [...],
|
||||
'qpos_max': [...],
|
||||
}
|
||||
"""
|
||||
files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
|
||||
print(f"Found {len(files)} episodes in {dataset_dir}")
|
||||
|
||||
all_actions = []
|
||||
all_qpos = []
|
||||
|
||||
print("Reading data...")
|
||||
for file_path in files:
|
||||
with h5py.File(file_path, 'r') as f:
|
||||
action = f['action'][:]
|
||||
qpos = f['observations']['qpos'][:]
|
||||
all_actions.append(action)
|
||||
all_qpos.append(qpos)
|
||||
|
||||
# 拼接所有数据
|
||||
all_actions = np.concatenate(all_actions, axis=0)
|
||||
all_qpos = np.concatenate(all_qpos, axis=0)
|
||||
|
||||
print(f"Total steps: {all_actions.shape[0]}")
|
||||
|
||||
# --- 核心计算部分 ---
|
||||
# 计算统计量
|
||||
action_mean = np.mean(all_actions, axis=0)
|
||||
action_std = np.std(all_actions, axis=0)
|
||||
action_min = np.min(all_actions, axis=0)
|
||||
action_max = np.max(all_actions, axis=0)
|
||||
|
||||
qpos_mean = np.mean(all_qpos, axis=0)
|
||||
qpos_std = np.std(all_qpos, axis=0)
|
||||
qpos_min = np.min(all_qpos, axis=0)
|
||||
qpos_max = np.max(all_qpos, axis=0)
|
||||
|
||||
# 修正标准差(防止除以 0)
|
||||
eps = 1e-8
|
||||
action_std_corrected = np.where(action_std < eps, eps, action_std)
|
||||
qpos_std_corrected = np.where(qpos_std < eps, eps, qpos_std)
|
||||
|
||||
# 转换为扁平化结构(与 NormalizationModule 期望一致)
|
||||
stats_flat = {
|
||||
'action_mean': action_mean,
|
||||
'action_std': action_std_corrected,
|
||||
'action_min': action_min,
|
||||
'action_max': action_max,
|
||||
'qpos_mean': qpos_mean,
|
||||
'qpos_std': qpos_std_corrected,
|
||||
'qpos_min': qpos_min,
|
||||
'qpos_max': qpos_max
|
||||
}
|
||||
return stats_flat
|
||||
|
||||
if __name__ == "__main__":
|
||||
DATASET_DIR = 'roboimi/demos/dataset/sim_transfer'
|
||||
OUTPUT_PATH = DATASET_DIR + "/dataset_stats.pkl"
|
||||
|
||||
stats_flat = get_data_stats(DATASET_DIR)
|
||||
|
||||
# 打印检查
|
||||
print("\n--- Stats Computed ---")
|
||||
print(f"Action Mean shape: {stats_flat['action_mean'].shape}")
|
||||
print(f"Action Std shape: {stats_flat['action_std'].shape}")
|
||||
|
||||
# 保存
|
||||
with open(OUTPUT_PATH, 'wb') as f:
|
||||
pickle.dump(stats_flat, f)
|
||||
print(f"\nStats saved to {OUTPUT_PATH}")
|
||||
Reference in New Issue
Block a user