chore: 删除detr和gr00t
This commit is contained in:
@@ -1,206 +0,0 @@
|
|||||||
import torch
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from tqdm import tqdm
|
|
||||||
from einops import rearrange
|
|
||||||
from roboimi.utils.utils import set_seed
|
|
||||||
from roboimi.utils.io_utils import IOUtils
|
|
||||||
from roboimi.utils.model_interface import ModelInterface
|
|
||||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
|
||||||
# from visualize_episodes import save_videos
|
|
||||||
from roboimi.utils.act_ex_utils import sample_transfer_pose
|
|
||||||
|
|
||||||
|
|
||||||
class ActionSmoother:
|
|
||||||
"""
|
|
||||||
动作平滑器,支持多种平滑策略
|
|
||||||
"""
|
|
||||||
def __init__(self, action_dim, method='ema', alpha=0.3, window_size=5):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
action_dim: 动作维度
|
|
||||||
method: 平滑方法 ('ema', 'moving_avg', 'lowpass', 'none')
|
|
||||||
alpha: EMA 平滑系数 (0-1),越小越平滑
|
|
||||||
window_size: 滑动窗口大小
|
|
||||||
"""
|
|
||||||
self.action_dim = action_dim
|
|
||||||
self.method = method
|
|
||||||
self.alpha = alpha
|
|
||||||
self.window_size = window_size
|
|
||||||
self.history = []
|
|
||||||
self.prev_action = None
|
|
||||||
|
|
||||||
def smooth(self, action):
|
|
||||||
"""
|
|
||||||
对动作进行平滑处理
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: 当前动作 [action_dim]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
smoothed_action: 平滑后的动作
|
|
||||||
"""
|
|
||||||
if self.method == 'none':
|
|
||||||
return action
|
|
||||||
|
|
||||||
if self.method == 'ema':
|
|
||||||
# 指数移动平均
|
|
||||||
if self.prev_action is None:
|
|
||||||
smoothed = action
|
|
||||||
else:
|
|
||||||
smoothed = self.alpha * action + (1 - self.alpha) * self.prev_action
|
|
||||||
self.prev_action = smoothed
|
|
||||||
return smoothed
|
|
||||||
|
|
||||||
elif self.method == 'moving_avg':
|
|
||||||
# 滑动平均
|
|
||||||
self.history.append(action.copy())
|
|
||||||
if len(self.history) > self.window_size:
|
|
||||||
self.history.pop(0)
|
|
||||||
return np.mean(self.history, axis=0)
|
|
||||||
|
|
||||||
elif self.method == 'lowpass':
|
|
||||||
# 一阶低通滤波器
|
|
||||||
if self.prev_action is None:
|
|
||||||
smoothed = action
|
|
||||||
else:
|
|
||||||
smoothed = self.prev_action + self.alpha * (action - self.prev_action)
|
|
||||||
self.prev_action = smoothed
|
|
||||||
return smoothed
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown smoothing method: {self.method}")
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""重置平滑器状态"""
|
|
||||||
self.history = []
|
|
||||||
self.prev_action = None
|
|
||||||
|
|
||||||
|
|
||||||
#should be added into IOUtils
|
|
||||||
def get_image(obs,camera_names):
|
|
||||||
curr_images = []
|
|
||||||
for cam_name in camera_names:
|
|
||||||
curr_image = rearrange(obs['images'][cam_name], 'h w c -> c h w')
|
|
||||||
curr_images.append(curr_image)
|
|
||||||
curr_image = np.stack(curr_images, axis=0)
|
|
||||||
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
|
|
||||||
return curr_image
|
|
||||||
|
|
||||||
|
|
||||||
def eval_bc(config, ckpt_name='policy_best.ckpt', save_episode=True):
|
|
||||||
set_seed(1)
|
|
||||||
model_interface = ModelInterface(config)
|
|
||||||
model_interface.setup()
|
|
||||||
policy = IOUtils.load_policy(config, ckpt_name)
|
|
||||||
stats = IOUtils.load_stats(config['ckpt_dir'])
|
|
||||||
num_rollouts = 3
|
|
||||||
episode_returns = []
|
|
||||||
highest_rewards = []
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
run_episode(config, policy, stats,
|
|
||||||
save_episode,num_rollouts)
|
|
||||||
# episode_return, episode_highest_reward = run_episode(config, policy, stats,
|
|
||||||
# save_episode,num_rollouts)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_episode(config, policy, stats, save_episode,num_rollouts):
|
|
||||||
|
|
||||||
if 'sim_transfer' in config['task_name']:
|
|
||||||
task_name = 'sim_transfer' #config['task_name']
|
|
||||||
env = make_sim_env(task_name)
|
|
||||||
|
|
||||||
max_timesteps = config['episode_len']
|
|
||||||
max_timesteps = int(max_timesteps * 1)
|
|
||||||
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
|
||||||
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
|
|
||||||
box_pos = sample_transfer_pose()
|
|
||||||
|
|
||||||
# 初始化动作平滑器
|
|
||||||
action_dim = config['action_dim']
|
|
||||||
use_smoothing = config.get('use_action_smoothing', False)
|
|
||||||
smooth_method = config.get('smooth_method', 'ema')
|
|
||||||
smooth_alpha = config.get('smooth_alpha', 0.3)
|
|
||||||
|
|
||||||
if use_smoothing and config['policy_class'] == "GR00T":
|
|
||||||
smoother = ActionSmoother(action_dim, method=smooth_method, alpha=smooth_alpha)
|
|
||||||
print(f"Action smoothing enabled: method={smooth_method}, alpha={smooth_alpha}")
|
|
||||||
else:
|
|
||||||
smoother = None
|
|
||||||
|
|
||||||
for rollout_id in range(num_rollouts):
|
|
||||||
print("\nrollout_id===",rollout_id,"\n")
|
|
||||||
image_list = []
|
|
||||||
rewards = []
|
|
||||||
query_frequency = config['policy_config'].get('num_queries', 1)
|
|
||||||
print("query_freq =====",query_frequency)
|
|
||||||
env.reset(box_pos)
|
|
||||||
|
|
||||||
# 重置平滑器
|
|
||||||
if smoother is not None:
|
|
||||||
smoother.reset()
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
for t in range(700):
|
|
||||||
image_list.append(env._get_image_obs()['images'] if 'images' in env._get_image_obs() else {print("img error")})
|
|
||||||
qpos_numpy = np.array(env._get_qpos_obs()['qpos'])
|
|
||||||
qpos = pre_process(qpos_numpy)
|
|
||||||
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
|
||||||
curr_image = get_image(env._get_image_obs(), config['camera_names'])
|
|
||||||
if config['policy_class'] in ["ACT", "ACTTV", "GR00T"]:
|
|
||||||
if t % query_frequency == 0:
|
|
||||||
all_actions = policy(qpos, curr_image)
|
|
||||||
raw_action = all_actions[:, t % query_frequency]
|
|
||||||
raw_action = raw_action.squeeze(0).cpu().numpy()
|
|
||||||
elif config['policy_class'] == "CNNMLP":
|
|
||||||
raw_action = policy(qpos, curr_image)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
action = post_process(raw_action)
|
|
||||||
|
|
||||||
# 应用动作平滑(仅对 GR00T)
|
|
||||||
if smoother is not None:
|
|
||||||
action = smoother.smooth(action)
|
|
||||||
|
|
||||||
print("action == ",action)
|
|
||||||
env.step_jnt(action)
|
|
||||||
rewards.append(env.rew)
|
|
||||||
env.render()
|
|
||||||
|
|
||||||
|
|
||||||
rewards = np.array(rewards)
|
|
||||||
# episode_return = np.sum(rewards[rewards != None])
|
|
||||||
# episode_highest_reward = np.max(rewards)
|
|
||||||
# env.viewer.close()
|
|
||||||
|
|
||||||
# del env
|
|
||||||
# return episode_return, episode_highest_reward
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_env():
|
|
||||||
try:
|
|
||||||
env = make_sim_env('sim_transfer')
|
|
||||||
env.reset()
|
|
||||||
while True: pass
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
del env
|
|
||||||
print("stop")
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# test_env()
|
|
||||||
io_utils = IOUtils()
|
|
||||||
config = io_utils.load_config()
|
|
||||||
eval_bc(config)
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,152 +0,0 @@
|
|||||||
import torch
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from tqdm import tqdm
|
|
||||||
from einops import rearrange
|
|
||||||
from roboimi.utils.utils import set_seed
|
|
||||||
from roboimi.utils.io_utils import IOUtils
|
|
||||||
from roboimi.utils.model_interface import ModelInterface
|
|
||||||
from roboimi.envs.vx300s_jnt import make_sim_env
|
|
||||||
import time
|
|
||||||
|
|
||||||
# from visualize_episodes import save_videos
|
|
||||||
from roboimi.utils.utils import sample_box_pose, sample_insertion_pose
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#should be added into IOUtils
|
|
||||||
def get_image(obs,camera_names):
|
|
||||||
curr_images = []
|
|
||||||
for cam_name in camera_names:
|
|
||||||
curr_image = rearrange(obs['images'][cam_name], 'h w c -> c h w')
|
|
||||||
curr_images.append(curr_image)
|
|
||||||
curr_image = np.stack(curr_images, axis=0)
|
|
||||||
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
|
|
||||||
return curr_image
|
|
||||||
|
|
||||||
|
|
||||||
def eval_bc(config, ckpt_name='policy_best.ckpt', save_episode=True):
|
|
||||||
set_seed(1)
|
|
||||||
model_interface = ModelInterface(config)
|
|
||||||
task_name = 'sim_insertion' #config['task_name']
|
|
||||||
model_interface.setup()
|
|
||||||
policy = IOUtils.load_policy(config, ckpt_name)
|
|
||||||
stats = IOUtils.load_stats(config['ckpt_dir'])
|
|
||||||
num_rollouts = 3
|
|
||||||
episode_returns = []
|
|
||||||
highest_rewards = []
|
|
||||||
for rollout_id in range(num_rollouts):
|
|
||||||
episode_return, episode_highest_reward = run_episode(config, policy, stats,
|
|
||||||
save_episode,rollout_id)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_episode(config, policy, stats, save_episode,rollout_id):
|
|
||||||
print("\nrollout_id===",rollout_id,"\n")
|
|
||||||
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
|
||||||
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
|
|
||||||
if 'sim_insertion' in config['task_name']:
|
|
||||||
peg_pose, socket_pose = sample_insertion_pose()
|
|
||||||
box_pose = np.hstack((peg_pose[:3],socket_pose[:3])) # used in sim reset
|
|
||||||
task_name = 'sim_insertion' #config['task_name']
|
|
||||||
env = make_sim_env(task_name)
|
|
||||||
env.reset(box_pose)
|
|
||||||
max_timesteps = config['episode_len']
|
|
||||||
max_timesteps = int(max_timesteps * 1)
|
|
||||||
|
|
||||||
image_list = []
|
|
||||||
rewards = []
|
|
||||||
query_frequency = config['policy_config'].get('num_queries', 1)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
for t in range(700):
|
|
||||||
# print("obs_img",env.obs['images'])
|
|
||||||
image_list.append(env.obs['images'] if 'images' in env.obs else {print("img error")})
|
|
||||||
qpos_numpy = np.array(env.obs['qpos'])
|
|
||||||
qpos = pre_process(qpos_numpy)
|
|
||||||
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
|
||||||
curr_image = get_image(env.obs, config['camera_names'])
|
|
||||||
if config['policy_class'] == "ACT" or "ACTTV":
|
|
||||||
if t % query_frequency == 0:
|
|
||||||
all_actions = policy(qpos, curr_image)
|
|
||||||
elif config['policy_class'] == "CNNMLP":
|
|
||||||
raw_action = policy(qpos, curr_image)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
raw_action = all_actions[:, t % query_frequency]
|
|
||||||
raw_action = raw_action.squeeze(0).cpu().numpy()
|
|
||||||
action = post_process(raw_action)
|
|
||||||
|
|
||||||
env.step(action)
|
|
||||||
rewards.append(env.rew)
|
|
||||||
env.render()
|
|
||||||
|
|
||||||
|
|
||||||
rewards = np.array(rewards)
|
|
||||||
episode_return = np.sum(rewards[rewards != None])
|
|
||||||
episode_highest_reward = np.max(rewards)
|
|
||||||
env.viewer.close()
|
|
||||||
|
|
||||||
del env
|
|
||||||
return episode_return, episode_highest_reward
|
|
||||||
|
|
||||||
|
|
||||||
def test_env():
|
|
||||||
try:
|
|
||||||
env = make_sim_env('sim_insertion')
|
|
||||||
box_pos = np.concatenate(sample_insertion_pose())
|
|
||||||
env.reset(box_pos)
|
|
||||||
while True: pass
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
del env
|
|
||||||
print("stop")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_env()
|
|
||||||
# io_utils = IOUtils()
|
|
||||||
# config = io_utils.load_config()
|
|
||||||
# eval_bc(config)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# config===== {'onscreen_render': False,
|
|
||||||
# 'eval': 1,
|
|
||||||
# 'ckpt_dir': 'ckpt_models',
|
|
||||||
# 'num_epochs': 3000,
|
|
||||||
# 'temporal_agg': False,
|
|
||||||
# 'policy_class': 'ACT',
|
|
||||||
# 'backbone': 'resnet18',
|
|
||||||
# 'seed': 0, 'real_robot': 0,
|
|
||||||
# 'task_name': 'sim_insertion',
|
|
||||||
# 'images_render_height': 480,
|
|
||||||
# 'images_render_width': 640,
|
|
||||||
# 'left_arm_DOF_number': 6,
|
|
||||||
# 'right_arm_DOF_number': 6,
|
|
||||||
# 'left_qpos_raw': 8,
|
|
||||||
# 'right_qpos_raw': 8,
|
|
||||||
# 'left_qvel_raw': 8,
|
|
||||||
# 'right_qvel_raw': 8,
|
|
||||||
# 'dataset_dir': '/home/arm/lzd/act_env/dataset/sim_insertion',
|
|
||||||
# 'num_episodes': 7,
|
|
||||||
# 'episode_len': 400,
|
|
||||||
# 'camera_names': ['top'],
|
|
||||||
# 'xml_dir': None,
|
|
||||||
# 'batch_size': 8,
|
|
||||||
# 'state_dim': 14,
|
|
||||||
# 'action_dim': 14,
|
|
||||||
# 'lr_backbone': 1e-05,
|
|
||||||
# 'enc_layers': 4,
|
|
||||||
# 'dec_layers': 7,
|
|
||||||
# 'nheads': 8,
|
|
||||||
# 'qpos_noise_std': 0,
|
|
||||||
# 'DT': 0.02,
|
|
||||||
# 'lr': 1e-05,
|
|
||||||
# 'kl_weight': 10,
|
|
||||||
# 'chunk_size': 100,
|
|
||||||
# 'hidden_dim': 512,
|
|
||||||
# 'dim_feedforward': 3200,
|
|
||||||
# 'policy_config': {'lr': 1e-05, 'num_queries': 100, 'kl_weight': 10, 'hidden_dim': 512, 'dim_feedforward': 3200, 'lr_backbone': 1e-05, 'backbone': 'resnet18', 'enc_layers': 4, 'dec_layers': 7, 'nheads': 8, 'camera_names': ['top']}}
|
|
||||||
@@ -1,179 +0,0 @@
|
|||||||
import torch
|
|
||||||
import os
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
|
||||||
from copy import deepcopy
|
|
||||||
from itertools import repeat
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import time
|
|
||||||
from roboimi.utils.utils import set_seed, compute_dict_mean, detach_dict, load_data
|
|
||||||
from roboimi.utils.io_utils import IOUtils
|
|
||||||
from roboimi.utils.model_interface import ModelInterface
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
def train_bc(config):
|
|
||||||
num_epochs = config['num_epochs']
|
|
||||||
ckpt_dir = config['ckpt_dir']
|
|
||||||
seed = config['seed']
|
|
||||||
|
|
||||||
os.makedirs(ckpt_dir, exist_ok=True)
|
|
||||||
|
|
||||||
set_seed(seed)
|
|
||||||
|
|
||||||
model_interface = ModelInterface(config)
|
|
||||||
model_interface.setup()
|
|
||||||
|
|
||||||
policy = model_interface.make_policy()
|
|
||||||
policy.cuda()
|
|
||||||
optimizer = model_interface.make_optimizer(policy)
|
|
||||||
# print("cam names=====",config['camera_names'])
|
|
||||||
train_dataloader, val_dataloader, stats, _ = load_data(
|
|
||||||
config['dataset_dir'],
|
|
||||||
config['num_episodes'],
|
|
||||||
config['camera_names'],
|
|
||||||
config['batch_size'],
|
|
||||||
config['batch_size'])
|
|
||||||
|
|
||||||
IOUtils.save_stats(ckpt_dir, stats)
|
|
||||||
|
|
||||||
train_history = []
|
|
||||||
validation_history = []
|
|
||||||
min_val_loss = np.inf
|
|
||||||
min_train_loss = np.inf
|
|
||||||
best_ckpt_info = None
|
|
||||||
|
|
||||||
plt.ion()
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
train_losses, val_losses = [], []
|
|
||||||
train_line, = ax.plot([], [], label='Train Loss')
|
|
||||||
val_line, = ax.plot([], [], label='Validation Loss')
|
|
||||||
ax.autoscale_view()
|
|
||||||
ax.set_xlabel('Epoch')
|
|
||||||
ax.set_ylabel('Loss')
|
|
||||||
ax.legend()
|
|
||||||
ax.grid(True)
|
|
||||||
|
|
||||||
|
|
||||||
train_annotation = ax.annotate('', xy=(0, 0), textcoords='offset points')
|
|
||||||
val_annotation = ax.annotate('', xy=(0, 0), textcoords='offset points')
|
|
||||||
|
|
||||||
|
|
||||||
min_train_text = ax.text(0.85, 0.5, '', transform=ax.transAxes, fontsize=10, verticalalignment='center', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.5))
|
|
||||||
min_val_text = ax.text(0.85, 0.45, '', transform=ax.transAxes, fontsize=10, verticalalignment='center', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.5))
|
|
||||||
|
|
||||||
for epoch in tqdm(range(num_epochs)):
|
|
||||||
print(f'\nEpoch {epoch}')
|
|
||||||
|
|
||||||
# Validation
|
|
||||||
epoch_val_loss, epoch_summary = validate(policy, val_dataloader)
|
|
||||||
validation_history.append(epoch_summary)
|
|
||||||
val_losses.append(epoch_val_loss.cpu().item())
|
|
||||||
|
|
||||||
if epoch_val_loss < min_val_loss:
|
|
||||||
min_val_loss = epoch_val_loss
|
|
||||||
min_val_epoch = epoch
|
|
||||||
best_ckpt_info = (epoch, min_val_loss,
|
|
||||||
deepcopy(policy.state_dict()))
|
|
||||||
|
|
||||||
print(f'Val loss: {epoch_val_loss:.5f}')
|
|
||||||
print_summary(epoch_summary)
|
|
||||||
|
|
||||||
# Training
|
|
||||||
epoch_train_loss, epoch_summary = train_epoch(
|
|
||||||
policy, optimizer, train_dataloader)
|
|
||||||
train_history.append(epoch_summary)
|
|
||||||
train_losses.append(epoch_train_loss.cpu().item())
|
|
||||||
|
|
||||||
if epoch_train_loss < min_train_loss:
|
|
||||||
min_train_loss = epoch_train_loss
|
|
||||||
min_train_epoch = epoch
|
|
||||||
|
|
||||||
print(f'Train loss: {epoch_train_loss:.5f}')
|
|
||||||
print_summary(epoch_summary)
|
|
||||||
|
|
||||||
# Update the plot with the new data
|
|
||||||
train_line.set_xdata(range(len(train_losses)))
|
|
||||||
train_line.set_ydata(train_losses)
|
|
||||||
val_line.set_xdata(range(len(val_losses)))
|
|
||||||
val_line.set_ydata(val_losses)
|
|
||||||
|
|
||||||
# Update annotations with the latest loss values at their respective positions
|
|
||||||
train_annotation.set_position((len(train_losses)-1, train_losses[-1]))
|
|
||||||
train_annotation.xy = (len(train_losses)-1, train_losses[-1])
|
|
||||||
train_annotation.set_text(f'{train_losses[-1]:.5f}')
|
|
||||||
|
|
||||||
val_annotation.set_position((len(val_losses)-1, val_losses[-1]))
|
|
||||||
val_annotation.xy = (len(val_losses)-1, val_losses[-1])
|
|
||||||
val_annotation.set_text(f'{val_losses[-1]:.5f}')
|
|
||||||
|
|
||||||
# Update text objects with the minimum loss values, fixed on the right side
|
|
||||||
min_train_text.set_text(f'Min Train Loss: {min_train_loss:.5f} (Epoch {min_train_epoch})')
|
|
||||||
min_val_text.set_text(f'Min Val Loss: {min_val_loss:.5f} (Epoch {min_val_epoch})')
|
|
||||||
|
|
||||||
ax.relim()
|
|
||||||
ax.autoscale_view()
|
|
||||||
plt.draw()
|
|
||||||
plt.pause(0.1)
|
|
||||||
|
|
||||||
|
|
||||||
plt.ioff()
|
|
||||||
IOUtils.save_checkpoint(policy, 'last', ckpt_dir, seed, 'last')
|
|
||||||
|
|
||||||
best_epoch, min_val_loss, best_state_dict = best_ckpt_info
|
|
||||||
IOUtils.save_checkpoint(best_state_dict, best_epoch,
|
|
||||||
ckpt_dir, seed, 'best', min_val_loss)
|
|
||||||
print(
|
|
||||||
f'Training finished:\nSeed {seed}, val loss {min_val_loss:.6f} at epoch {best_epoch}')
|
|
||||||
|
|
||||||
IOUtils.plot_history(train_history, validation_history,
|
|
||||||
num_epochs, ckpt_dir, seed)
|
|
||||||
|
|
||||||
return best_ckpt_info
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def validate(policy, dataloader):
|
|
||||||
policy.eval()
|
|
||||||
epoch_dicts = []
|
|
||||||
with torch.inference_mode():
|
|
||||||
for data in dataloader:
|
|
||||||
forward_dict = forward_pass(data, policy)
|
|
||||||
epoch_dicts.append(forward_dict)
|
|
||||||
epoch_summary = compute_dict_mean(epoch_dicts)
|
|
||||||
return epoch_summary['loss'], epoch_summary
|
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(policy, optimizer, dataloader):
|
|
||||||
policy.train()
|
|
||||||
epoch_dicts = []
|
|
||||||
for data in dataloader:
|
|
||||||
optimizer.zero_grad()
|
|
||||||
forward_dict = forward_pass(data, policy)
|
|
||||||
loss = forward_dict['loss']
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
epoch_dicts.append(detach_dict(forward_dict))
|
|
||||||
epoch_summary = compute_dict_mean(epoch_dicts)
|
|
||||||
return epoch_summary['loss'], epoch_summary
|
|
||||||
|
|
||||||
|
|
||||||
def forward_pass(data, policy):
|
|
||||||
image_data, qpos_data, action_data, is_pad = data
|
|
||||||
image_data, qpos_data, action_data, is_pad = image_data.cuda(
|
|
||||||
), qpos_data.cuda(), action_data.cuda(), is_pad.cuda()
|
|
||||||
return policy(qpos_data, image_data, action_data, is_pad)
|
|
||||||
|
|
||||||
|
|
||||||
def print_summary(summary):
|
|
||||||
summary_string = ' '.join(
|
|
||||||
[f'{k}: {v.item():.3f}' for k, v in summary.items()])
|
|
||||||
print(summary_string)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
io_utils = IOUtils()
|
|
||||||
config = io_utils.load_config()
|
|
||||||
train_bc(config)
|
|
||||||
@@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright 2020 - present, Facebook, Inc
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0.
|
|
||||||
|
|
||||||
@article{Carion2020EndtoEndOD,
|
|
||||||
title={End-to-End Object Detection with Transformers},
|
|
||||||
author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko},
|
|
||||||
journal={ArXiv},
|
|
||||||
year={2020},
|
|
||||||
volume={abs/2005.12872}
|
|
||||||
}
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from .models import build_ACT_model, build_CNNMLP_model
|
|
||||||
|
|
||||||
|
|
||||||
def get_args_parser():
|
|
||||||
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
|
|
||||||
parser.add_argument('--lr', default=1e-4, type=float) # will be overridden
|
|
||||||
parser.add_argument('--lr_backbone', default=1e-5, type=float) # will be overridden
|
|
||||||
parser.add_argument('--batch_size', default=2, type=int) # not used
|
|
||||||
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
|
||||||
parser.add_argument('--epochs', default=300, type=int) # not used
|
|
||||||
parser.add_argument('--lr_drop', default=200, type=int) # not used
|
|
||||||
parser.add_argument('--clip_max_norm', default=0.1, type=float, # not used
|
|
||||||
help='gradient clipping max norm')
|
|
||||||
parser.add_argument('--qpos_noise_std', action='store', default=0, type=float, help='lr', required=False)
|
|
||||||
|
|
||||||
# Model parameters
|
|
||||||
# * Backbone
|
|
||||||
parser.add_argument('--backbone', default='resnet18', type=str, # will be overridden
|
|
||||||
help="Name of the convolutional backbone to use")
|
|
||||||
parser.add_argument('--dilation', action='store_true',
|
|
||||||
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
|
|
||||||
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
|
|
||||||
help="Type of positional embedding to use on top of the image features")
|
|
||||||
parser.add_argument('--camera_names', default=[], type=list, # will be overridden
|
|
||||||
help="A list of camera names")
|
|
||||||
|
|
||||||
# * Transformer
|
|
||||||
parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
|
|
||||||
help="Number of encoding layers in the transformer")
|
|
||||||
parser.add_argument('--dec_layers', default=6, type=int, # will be overridden
|
|
||||||
help="Number of decoding layers in the transformer")
|
|
||||||
parser.add_argument('--dim_feedforward', default=2048, type=int, # will be overridden
|
|
||||||
help="Intermediate size of the feedforward layers in the transformer blocks")
|
|
||||||
parser.add_argument('--hidden_dim', default=256, type=int, # will be overridden
|
|
||||||
help="Size of the embeddings (dimension of the transformer)")
|
|
||||||
parser.add_argument('--dropout', default=0.1, type=float,
|
|
||||||
help="Dropout applied in the transformer")
|
|
||||||
parser.add_argument('--nheads', default=8, type=int, # will be overridden
|
|
||||||
help="Number of attention heads inside the transformer's attentions")
|
|
||||||
parser.add_argument('--num_queries', default=400, type=int, # will be overridden
|
|
||||||
help="Number of query slots")
|
|
||||||
parser.add_argument('--pre_norm', action='store_true')
|
|
||||||
parser.add_argument('--state_dim', default=14, type=int)
|
|
||||||
parser.add_argument('--action_dim', default=14, type=int)
|
|
||||||
|
|
||||||
|
|
||||||
# * Segmentation
|
|
||||||
parser.add_argument('--masks', action='store_true',
|
|
||||||
help="Train segmentation head if the flag is provided")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def build_ACT_model_and_optimizer(args_override):
|
|
||||||
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
for k, v in args_override.items():
|
|
||||||
setattr(args, k, v)
|
|
||||||
|
|
||||||
model = build_ACT_model(args)
|
|
||||||
model.cuda()
|
|
||||||
|
|
||||||
param_dicts = [
|
|
||||||
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
|
|
||||||
"lr": args.lr_backbone,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
|
||||||
weight_decay=args.weight_decay)
|
|
||||||
|
|
||||||
return model, optimizer
|
|
||||||
|
|
||||||
|
|
||||||
def build_CNNMLP_model_and_optimizer(args_override):
|
|
||||||
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
for k, v in args_override.items():
|
|
||||||
setattr(args, k, v)
|
|
||||||
|
|
||||||
model = build_CNNMLP_model(args)
|
|
||||||
model.cuda()
|
|
||||||
|
|
||||||
param_dicts = [
|
|
||||||
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
|
|
||||||
"lr": args.lr_backbone,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
|
||||||
weight_decay=args.weight_decay)
|
|
||||||
|
|
||||||
return model, optimizer
|
|
||||||
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
from .detr_vae import build as build_vae
|
|
||||||
from .detr_vae import build_cnnmlp as build_cnnmlp
|
|
||||||
|
|
||||||
def build_ACT_model(args):
|
|
||||||
return build_vae(args)
|
|
||||||
|
|
||||||
def build_CNNMLP_model(args):
|
|
||||||
return build_cnnmlp(args)
|
|
||||||
@@ -1,168 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
"""
|
|
||||||
Backbone modules.
|
|
||||||
"""
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchvision
|
|
||||||
from torch import nn
|
|
||||||
from torchvision.models._utils import IntermediateLayerGetter
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from util.misc import NestedTensor, is_main_process
|
|
||||||
|
|
||||||
from .position_encoding import build_position_encoding
|
|
||||||
|
|
||||||
class FrozenBatchNorm2d(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
||||||
|
|
||||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
|
||||||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
|
||||||
produce nans.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, n):
|
|
||||||
super(FrozenBatchNorm2d, self).__init__()
|
|
||||||
self.register_buffer("weight", torch.ones(n))
|
|
||||||
self.register_buffer("bias", torch.zeros(n))
|
|
||||||
self.register_buffer("running_mean", torch.zeros(n))
|
|
||||||
self.register_buffer("running_var", torch.ones(n))
|
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
||||||
missing_keys, unexpected_keys, error_msgs):
|
|
||||||
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
|
||||||
if num_batches_tracked_key in state_dict:
|
|
||||||
del state_dict[num_batches_tracked_key]
|
|
||||||
|
|
||||||
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
|
||||||
state_dict, prefix, local_metadata, strict,
|
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# move reshapes to the beginning
|
|
||||||
# to make it fuser-friendly
|
|
||||||
w = self.weight.reshape(1, -1, 1, 1)
|
|
||||||
b = self.bias.reshape(1, -1, 1, 1)
|
|
||||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
|
||||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
|
||||||
eps = 1e-5
|
|
||||||
scale = w * (rv + eps).rsqrt()
|
|
||||||
bias = b - rm * scale
|
|
||||||
return x * scale + bias
|
|
||||||
|
|
||||||
|
|
||||||
class BackboneBase(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
|
|
||||||
super().__init__()
|
|
||||||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
|
||||||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
|
||||||
# parameter.requires_grad_(False)
|
|
||||||
if return_interm_layers:
|
|
||||||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
|
||||||
else:
|
|
||||||
return_layers = {'layer4': "0"}
|
|
||||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
|
||||||
self.num_channels = num_channels
|
|
||||||
|
|
||||||
def forward(self, tensor):
|
|
||||||
xs = self.body(tensor)
|
|
||||||
return xs
|
|
||||||
# out: Dict[str, NestedTensor] = {}
|
|
||||||
# for name, x in xs.items():
|
|
||||||
# m = tensor_list.mask
|
|
||||||
# assert m is not None
|
|
||||||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
|
||||||
# out[name] = NestedTensor(x, mask)
|
|
||||||
# return out
|
|
||||||
|
|
||||||
|
|
||||||
class Backbone(BackboneBase):
|
|
||||||
"""ResNet backbone with frozen BatchNorm."""
|
|
||||||
def __init__(self, name: str,
|
|
||||||
train_backbone: bool,
|
|
||||||
return_interm_layers: bool,
|
|
||||||
dilation: bool):
|
|
||||||
backbone = getattr(torchvision.models, name)(
|
|
||||||
replace_stride_with_dilation=[False, False, dilation],
|
|
||||||
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
|
|
||||||
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
|
||||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
|
||||||
|
|
||||||
|
|
||||||
# class DINOv2BackBone(nn.Module):
|
|
||||||
# def __init__(self) -> None:
|
|
||||||
# super().__init__()
|
|
||||||
# self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
|
||||||
# self.body.eval()
|
|
||||||
# self.num_channels = 384
|
|
||||||
|
|
||||||
# @torch.no_grad()
|
|
||||||
# def forward(self, tensor):
|
|
||||||
# xs = self.body.forward_features(tensor)["x_norm_patchtokens"]
|
|
||||||
# od = OrderedDict()
|
|
||||||
# od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
|
||||||
# return od
|
|
||||||
|
|
||||||
class DINOv2BackBone(nn.Module):
|
|
||||||
def __init__(self, return_interm_layers: bool = False) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
|
||||||
self.body.eval()
|
|
||||||
self.num_channels = 384
|
|
||||||
self.return_interm_layers = return_interm_layers
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(self, tensor):
|
|
||||||
features = self.body.forward_features(tensor)
|
|
||||||
|
|
||||||
if self.return_interm_layers:
|
|
||||||
|
|
||||||
layer1 = features["x_norm_patchtokens"]
|
|
||||||
layer2 = features["x_norm_patchtokens"]
|
|
||||||
layer3 = features["x_norm_patchtokens"]
|
|
||||||
layer4 = features["x_norm_patchtokens"]
|
|
||||||
|
|
||||||
od = OrderedDict()
|
|
||||||
od["0"] = layer1.reshape(layer1.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
|
||||||
od["1"] = layer2.reshape(layer2.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
|
||||||
od["2"] = layer3.reshape(layer3.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
|
||||||
od["3"] = layer4.reshape(layer4.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
|
||||||
return od
|
|
||||||
else:
|
|
||||||
xs = features["x_norm_patchtokens"]
|
|
||||||
od = OrderedDict()
|
|
||||||
od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
|
||||||
return od
|
|
||||||
|
|
||||||
class Joiner(nn.Sequential):
|
|
||||||
def __init__(self, backbone, position_embedding):
|
|
||||||
super().__init__(backbone, position_embedding)
|
|
||||||
|
|
||||||
def forward(self, tensor_list: NestedTensor):
|
|
||||||
xs = self[0](tensor_list)
|
|
||||||
out: List[NestedTensor] = []
|
|
||||||
pos = []
|
|
||||||
for name, x in xs.items():
|
|
||||||
out.append(x)
|
|
||||||
# position encoding
|
|
||||||
pos.append(self[1](x).to(x.dtype))
|
|
||||||
|
|
||||||
return out, pos
|
|
||||||
|
|
||||||
|
|
||||||
def build_backbone(args):
|
|
||||||
position_embedding = build_position_encoding(args)
|
|
||||||
train_backbone = args.lr_backbone > 0
|
|
||||||
return_interm_layers = args.masks
|
|
||||||
if args.backbone == 'dino_v2':
|
|
||||||
backbone = DINOv2BackBone()
|
|
||||||
else:
|
|
||||||
assert args.backbone in ['resnet18', 'resnet34']
|
|
||||||
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
|
||||||
model = Joiner(backbone, position_embedding)
|
|
||||||
model.num_channels = backbone.num_channels
|
|
||||||
return model
|
|
||||||
@@ -1,300 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
"""
|
|
||||||
DETR model and criterion classes.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.autograd import Variable
|
|
||||||
from .backbone import build_backbone
|
|
||||||
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def reparametrize(mu, logvar):
|
|
||||||
std = logvar.div(2).exp()
|
|
||||||
eps = Variable(std.data.new(std.size()).normal_())
|
|
||||||
return mu + std * eps
|
|
||||||
|
|
||||||
|
|
||||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
|
||||||
def get_position_angle_vec(position):
|
|
||||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
|
||||||
|
|
||||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
|
||||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
|
||||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
|
||||||
|
|
||||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
class DETRVAE(nn.Module):
|
|
||||||
""" This is the DETR module that performs object detection """
|
|
||||||
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names):
|
|
||||||
""" Initializes the model.
|
|
||||||
Parameters:
|
|
||||||
backbones: torch module of the backbone to be used. See backbone.py
|
|
||||||
transformer: torch module of the transformer architecture. See transformer.py
|
|
||||||
state_dim: robot state dimension of the environment
|
|
||||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
|
||||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
|
||||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.num_queries = num_queries
|
|
||||||
self.camera_names = camera_names
|
|
||||||
self.transformer = transformer
|
|
||||||
self.encoder = encoder
|
|
||||||
hidden_dim = transformer.d_model
|
|
||||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
|
||||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
|
||||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
|
||||||
if backbones is not None:
|
|
||||||
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
|
||||||
self.backbones = nn.ModuleList(backbones)
|
|
||||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
# input_dim = 14 + 7 # robot_state + env_state
|
|
||||||
# self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
|
||||||
# self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
|
||||||
# self.pos = torch.nn.Embedding(2, hidden_dim)
|
|
||||||
# self.backbones = None
|
|
||||||
|
|
||||||
# encoder extra parameters
|
|
||||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
|
||||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
|
||||||
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
|
|
||||||
self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
|
|
||||||
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
|
|
||||||
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
|
|
||||||
|
|
||||||
# decoder extra parameters
|
|
||||||
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
|
||||||
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
|
|
||||||
|
|
||||||
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
|
|
||||||
"""
|
|
||||||
qpos: batch, qpos_dim
|
|
||||||
image: batch, num_cam, channel, height, width
|
|
||||||
env_state: None
|
|
||||||
actions: batch, seq, action_dim
|
|
||||||
"""
|
|
||||||
is_training = actions is not None # train or val
|
|
||||||
bs, _ = qpos.shape
|
|
||||||
### Obtain latent z from action sequence
|
|
||||||
if is_training:
|
|
||||||
# project action sequence to embedding dim, and concat with a CLS token
|
|
||||||
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
|
|
||||||
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
|
||||||
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
|
|
||||||
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
|
||||||
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
|
|
||||||
encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
|
|
||||||
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
|
|
||||||
# do not mask cls token
|
|
||||||
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
|
|
||||||
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
|
||||||
# obtain position embedding
|
|
||||||
pos_embed = self.pos_table.clone().detach()
|
|
||||||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
|
||||||
# query model
|
|
||||||
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
|
|
||||||
encoder_output = encoder_output[0] # take cls output only
|
|
||||||
latent_info = self.latent_proj(encoder_output)
|
|
||||||
mu = latent_info[:, :self.latent_dim]
|
|
||||||
logvar = latent_info[:, self.latent_dim:]
|
|
||||||
latent_sample = reparametrize(mu, logvar)
|
|
||||||
latent_input = self.latent_out_proj(latent_sample)
|
|
||||||
else:
|
|
||||||
mu = logvar = None
|
|
||||||
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
|
|
||||||
latent_input = self.latent_out_proj(latent_sample)
|
|
||||||
|
|
||||||
if self.backbones is not None:
|
|
||||||
# Image observation features and position embeddings
|
|
||||||
all_cam_features = []
|
|
||||||
all_cam_pos = []
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# print(f"Image shape: {image.shape}, Number of cameras: {len(self.camera_names)}")
|
|
||||||
|
|
||||||
|
|
||||||
for cam_id, cam_name in enumerate(self.camera_names):
|
|
||||||
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
|
||||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
|
||||||
features = features[0] # take the last layer feature
|
|
||||||
pos = pos[0]
|
|
||||||
all_cam_features.append(self.input_proj(features))
|
|
||||||
all_cam_pos.append(pos)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# proprioception features
|
|
||||||
proprio_input = self.input_proj_robot_state(qpos)
|
|
||||||
# fold camera dimension into width dimension
|
|
||||||
src = torch.cat(all_cam_features, axis=3)
|
|
||||||
pos = torch.cat(all_cam_pos, axis=3)
|
|
||||||
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
|
|
||||||
else:
|
|
||||||
qpos = self.input_proj_robot_state(qpos)
|
|
||||||
env_state = self.input_proj_env_state(env_state)
|
|
||||||
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
|
|
||||||
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
|
|
||||||
a_hat = self.action_head(hs)
|
|
||||||
is_pad_hat = self.is_pad_head(hs)
|
|
||||||
return a_hat, is_pad_hat, [mu, logvar]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CNNMLP(nn.Module):
|
|
||||||
def __init__(self, backbones, state_dim, camera_names):
|
|
||||||
""" Initializes the model.
|
|
||||||
Parameters:
|
|
||||||
backbones: torch module of the backbone to be used. See backbone.py
|
|
||||||
transformer: torch module of the transformer architecture. See transformer.py
|
|
||||||
state_dim: robot state dimension of the environment
|
|
||||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
|
||||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
|
||||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.camera_names = camera_names
|
|
||||||
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
|
||||||
if backbones is not None:
|
|
||||||
self.backbones = nn.ModuleList(backbones)
|
|
||||||
backbone_down_projs = []
|
|
||||||
for backbone in backbones:
|
|
||||||
down_proj = nn.Sequential(
|
|
||||||
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
|
|
||||||
nn.Conv2d(128, 64, kernel_size=5),
|
|
||||||
nn.Conv2d(64, 32, kernel_size=5)
|
|
||||||
)
|
|
||||||
backbone_down_projs.append(down_proj)
|
|
||||||
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
|
||||||
|
|
||||||
mlp_in_dim = 768 * len(backbones) + 14
|
|
||||||
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def forward(self, qpos, image, env_state, actions=None):
|
|
||||||
"""
|
|
||||||
qpos: batch, qpos_dim
|
|
||||||
image: batch, num_cam, channel, height, width
|
|
||||||
env_state: None
|
|
||||||
actions: batch, seq, action_dim
|
|
||||||
"""
|
|
||||||
is_training = actions is not None # train or val
|
|
||||||
bs, _ = qpos.shape
|
|
||||||
# Image observation features and position embeddings
|
|
||||||
all_cam_features = []
|
|
||||||
for cam_id, cam_name in enumerate(self.camera_names):
|
|
||||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
|
||||||
features = features[0] # take the last layer feature
|
|
||||||
pos = pos[0] # not used
|
|
||||||
all_cam_features.append(self.backbone_down_projs[cam_id](features))
|
|
||||||
# flatten everything
|
|
||||||
flattened_features = []
|
|
||||||
for cam_feature in all_cam_features:
|
|
||||||
flattened_features.append(cam_feature.reshape([bs, -1]))
|
|
||||||
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
|
||||||
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
|
||||||
a_hat = self.mlp(features)
|
|
||||||
return a_hat
|
|
||||||
|
|
||||||
|
|
||||||
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
|
||||||
if hidden_depth == 0:
|
|
||||||
mods = [nn.Linear(input_dim, output_dim)]
|
|
||||||
else:
|
|
||||||
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
|
|
||||||
for i in range(hidden_depth - 1):
|
|
||||||
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
|
|
||||||
mods.append(nn.Linear(hidden_dim, output_dim))
|
|
||||||
trunk = nn.Sequential(*mods)
|
|
||||||
return trunk
|
|
||||||
|
|
||||||
|
|
||||||
def build_encoder(args):
|
|
||||||
d_model = args.hidden_dim # 256
|
|
||||||
dropout = args.dropout # 0.1
|
|
||||||
nhead = args.nheads # 8
|
|
||||||
dim_feedforward = args.dim_feedforward # 2048
|
|
||||||
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
|
|
||||||
normalize_before = args.pre_norm # False
|
|
||||||
activation = "relu"
|
|
||||||
|
|
||||||
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
|
||||||
dropout, activation, normalize_before)
|
|
||||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
|
||||||
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
|
||||||
|
|
||||||
return encoder
|
|
||||||
|
|
||||||
|
|
||||||
def build(args):
|
|
||||||
state_dim = args.state_dim
|
|
||||||
action_dim = args.action_dim
|
|
||||||
|
|
||||||
# From state
|
|
||||||
# backbone = None # from state for now, no need for conv nets
|
|
||||||
# From image
|
|
||||||
backbones = []
|
|
||||||
# backbone = build_backbone(args)
|
|
||||||
# backbones.append(backbone)
|
|
||||||
for _ in args.camera_names:
|
|
||||||
backbone = build_backbone(args)
|
|
||||||
backbones.append(backbone)
|
|
||||||
|
|
||||||
transformer = build_transformer(args)
|
|
||||||
|
|
||||||
encoder = build_encoder(args)
|
|
||||||
|
|
||||||
model = DETRVAE(
|
|
||||||
backbones,
|
|
||||||
transformer,
|
|
||||||
encoder,
|
|
||||||
state_dim=state_dim,
|
|
||||||
action_dim=action_dim,
|
|
||||||
num_queries=args.num_queries,
|
|
||||||
camera_names=args.camera_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
print("number of parameters: %.2fM" % (n_parameters/1e6,))
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def build_cnnmlp(args):
|
|
||||||
state_dim = 14 # TODO hardcode
|
|
||||||
|
|
||||||
# From state
|
|
||||||
# backbone = None # from state for now, no need for conv nets
|
|
||||||
# From image
|
|
||||||
backbones = []
|
|
||||||
for _ in args.camera_names:
|
|
||||||
backbone = build_backbone(args)
|
|
||||||
backbones.append(backbone)
|
|
||||||
|
|
||||||
model = CNNMLP(
|
|
||||||
backbones,
|
|
||||||
state_dim=state_dim,
|
|
||||||
camera_names=args.camera_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
print("number of parameters: %.2fM" % (n_parameters/1e6,))
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
"""
|
|
||||||
Various positional encodings for the transformer.
|
|
||||||
"""
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from util.misc import NestedTensor
|
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingSine(nn.Module):
|
|
||||||
"""
|
|
||||||
This is a more standard version of the position embedding, very similar to the one
|
|
||||||
used by the Attention is all you need paper, generalized to work on images.
|
|
||||||
"""
|
|
||||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
|
||||||
super().__init__()
|
|
||||||
self.num_pos_feats = num_pos_feats
|
|
||||||
self.temperature = temperature
|
|
||||||
self.normalize = normalize
|
|
||||||
if scale is not None and normalize is False:
|
|
||||||
raise ValueError("normalize should be True if scale is passed")
|
|
||||||
if scale is None:
|
|
||||||
scale = 2 * math.pi
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def forward(self, tensor):
|
|
||||||
x = tensor
|
|
||||||
# mask = tensor_list.mask
|
|
||||||
# assert mask is not None
|
|
||||||
# not_mask = ~mask
|
|
||||||
|
|
||||||
not_mask = torch.ones_like(x[0, [0]])
|
|
||||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
|
||||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
|
||||||
if self.normalize:
|
|
||||||
eps = 1e-6
|
|
||||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
||||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
||||||
|
|
||||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
|
||||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
|
||||||
|
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t
|
|
||||||
pos_y = y_embed[:, :, :, None] / dim_t
|
|
||||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
||||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
||||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
||||||
return pos
|
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingLearned(nn.Module):
|
|
||||||
"""
|
|
||||||
Absolute pos embedding, learned.
|
|
||||||
"""
|
|
||||||
def __init__(self, num_pos_feats=256):
|
|
||||||
super().__init__()
|
|
||||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
|
||||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
nn.init.uniform_(self.row_embed.weight)
|
|
||||||
nn.init.uniform_(self.col_embed.weight)
|
|
||||||
|
|
||||||
def forward(self, tensor_list: NestedTensor):
|
|
||||||
x = tensor_list.tensors
|
|
||||||
h, w = x.shape[-2:]
|
|
||||||
i = torch.arange(w, device=x.device)
|
|
||||||
j = torch.arange(h, device=x.device)
|
|
||||||
x_emb = self.col_embed(i)
|
|
||||||
y_emb = self.row_embed(j)
|
|
||||||
pos = torch.cat([
|
|
||||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
|
||||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
|
||||||
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
|
||||||
return pos
|
|
||||||
|
|
||||||
|
|
||||||
def build_position_encoding(args):
|
|
||||||
N_steps = args.hidden_dim // 2
|
|
||||||
if args.position_embedding in ('v2', 'sine'):
|
|
||||||
# TODO find a better way of exposing other arguments
|
|
||||||
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
|
||||||
elif args.position_embedding in ('v3', 'learned'):
|
|
||||||
position_embedding = PositionEmbeddingLearned(N_steps)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"not supported {args.position_embedding}")
|
|
||||||
|
|
||||||
return position_embedding
|
|
||||||
@@ -1,312 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
"""
|
|
||||||
DETR Transformer class.
|
|
||||||
|
|
||||||
Copy-paste from torch.nn.Transformer with modifications:
|
|
||||||
* positional encodings are passed in MHattention
|
|
||||||
* extra LN at the end of encoder is removed
|
|
||||||
* decoder returns a stack of activations from all decoding layers
|
|
||||||
"""
|
|
||||||
import copy
|
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn, Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
|
||||||
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
|
||||||
activation="relu", normalize_before=False,
|
|
||||||
return_intermediate_dec=False):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
|
||||||
dropout, activation, normalize_before)
|
|
||||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
|
||||||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
|
||||||
|
|
||||||
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
|
||||||
dropout, activation, normalize_before)
|
|
||||||
decoder_norm = nn.LayerNorm(d_model)
|
|
||||||
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
|
||||||
return_intermediate=return_intermediate_dec)
|
|
||||||
|
|
||||||
self._reset_parameters()
|
|
||||||
|
|
||||||
self.d_model = d_model
|
|
||||||
self.nhead = nhead
|
|
||||||
|
|
||||||
def _reset_parameters(self):
|
|
||||||
for p in self.parameters():
|
|
||||||
if p.dim() > 1:
|
|
||||||
nn.init.xavier_uniform_(p)
|
|
||||||
|
|
||||||
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
|
|
||||||
# TODO flatten only when input has H and W
|
|
||||||
if len(src.shape) == 4: # has H and W
|
|
||||||
# flatten NxCxHxW to HWxNxC
|
|
||||||
bs, c, h, w = src.shape
|
|
||||||
src = src.flatten(2).permute(2, 0, 1)
|
|
||||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
|
||||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
|
||||||
# mask = mask.flatten(1)
|
|
||||||
|
|
||||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
|
||||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
|
||||||
|
|
||||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
|
||||||
src = torch.cat([addition_input, src], axis=0)
|
|
||||||
else:
|
|
||||||
assert len(src.shape) == 3
|
|
||||||
# flatten NxHWxC to HWxNxC
|
|
||||||
bs, hw, c = src.shape
|
|
||||||
src = src.permute(1, 0, 2)
|
|
||||||
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
|
||||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
|
||||||
|
|
||||||
tgt = torch.zeros_like(query_embed)
|
|
||||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
|
||||||
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
|
|
||||||
pos=pos_embed, query_pos=query_embed)
|
|
||||||
hs = hs.transpose(1, 2)
|
|
||||||
return hs
|
|
||||||
|
|
||||||
class TransformerEncoder(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
|
||||||
super().__init__()
|
|
||||||
self.layers = _get_clones(encoder_layer, num_layers)
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.norm = norm
|
|
||||||
|
|
||||||
def forward(self, src,
|
|
||||||
mask: Optional[Tensor] = None,
|
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None):
|
|
||||||
output = src
|
|
||||||
|
|
||||||
for layer in self.layers:
|
|
||||||
output = layer(output, src_mask=mask,
|
|
||||||
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
|
||||||
|
|
||||||
if self.norm is not None:
|
|
||||||
output = self.norm(output)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoder(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
|
||||||
super().__init__()
|
|
||||||
self.layers = _get_clones(decoder_layer, num_layers)
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.norm = norm
|
|
||||||
self.return_intermediate = return_intermediate
|
|
||||||
|
|
||||||
def forward(self, tgt, memory,
|
|
||||||
tgt_mask: Optional[Tensor] = None,
|
|
||||||
memory_mask: Optional[Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None,
|
|
||||||
query_pos: Optional[Tensor] = None):
|
|
||||||
output = tgt
|
|
||||||
|
|
||||||
intermediate = []
|
|
||||||
|
|
||||||
for layer in self.layers:
|
|
||||||
output = layer(output, memory, tgt_mask=tgt_mask,
|
|
||||||
memory_mask=memory_mask,
|
|
||||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
|
||||||
pos=pos, query_pos=query_pos)
|
|
||||||
if self.return_intermediate:
|
|
||||||
intermediate.append(self.norm(output))
|
|
||||||
|
|
||||||
if self.norm is not None:
|
|
||||||
output = self.norm(output)
|
|
||||||
if self.return_intermediate:
|
|
||||||
intermediate.pop()
|
|
||||||
intermediate.append(output)
|
|
||||||
|
|
||||||
if self.return_intermediate:
|
|
||||||
return torch.stack(intermediate)
|
|
||||||
|
|
||||||
return output.unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
|
||||||
activation="relu", normalize_before=False):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
||||||
# Implementation of Feedforward model
|
|
||||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(d_model)
|
|
||||||
self.norm2 = nn.LayerNorm(d_model)
|
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
self.normalize_before = normalize_before
|
|
||||||
|
|
||||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
|
||||||
return tensor if pos is None else tensor + pos
|
|
||||||
|
|
||||||
def forward_post(self,
|
|
||||||
src,
|
|
||||||
src_mask: Optional[Tensor] = None,
|
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None):
|
|
||||||
q = k = self.with_pos_embed(src, pos)
|
|
||||||
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
|
||||||
key_padding_mask=src_key_padding_mask)[0]
|
|
||||||
src = src + self.dropout1(src2)
|
|
||||||
src = self.norm1(src)
|
|
||||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
|
||||||
src = src + self.dropout2(src2)
|
|
||||||
src = self.norm2(src)
|
|
||||||
return src
|
|
||||||
|
|
||||||
def forward_pre(self, src,
|
|
||||||
src_mask: Optional[Tensor] = None,
|
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None):
|
|
||||||
src2 = self.norm1(src)
|
|
||||||
q = k = self.with_pos_embed(src2, pos)
|
|
||||||
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
|
||||||
key_padding_mask=src_key_padding_mask)[0]
|
|
||||||
src = src + self.dropout1(src2)
|
|
||||||
src2 = self.norm2(src)
|
|
||||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
|
||||||
src = src + self.dropout2(src2)
|
|
||||||
return src
|
|
||||||
|
|
||||||
def forward(self, src,
|
|
||||||
src_mask: Optional[Tensor] = None,
|
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None):
|
|
||||||
if self.normalize_before:
|
|
||||||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
|
||||||
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
|
||||||
activation="relu", normalize_before=False):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
||||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
||||||
# Implementation of Feedforward model
|
|
||||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(d_model)
|
|
||||||
self.norm2 = nn.LayerNorm(d_model)
|
|
||||||
self.norm3 = nn.LayerNorm(d_model)
|
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
self.dropout3 = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
self.normalize_before = normalize_before
|
|
||||||
|
|
||||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
|
||||||
return tensor if pos is None else tensor + pos
|
|
||||||
|
|
||||||
def forward_post(self, tgt, memory,
|
|
||||||
tgt_mask: Optional[Tensor] = None,
|
|
||||||
memory_mask: Optional[Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None,
|
|
||||||
query_pos: Optional[Tensor] = None):
|
|
||||||
q = k = self.with_pos_embed(tgt, query_pos)
|
|
||||||
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
|
||||||
key_padding_mask=tgt_key_padding_mask)[0]
|
|
||||||
tgt = tgt + self.dropout1(tgt2)
|
|
||||||
tgt = self.norm1(tgt)
|
|
||||||
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
|
||||||
key=self.with_pos_embed(memory, pos),
|
|
||||||
value=memory, attn_mask=memory_mask,
|
|
||||||
key_padding_mask=memory_key_padding_mask)[0]
|
|
||||||
tgt = tgt + self.dropout2(tgt2)
|
|
||||||
tgt = self.norm2(tgt)
|
|
||||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
|
||||||
tgt = tgt + self.dropout3(tgt2)
|
|
||||||
tgt = self.norm3(tgt)
|
|
||||||
return tgt
|
|
||||||
|
|
||||||
def forward_pre(self, tgt, memory,
|
|
||||||
tgt_mask: Optional[Tensor] = None,
|
|
||||||
memory_mask: Optional[Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None,
|
|
||||||
query_pos: Optional[Tensor] = None):
|
|
||||||
tgt2 = self.norm1(tgt)
|
|
||||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
|
||||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
|
||||||
key_padding_mask=tgt_key_padding_mask)[0]
|
|
||||||
tgt = tgt + self.dropout1(tgt2)
|
|
||||||
tgt2 = self.norm2(tgt)
|
|
||||||
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
|
||||||
key=self.with_pos_embed(memory, pos),
|
|
||||||
value=memory, attn_mask=memory_mask,
|
|
||||||
key_padding_mask=memory_key_padding_mask)[0]
|
|
||||||
tgt = tgt + self.dropout2(tgt2)
|
|
||||||
tgt2 = self.norm3(tgt)
|
|
||||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
|
||||||
tgt = tgt + self.dropout3(tgt2)
|
|
||||||
return tgt
|
|
||||||
|
|
||||||
def forward(self, tgt, memory,
|
|
||||||
tgt_mask: Optional[Tensor] = None,
|
|
||||||
memory_mask: Optional[Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None,
|
|
||||||
query_pos: Optional[Tensor] = None):
|
|
||||||
if self.normalize_before:
|
|
||||||
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
|
||||||
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
|
||||||
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
|
||||||
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_clones(module, N):
|
|
||||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
|
||||||
|
|
||||||
|
|
||||||
def build_transformer(args):
|
|
||||||
return Transformer(
|
|
||||||
d_model=args.hidden_dim,
|
|
||||||
dropout=args.dropout,
|
|
||||||
nhead=args.nheads,
|
|
||||||
dim_feedforward=args.dim_feedforward,
|
|
||||||
num_encoder_layers=args.enc_layers,
|
|
||||||
num_decoder_layers=args.dec_layers,
|
|
||||||
normalize_before=args.pre_norm,
|
|
||||||
return_intermediate_dec=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_activation_fn(activation):
|
|
||||||
"""Return an activation function given a string"""
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
if activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
if activation == "glu":
|
|
||||||
return F.glu
|
|
||||||
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
|
||||||
@@ -1,163 +0,0 @@
|
|||||||
import torch.nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import torchvision.transforms as transforms
|
|
||||||
from torchvision.transforms import v2
|
|
||||||
import torch
|
|
||||||
from roboimi.detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
|
|
||||||
|
|
||||||
|
|
||||||
class ACTPolicy(nn.Module):
|
|
||||||
def __init__(self, args_override):
|
|
||||||
super().__init__()
|
|
||||||
model, optimizer = build_ACT_model_and_optimizer(args_override)
|
|
||||||
self.model = model # CVAE decoder
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.kl_weight = args_override['kl_weight']
|
|
||||||
print(f'KL Weight {self.kl_weight}')
|
|
||||||
|
|
||||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
|
||||||
env_state = None
|
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
||||||
std=[0.229, 0.224, 0.225])
|
|
||||||
image = normalize(image)
|
|
||||||
if actions is not None: # training time
|
|
||||||
actions = actions[:, :self.model.num_queries]
|
|
||||||
is_pad = is_pad[:, :self.model.num_queries]
|
|
||||||
|
|
||||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
|
||||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
|
||||||
loss_dict = dict()
|
|
||||||
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
|
|
||||||
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
|
||||||
loss_dict['l1'] = l1
|
|
||||||
loss_dict['kl'] = total_kld[0]
|
|
||||||
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
|
|
||||||
return loss_dict
|
|
||||||
else: # inference time
|
|
||||||
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
|
||||||
return a_hat
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
class ACTTVPolicy(nn.Module):
|
|
||||||
def __init__(self, args_override):
|
|
||||||
super().__init__()
|
|
||||||
model, optimizer = build_ACT_model_and_optimizer(args_override)
|
|
||||||
self.model = model # CVAE decoder
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.kl_weight = args_override['kl_weight']
|
|
||||||
self.qpos_noise_std = args_override['qpos_noise_std']
|
|
||||||
print(f'KL Weight {self.kl_weight}')
|
|
||||||
|
|
||||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
|
||||||
env_state = None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
||||||
# std=[0.229, 0.224, 0.225])
|
|
||||||
# image = normalize(image)
|
|
||||||
|
|
||||||
|
|
||||||
patch_h = 16
|
|
||||||
patch_w = 22
|
|
||||||
if actions is not None:
|
|
||||||
transform = v2.Compose([
|
|
||||||
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
|
|
||||||
v2.RandomPerspective(distortion_scale=0.5),
|
|
||||||
v2.RandomAffine(degrees=10, translate=(0.1,0.1), scale=(0.9,1.1)),
|
|
||||||
v2.GaussianBlur(kernel_size=(9,9), sigma=(0.1,2.0)),
|
|
||||||
v2.Resize((patch_h * 14, patch_w * 14)),
|
|
||||||
# v2.CenterCrop((patch_h * 14, patch_w * 14)),
|
|
||||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
||||||
])
|
|
||||||
qpos += (self.qpos_noise_std**0.5)*torch.randn_like(qpos)
|
|
||||||
else: # inference time
|
|
||||||
transform = v2.Compose([
|
|
||||||
v2.Resize((patch_h * 14, patch_w * 14)),
|
|
||||||
# v2.CenterCrop((patch_h * 14, patch_w * 14)),
|
|
||||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
|
|
||||||
])
|
|
||||||
|
|
||||||
image = transform(image)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if actions is not None: # training time
|
|
||||||
actions = actions[:, :self.model.num_queries]
|
|
||||||
is_pad = is_pad[:, :self.model.num_queries]
|
|
||||||
|
|
||||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
|
||||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
|
||||||
loss_dict = dict()
|
|
||||||
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
|
|
||||||
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
|
||||||
loss_dict['l1'] = l1
|
|
||||||
loss_dict['kl'] = total_kld[0]
|
|
||||||
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
|
|
||||||
return loss_dict
|
|
||||||
else: # inference time
|
|
||||||
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
|
||||||
return a_hat
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
|
|
||||||
class CNNMLPPolicy(nn.Module):
|
|
||||||
def __init__(self, args_override):
|
|
||||||
super().__init__()
|
|
||||||
model, optimizer = build_CNNMLP_model_and_optimizer(args_override)
|
|
||||||
self.model = model # decoder
|
|
||||||
self.optimizer = optimizer
|
|
||||||
|
|
||||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
|
||||||
env_state = None # TODO
|
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
||||||
std=[0.229, 0.224, 0.225])
|
|
||||||
image = normalize(image)
|
|
||||||
if actions is not None: # training time
|
|
||||||
actions = actions[:, 0]
|
|
||||||
a_hat = self.model(qpos, image, env_state, actions)
|
|
||||||
mse = F.mse_loss(actions, a_hat)
|
|
||||||
loss_dict = dict()
|
|
||||||
loss_dict['mse'] = mse
|
|
||||||
loss_dict['loss'] = loss_dict['mse']
|
|
||||||
return loss_dict
|
|
||||||
else: # inference time
|
|
||||||
a_hat = self.model(qpos, image, env_state) # no action, sample from prior
|
|
||||||
return a_hat
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
def kl_divergence(mu, logvar):
|
|
||||||
batch_size = mu.size(0)
|
|
||||||
assert batch_size != 0
|
|
||||||
if mu.data.ndimension() == 4:
|
|
||||||
mu = mu.view(mu.size(0), mu.size(1))
|
|
||||||
if logvar.data.ndimension() == 4:
|
|
||||||
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
|
||||||
|
|
||||||
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
|
||||||
total_kld = klds.sum(1).mean(0, True)
|
|
||||||
dimension_wise_kld = klds.mean(0)
|
|
||||||
mean_kld = klds.mean(1).mean(0, True)
|
|
||||||
|
|
||||||
return total_kld, dimension_wise_kld, mean_kld
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
from distutils.core import setup
|
|
||||||
from setuptools import find_packages
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name='detr',
|
|
||||||
version='0.0.0',
|
|
||||||
packages=find_packages(),
|
|
||||||
license='MIT License',
|
|
||||||
long_description=open('README.md').read(),
|
|
||||||
)
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
"""
|
|
||||||
Utilities for bounding box manipulation and GIoU.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
from torchvision.ops.boxes import box_area
|
|
||||||
|
|
||||||
|
|
||||||
def box_cxcywh_to_xyxy(x):
|
|
||||||
x_c, y_c, w, h = x.unbind(-1)
|
|
||||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
|
||||||
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
|
||||||
return torch.stack(b, dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def box_xyxy_to_cxcywh(x):
|
|
||||||
x0, y0, x1, y1 = x.unbind(-1)
|
|
||||||
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
|
||||||
(x1 - x0), (y1 - y0)]
|
|
||||||
return torch.stack(b, dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
# modified from torchvision to also return the union
|
|
||||||
def box_iou(boxes1, boxes2):
|
|
||||||
area1 = box_area(boxes1)
|
|
||||||
area2 = box_area(boxes2)
|
|
||||||
|
|
||||||
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
|
||||||
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
|
||||||
|
|
||||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
|
||||||
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
|
||||||
|
|
||||||
union = area1[:, None] + area2 - inter
|
|
||||||
|
|
||||||
iou = inter / union
|
|
||||||
return iou, union
|
|
||||||
|
|
||||||
|
|
||||||
def generalized_box_iou(boxes1, boxes2):
|
|
||||||
"""
|
|
||||||
Generalized IoU from https://giou.stanford.edu/
|
|
||||||
|
|
||||||
The boxes should be in [x0, y0, x1, y1] format
|
|
||||||
|
|
||||||
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
|
||||||
and M = len(boxes2)
|
|
||||||
"""
|
|
||||||
# degenerate boxes gives inf / nan results
|
|
||||||
# so do an early check
|
|
||||||
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
|
||||||
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
|
||||||
iou, union = box_iou(boxes1, boxes2)
|
|
||||||
|
|
||||||
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
|
||||||
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
|
||||||
|
|
||||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
|
||||||
area = wh[:, :, 0] * wh[:, :, 1]
|
|
||||||
|
|
||||||
return iou - (area - union) / area
|
|
||||||
|
|
||||||
|
|
||||||
def masks_to_boxes(masks):
|
|
||||||
"""Compute the bounding boxes around the provided masks
|
|
||||||
|
|
||||||
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
|
||||||
|
|
||||||
Returns a [N, 4] tensors, with the boxes in xyxy format
|
|
||||||
"""
|
|
||||||
if masks.numel() == 0:
|
|
||||||
return torch.zeros((0, 4), device=masks.device)
|
|
||||||
|
|
||||||
h, w = masks.shape[-2:]
|
|
||||||
|
|
||||||
y = torch.arange(0, h, dtype=torch.float)
|
|
||||||
x = torch.arange(0, w, dtype=torch.float)
|
|
||||||
y, x = torch.meshgrid(y, x)
|
|
||||||
|
|
||||||
x_mask = (masks * x.unsqueeze(0))
|
|
||||||
x_max = x_mask.flatten(1).max(-1)[0]
|
|
||||||
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
|
||||||
|
|
||||||
y_mask = (masks * y.unsqueeze(0))
|
|
||||||
y_max = y_mask.flatten(1).max(-1)[0]
|
|
||||||
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
|
||||||
|
|
||||||
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
|
||||||
@@ -1,468 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
"""
|
|
||||||
Misc functions, including distributed helpers.
|
|
||||||
|
|
||||||
Mostly copy-paste from torchvision references.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
from collections import defaultdict, deque
|
|
||||||
import datetime
|
|
||||||
import pickle
|
|
||||||
from packaging import version
|
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
|
||||||
import torchvision
|
|
||||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
|
||||||
from torchvision.ops import _new_empty_tensor
|
|
||||||
from torchvision.ops.misc import _output_size
|
|
||||||
|
|
||||||
|
|
||||||
class SmoothedValue(object):
|
|
||||||
"""Track a series of values and provide access to smoothed values over a
|
|
||||||
window or the global series average.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, window_size=20, fmt=None):
|
|
||||||
if fmt is None:
|
|
||||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
|
||||||
self.deque = deque(maxlen=window_size)
|
|
||||||
self.total = 0.0
|
|
||||||
self.count = 0
|
|
||||||
self.fmt = fmt
|
|
||||||
|
|
||||||
def update(self, value, n=1):
|
|
||||||
self.deque.append(value)
|
|
||||||
self.count += n
|
|
||||||
self.total += value * n
|
|
||||||
|
|
||||||
def synchronize_between_processes(self):
|
|
||||||
"""
|
|
||||||
Warning: does not synchronize the deque!
|
|
||||||
"""
|
|
||||||
if not is_dist_avail_and_initialized():
|
|
||||||
return
|
|
||||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
|
||||||
dist.barrier()
|
|
||||||
dist.all_reduce(t)
|
|
||||||
t = t.tolist()
|
|
||||||
self.count = int(t[0])
|
|
||||||
self.total = t[1]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def median(self):
|
|
||||||
d = torch.tensor(list(self.deque))
|
|
||||||
return d.median().item()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def avg(self):
|
|
||||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
|
||||||
return d.mean().item()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def global_avg(self):
|
|
||||||
return self.total / self.count
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max(self):
|
|
||||||
return max(self.deque)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def value(self):
|
|
||||||
return self.deque[-1]
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.fmt.format(
|
|
||||||
median=self.median,
|
|
||||||
avg=self.avg,
|
|
||||||
global_avg=self.global_avg,
|
|
||||||
max=self.max,
|
|
||||||
value=self.value)
|
|
||||||
|
|
||||||
|
|
||||||
def all_gather(data):
|
|
||||||
"""
|
|
||||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
|
||||||
Args:
|
|
||||||
data: any picklable object
|
|
||||||
Returns:
|
|
||||||
list[data]: list of data gathered from each rank
|
|
||||||
"""
|
|
||||||
world_size = get_world_size()
|
|
||||||
if world_size == 1:
|
|
||||||
return [data]
|
|
||||||
|
|
||||||
# serialized to a Tensor
|
|
||||||
buffer = pickle.dumps(data)
|
|
||||||
storage = torch.ByteStorage.from_buffer(buffer)
|
|
||||||
tensor = torch.ByteTensor(storage).to("cuda")
|
|
||||||
|
|
||||||
# obtain Tensor size of each rank
|
|
||||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
|
||||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
|
||||||
dist.all_gather(size_list, local_size)
|
|
||||||
size_list = [int(size.item()) for size in size_list]
|
|
||||||
max_size = max(size_list)
|
|
||||||
|
|
||||||
# receiving Tensor from all ranks
|
|
||||||
# we pad the tensor because torch all_gather does not support
|
|
||||||
# gathering tensors of different shapes
|
|
||||||
tensor_list = []
|
|
||||||
for _ in size_list:
|
|
||||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
|
||||||
if local_size != max_size:
|
|
||||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
|
||||||
tensor = torch.cat((tensor, padding), dim=0)
|
|
||||||
dist.all_gather(tensor_list, tensor)
|
|
||||||
|
|
||||||
data_list = []
|
|
||||||
for size, tensor in zip(size_list, tensor_list):
|
|
||||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
|
||||||
data_list.append(pickle.loads(buffer))
|
|
||||||
|
|
||||||
return data_list
|
|
||||||
|
|
||||||
|
|
||||||
def reduce_dict(input_dict, average=True):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input_dict (dict): all the values will be reduced
|
|
||||||
average (bool): whether to do average or sum
|
|
||||||
Reduce the values in the dictionary from all processes so that all processes
|
|
||||||
have the averaged results. Returns a dict with the same fields as
|
|
||||||
input_dict, after reduction.
|
|
||||||
"""
|
|
||||||
world_size = get_world_size()
|
|
||||||
if world_size < 2:
|
|
||||||
return input_dict
|
|
||||||
with torch.no_grad():
|
|
||||||
names = []
|
|
||||||
values = []
|
|
||||||
# sort the keys so that they are consistent across processes
|
|
||||||
for k in sorted(input_dict.keys()):
|
|
||||||
names.append(k)
|
|
||||||
values.append(input_dict[k])
|
|
||||||
values = torch.stack(values, dim=0)
|
|
||||||
dist.all_reduce(values)
|
|
||||||
if average:
|
|
||||||
values /= world_size
|
|
||||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
|
||||||
return reduced_dict
|
|
||||||
|
|
||||||
|
|
||||||
class MetricLogger(object):
|
|
||||||
def __init__(self, delimiter="\t"):
|
|
||||||
self.meters = defaultdict(SmoothedValue)
|
|
||||||
self.delimiter = delimiter
|
|
||||||
|
|
||||||
def update(self, **kwargs):
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
v = v.item()
|
|
||||||
assert isinstance(v, (float, int))
|
|
||||||
self.meters[k].update(v)
|
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
if attr in self.meters:
|
|
||||||
return self.meters[attr]
|
|
||||||
if attr in self.__dict__:
|
|
||||||
return self.__dict__[attr]
|
|
||||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
|
||||||
type(self).__name__, attr))
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
loss_str = []
|
|
||||||
for name, meter in self.meters.items():
|
|
||||||
loss_str.append(
|
|
||||||
"{}: {}".format(name, str(meter))
|
|
||||||
)
|
|
||||||
return self.delimiter.join(loss_str)
|
|
||||||
|
|
||||||
def synchronize_between_processes(self):
|
|
||||||
for meter in self.meters.values():
|
|
||||||
meter.synchronize_between_processes()
|
|
||||||
|
|
||||||
def add_meter(self, name, meter):
|
|
||||||
self.meters[name] = meter
|
|
||||||
|
|
||||||
def log_every(self, iterable, print_freq, header=None):
|
|
||||||
i = 0
|
|
||||||
if not header:
|
|
||||||
header = ''
|
|
||||||
start_time = time.time()
|
|
||||||
end = time.time()
|
|
||||||
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
|
||||||
data_time = SmoothedValue(fmt='{avg:.4f}')
|
|
||||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
log_msg = self.delimiter.join([
|
|
||||||
header,
|
|
||||||
'[{0' + space_fmt + '}/{1}]',
|
|
||||||
'eta: {eta}',
|
|
||||||
'{meters}',
|
|
||||||
'time: {time}',
|
|
||||||
'data: {data}',
|
|
||||||
'max mem: {memory:.0f}'
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
log_msg = self.delimiter.join([
|
|
||||||
header,
|
|
||||||
'[{0' + space_fmt + '}/{1}]',
|
|
||||||
'eta: {eta}',
|
|
||||||
'{meters}',
|
|
||||||
'time: {time}',
|
|
||||||
'data: {data}'
|
|
||||||
])
|
|
||||||
MB = 1024.0 * 1024.0
|
|
||||||
for obj in iterable:
|
|
||||||
data_time.update(time.time() - end)
|
|
||||||
yield obj
|
|
||||||
iter_time.update(time.time() - end)
|
|
||||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
|
||||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
|
||||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
print(log_msg.format(
|
|
||||||
i, len(iterable), eta=eta_string,
|
|
||||||
meters=str(self),
|
|
||||||
time=str(iter_time), data=str(data_time),
|
|
||||||
memory=torch.cuda.max_memory_allocated() / MB))
|
|
||||||
else:
|
|
||||||
print(log_msg.format(
|
|
||||||
i, len(iterable), eta=eta_string,
|
|
||||||
meters=str(self),
|
|
||||||
time=str(iter_time), data=str(data_time)))
|
|
||||||
i += 1
|
|
||||||
end = time.time()
|
|
||||||
total_time = time.time() - start_time
|
|
||||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
|
||||||
print('{} Total time: {} ({:.4f} s / it)'.format(
|
|
||||||
header, total_time_str, total_time / len(iterable)))
|
|
||||||
|
|
||||||
|
|
||||||
def get_sha():
|
|
||||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
|
|
||||||
def _run(command):
|
|
||||||
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
|
||||||
sha = 'N/A'
|
|
||||||
diff = "clean"
|
|
||||||
branch = 'N/A'
|
|
||||||
try:
|
|
||||||
sha = _run(['git', 'rev-parse', 'HEAD'])
|
|
||||||
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
|
||||||
diff = _run(['git', 'diff-index', 'HEAD'])
|
|
||||||
diff = "has uncommited changes" if diff else "clean"
|
|
||||||
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(batch):
|
|
||||||
batch = list(zip(*batch))
|
|
||||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
|
||||||
return tuple(batch)
|
|
||||||
|
|
||||||
|
|
||||||
def _max_by_axis(the_list):
|
|
||||||
# type: (List[List[int]]) -> List[int]
|
|
||||||
maxes = the_list[0]
|
|
||||||
for sublist in the_list[1:]:
|
|
||||||
for index, item in enumerate(sublist):
|
|
||||||
maxes[index] = max(maxes[index], item)
|
|
||||||
return maxes
|
|
||||||
|
|
||||||
|
|
||||||
class NestedTensor(object):
|
|
||||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
|
||||||
self.tensors = tensors
|
|
||||||
self.mask = mask
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
# type: (Device) -> NestedTensor # noqa
|
|
||||||
cast_tensor = self.tensors.to(device)
|
|
||||||
mask = self.mask
|
|
||||||
if mask is not None:
|
|
||||||
assert mask is not None
|
|
||||||
cast_mask = mask.to(device)
|
|
||||||
else:
|
|
||||||
cast_mask = None
|
|
||||||
return NestedTensor(cast_tensor, cast_mask)
|
|
||||||
|
|
||||||
def decompose(self):
|
|
||||||
return self.tensors, self.mask
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return str(self.tensors)
|
|
||||||
|
|
||||||
|
|
||||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
|
||||||
# TODO make this more general
|
|
||||||
if tensor_list[0].ndim == 3:
|
|
||||||
if torchvision._is_tracing():
|
|
||||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
|
||||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
|
||||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
|
||||||
|
|
||||||
# TODO make it support different-sized images
|
|
||||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
|
||||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
|
||||||
batch_shape = [len(tensor_list)] + max_size
|
|
||||||
b, c, h, w = batch_shape
|
|
||||||
dtype = tensor_list[0].dtype
|
|
||||||
device = tensor_list[0].device
|
|
||||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
|
||||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
|
||||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
|
||||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
|
||||||
m[: img.shape[1], :img.shape[2]] = False
|
|
||||||
else:
|
|
||||||
raise ValueError('not supported')
|
|
||||||
return NestedTensor(tensor, mask)
|
|
||||||
|
|
||||||
|
|
||||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
|
||||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
|
||||||
@torch.jit.unused
|
|
||||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
|
||||||
max_size = []
|
|
||||||
for i in range(tensor_list[0].dim()):
|
|
||||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
|
||||||
max_size.append(max_size_i)
|
|
||||||
max_size = tuple(max_size)
|
|
||||||
|
|
||||||
# work around for
|
|
||||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
|
||||||
# m[: img.shape[1], :img.shape[2]] = False
|
|
||||||
# which is not yet supported in onnx
|
|
||||||
padded_imgs = []
|
|
||||||
padded_masks = []
|
|
||||||
for img in tensor_list:
|
|
||||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
|
||||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
|
||||||
padded_imgs.append(padded_img)
|
|
||||||
|
|
||||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
|
||||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
|
||||||
padded_masks.append(padded_mask.to(torch.bool))
|
|
||||||
|
|
||||||
tensor = torch.stack(padded_imgs)
|
|
||||||
mask = torch.stack(padded_masks)
|
|
||||||
|
|
||||||
return NestedTensor(tensor, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_for_distributed(is_master):
|
|
||||||
"""
|
|
||||||
This function disables printing when not in master process
|
|
||||||
"""
|
|
||||||
import builtins as __builtin__
|
|
||||||
builtin_print = __builtin__.print
|
|
||||||
|
|
||||||
def print(*args, **kwargs):
|
|
||||||
force = kwargs.pop('force', False)
|
|
||||||
if is_master or force:
|
|
||||||
builtin_print(*args, **kwargs)
|
|
||||||
|
|
||||||
__builtin__.print = print
|
|
||||||
|
|
||||||
|
|
||||||
def is_dist_avail_and_initialized():
|
|
||||||
if not dist.is_available():
|
|
||||||
return False
|
|
||||||
if not dist.is_initialized():
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def get_world_size():
|
|
||||||
if not is_dist_avail_and_initialized():
|
|
||||||
return 1
|
|
||||||
return dist.get_world_size()
|
|
||||||
|
|
||||||
|
|
||||||
def get_rank():
|
|
||||||
if not is_dist_avail_and_initialized():
|
|
||||||
return 0
|
|
||||||
return dist.get_rank()
|
|
||||||
|
|
||||||
|
|
||||||
def is_main_process():
|
|
||||||
return get_rank() == 0
|
|
||||||
|
|
||||||
|
|
||||||
def save_on_master(*args, **kwargs):
|
|
||||||
if is_main_process():
|
|
||||||
torch.save(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def init_distributed_mode(args):
|
|
||||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
|
||||||
args.rank = int(os.environ["RANK"])
|
|
||||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
|
||||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
|
||||||
elif 'SLURM_PROCID' in os.environ:
|
|
||||||
args.rank = int(os.environ['SLURM_PROCID'])
|
|
||||||
args.gpu = args.rank % torch.cuda.device_count()
|
|
||||||
else:
|
|
||||||
print('Not using distributed mode')
|
|
||||||
args.distributed = False
|
|
||||||
return
|
|
||||||
|
|
||||||
args.distributed = True
|
|
||||||
|
|
||||||
torch.cuda.set_device(args.gpu)
|
|
||||||
args.dist_backend = 'nccl'
|
|
||||||
print('| distributed init (rank {}): {}'.format(
|
|
||||||
args.rank, args.dist_url), flush=True)
|
|
||||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
|
||||||
world_size=args.world_size, rank=args.rank)
|
|
||||||
torch.distributed.barrier()
|
|
||||||
setup_for_distributed(args.rank == 0)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def accuracy(output, target, topk=(1,)):
|
|
||||||
"""Computes the precision@k for the specified values of k"""
|
|
||||||
if target.numel() == 0:
|
|
||||||
return [torch.zeros([], device=output.device)]
|
|
||||||
maxk = max(topk)
|
|
||||||
batch_size = target.size(0)
|
|
||||||
|
|
||||||
_, pred = output.topk(maxk, 1, True, True)
|
|
||||||
pred = pred.t()
|
|
||||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
||||||
|
|
||||||
res = []
|
|
||||||
for k in topk:
|
|
||||||
correct_k = correct[:k].view(-1).float().sum(0)
|
|
||||||
res.append(correct_k.mul_(100.0 / batch_size))
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
|
||||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
|
||||||
"""
|
|
||||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
|
||||||
This will eventually be supported natively by PyTorch, and this
|
|
||||||
class can go away.
|
|
||||||
"""
|
|
||||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
|
||||||
if input.numel() > 0:
|
|
||||||
return torch.nn.functional.interpolate(
|
|
||||||
input, size, scale_factor, mode, align_corners
|
|
||||||
)
|
|
||||||
|
|
||||||
output_shape = _output_size(2, input, size, scale_factor)
|
|
||||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
|
||||||
return _new_empty_tensor(input, output_shape)
|
|
||||||
else:
|
|
||||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
"""
|
|
||||||
Plotting utilities to visualize training logs.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import seaborn as sns
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
from pathlib import Path, PurePath
|
|
||||||
|
|
||||||
|
|
||||||
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
|
|
||||||
'''
|
|
||||||
Function to plot specific fields from training log(s). Plots both training and test results.
|
|
||||||
|
|
||||||
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
|
|
||||||
- fields = which results to plot from each log file - plots both training and test for each field.
|
|
||||||
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
|
|
||||||
- log_name = optional, name of log file if different than default 'log.txt'.
|
|
||||||
|
|
||||||
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
|
|
||||||
- solid lines are training results, dashed lines are test results.
|
|
||||||
|
|
||||||
'''
|
|
||||||
func_name = "plot_utils.py::plot_logs"
|
|
||||||
|
|
||||||
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
|
|
||||||
# convert single Path to list to avoid 'not iterable' error
|
|
||||||
|
|
||||||
if not isinstance(logs, list):
|
|
||||||
if isinstance(logs, PurePath):
|
|
||||||
logs = [logs]
|
|
||||||
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
|
|
||||||
Expect list[Path] or single Path obj, received {type(logs)}")
|
|
||||||
|
|
||||||
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
|
|
||||||
for i, dir in enumerate(logs):
|
|
||||||
if not isinstance(dir, PurePath):
|
|
||||||
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
|
|
||||||
if not dir.exists():
|
|
||||||
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
|
|
||||||
# verify log_name exists
|
|
||||||
fn = Path(dir / log_name)
|
|
||||||
if not fn.exists():
|
|
||||||
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
|
|
||||||
print(f"--> full path of missing log file: {fn}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# load log file(s) and plot
|
|
||||||
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
|
|
||||||
|
|
||||||
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
|
|
||||||
|
|
||||||
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
|
|
||||||
for j, field in enumerate(fields):
|
|
||||||
if field == 'mAP':
|
|
||||||
coco_eval = pd.DataFrame(
|
|
||||||
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
|
|
||||||
).ewm(com=ewm_col).mean()
|
|
||||||
axs[j].plot(coco_eval, c=color)
|
|
||||||
else:
|
|
||||||
df.interpolate().ewm(com=ewm_col).mean().plot(
|
|
||||||
y=[f'train_{field}', f'test_{field}'],
|
|
||||||
ax=axs[j],
|
|
||||||
color=[color] * 2,
|
|
||||||
style=['-', '--']
|
|
||||||
)
|
|
||||||
for ax, field in zip(axs, fields):
|
|
||||||
ax.legend([Path(p).name for p in logs])
|
|
||||||
ax.set_title(field)
|
|
||||||
|
|
||||||
|
|
||||||
def plot_precision_recall(files, naming_scheme='iter'):
|
|
||||||
if naming_scheme == 'exp_id':
|
|
||||||
# name becomes exp_id
|
|
||||||
names = [f.parts[-3] for f in files]
|
|
||||||
elif naming_scheme == 'iter':
|
|
||||||
names = [f.stem for f in files]
|
|
||||||
else:
|
|
||||||
raise ValueError(f'not supported {naming_scheme}')
|
|
||||||
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
|
|
||||||
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
|
|
||||||
data = torch.load(f)
|
|
||||||
# precision is n_iou, n_points, n_cat, n_area, max_det
|
|
||||||
precision = data['precision']
|
|
||||||
recall = data['params'].recThrs
|
|
||||||
scores = data['scores']
|
|
||||||
# take precision for all classes, all areas and 100 detections
|
|
||||||
precision = precision[0, :, :, 0, -1].mean(1)
|
|
||||||
scores = scores[0, :, :, 0, -1].mean(1)
|
|
||||||
prec = precision.mean()
|
|
||||||
rec = data['recall'][0, :, 0, -1].mean()
|
|
||||||
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
|
|
||||||
f'score={scores.mean():0.3f}, ' +
|
|
||||||
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
|
|
||||||
)
|
|
||||||
axs[0].plot(recall, precision, c=color)
|
|
||||||
axs[1].plot(recall, scores, c=color)
|
|
||||||
|
|
||||||
axs[0].set_title('Precision / Recall')
|
|
||||||
axs[0].legend(names)
|
|
||||||
axs[1].set_title('Scores / Recall')
|
|
||||||
axs[1].legend(names)
|
|
||||||
return fig, axs
|
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
"""
|
|
||||||
GR00T (diffusion-based DiT policy) model builder.
|
|
||||||
|
|
||||||
This module provides functions to build GR00T models and optimizers
|
|
||||||
from configuration dictionaries (typically from config.yaml's 'gr00t:' section).
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from .models import build_gr00t_model
|
|
||||||
|
|
||||||
|
|
||||||
def get_args_parser():
|
|
||||||
"""
|
|
||||||
Create argument parser for GR00T model configuration.
|
|
||||||
|
|
||||||
All parameters can be overridden via args_override dictionary in
|
|
||||||
build_gr00t_model_and_optimizer(). This allows loading from config.yaml.
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser('GR00T training and evaluation script', add_help=False)
|
|
||||||
|
|
||||||
# Training parameters
|
|
||||||
parser.add_argument('--lr', default=1e-5, type=float,
|
|
||||||
help='Learning rate for main parameters')
|
|
||||||
parser.add_argument('--lr_backbone', default=1e-5, type=float,
|
|
||||||
help='Learning rate for backbone parameters')
|
|
||||||
parser.add_argument('--weight_decay', default=1e-4, type=float,
|
|
||||||
help='Weight decay for optimizer')
|
|
||||||
|
|
||||||
# GR00T model architecture parameters
|
|
||||||
parser.add_argument('--embed_dim', default=1536, type=int,
|
|
||||||
help='Embedding dimension for transformer')
|
|
||||||
parser.add_argument('--hidden_dim', default=1024, type=int,
|
|
||||||
help='Hidden dimension for MLP layers')
|
|
||||||
parser.add_argument('--state_dim', default=16, type=int,
|
|
||||||
help='State (qpos) dimension')
|
|
||||||
parser.add_argument('--action_dim', default=16, type=int,
|
|
||||||
help='Action dimension')
|
|
||||||
parser.add_argument('--num_queries', default=16, type=int,
|
|
||||||
help='Number of action queries (chunk size)')
|
|
||||||
|
|
||||||
# DiT (Diffusion Transformer) parameters
|
|
||||||
parser.add_argument('--num_layers', default=16, type=int,
|
|
||||||
help='Number of transformer layers')
|
|
||||||
parser.add_argument('--nheads', default=32, type=int,
|
|
||||||
help='Number of attention heads')
|
|
||||||
parser.add_argument('--mlp_ratio', default=4, type=float,
|
|
||||||
help='MLP hidden dimension ratio')
|
|
||||||
parser.add_argument('--dropout', default=0.2, type=float,
|
|
||||||
help='Dropout rate')
|
|
||||||
|
|
||||||
# Backbone parameters
|
|
||||||
parser.add_argument('--backbone', default='dino_v2', type=str,
|
|
||||||
help='Backbone architecture (dino_v2, resnet18, resnet34)')
|
|
||||||
parser.add_argument('--position_embedding', default='sine', type=str,
|
|
||||||
choices=('sine', 'learned'),
|
|
||||||
help='Type of positional encoding')
|
|
||||||
|
|
||||||
# Camera configuration
|
|
||||||
parser.add_argument('--camera_names', default=[], nargs='+',
|
|
||||||
help='List of camera names for observations')
|
|
||||||
|
|
||||||
# Other parameters (not directly used but kept for compatibility)
|
|
||||||
parser.add_argument('--batch_size', default=15, type=int)
|
|
||||||
parser.add_argument('--epochs', default=20000, type=int)
|
|
||||||
parser.add_argument('--masks', action='store_true',
|
|
||||||
help='Use intermediate layer features')
|
|
||||||
parser.add_argument('--dilation', action='store_false',
|
|
||||||
help='Use dilated convolution in backbone')
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def build_gr00t_model_and_optimizer(args_override):
|
|
||||||
"""
|
|
||||||
Build GR00T model and optimizer from config dictionary.
|
|
||||||
|
|
||||||
This function is designed to work with config.yaml loading:
|
|
||||||
1. Parse default arguments
|
|
||||||
2. Override with values from args_override (typically from config['gr00t'])
|
|
||||||
3. Build model and optimizer
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args_override: Dictionary of config values, typically from config.yaml's 'gr00t:' section
|
|
||||||
Expected keys: embed_dim, hidden_dim, state_dim, action_dim,
|
|
||||||
num_queries, nheads, mlp_ratio, dropout, num_layers,
|
|
||||||
lr, lr_backbone, camera_names, backbone, etc.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
model: GR00T model on CUDA
|
|
||||||
optimizer: AdamW optimizer with separate learning rates for backbone and other params
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser('GR00T training and evaluation script',
|
|
||||||
parents=[get_args_parser()])
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Override with config values
|
|
||||||
for k, v in args_override.items():
|
|
||||||
setattr(args, k, v)
|
|
||||||
|
|
||||||
# Build model
|
|
||||||
model = build_gr00t_model(args)
|
|
||||||
model.cuda()
|
|
||||||
|
|
||||||
# Create parameter groups with different learning rates
|
|
||||||
param_dicts = [
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters()
|
|
||||||
if "backbone" not in n and p.requires_grad]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters()
|
|
||||||
if "backbone" in n and p.requires_grad],
|
|
||||||
"lr": args.lr_backbone,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(param_dicts,
|
|
||||||
lr=args.lr,
|
|
||||||
weight_decay=args.weight_decay)
|
|
||||||
|
|
||||||
return model, optimizer
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
from .gr00t import build_gr00t_model
|
|
||||||
|
|
||||||
__all__ = ['build_gr00t_model']
|
|
||||||
@@ -1,168 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
"""
|
|
||||||
Backbone modules.
|
|
||||||
"""
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchvision
|
|
||||||
from torch import nn
|
|
||||||
from torchvision.models._utils import IntermediateLayerGetter
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from util.misc import NestedTensor, is_main_process
|
|
||||||
|
|
||||||
from .position_encoding import build_position_encoding
|
|
||||||
|
|
||||||
class FrozenBatchNorm2d(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
||||||
|
|
||||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
|
||||||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
|
||||||
produce nans.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, n):
|
|
||||||
super(FrozenBatchNorm2d, self).__init__()
|
|
||||||
self.register_buffer("weight", torch.ones(n))
|
|
||||||
self.register_buffer("bias", torch.zeros(n))
|
|
||||||
self.register_buffer("running_mean", torch.zeros(n))
|
|
||||||
self.register_buffer("running_var", torch.ones(n))
|
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
||||||
missing_keys, unexpected_keys, error_msgs):
|
|
||||||
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
|
||||||
if num_batches_tracked_key in state_dict:
|
|
||||||
del state_dict[num_batches_tracked_key]
|
|
||||||
|
|
||||||
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
|
||||||
state_dict, prefix, local_metadata, strict,
|
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# move reshapes to the beginning
|
|
||||||
# to make it fuser-friendly
|
|
||||||
w = self.weight.reshape(1, -1, 1, 1)
|
|
||||||
b = self.bias.reshape(1, -1, 1, 1)
|
|
||||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
|
||||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
|
||||||
eps = 1e-5
|
|
||||||
scale = w * (rv + eps).rsqrt()
|
|
||||||
bias = b - rm * scale
|
|
||||||
return x * scale + bias
|
|
||||||
|
|
||||||
|
|
||||||
class BackboneBase(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
|
|
||||||
super().__init__()
|
|
||||||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
|
||||||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
|
||||||
# parameter.requires_grad_(False)
|
|
||||||
if return_interm_layers:
|
|
||||||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
|
||||||
else:
|
|
||||||
return_layers = {'layer4': "0"}
|
|
||||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
|
||||||
self.num_channels = num_channels
|
|
||||||
|
|
||||||
def forward(self, tensor):
|
|
||||||
xs = self.body(tensor)
|
|
||||||
return xs
|
|
||||||
# out: Dict[str, NestedTensor] = {}
|
|
||||||
# for name, x in xs.items():
|
|
||||||
# m = tensor_list.mask
|
|
||||||
# assert m is not None
|
|
||||||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
|
||||||
# out[name] = NestedTensor(x, mask)
|
|
||||||
# return out
|
|
||||||
|
|
||||||
|
|
||||||
class Backbone(BackboneBase):
|
|
||||||
"""ResNet backbone with frozen BatchNorm."""
|
|
||||||
def __init__(self, name: str,
|
|
||||||
train_backbone: bool,
|
|
||||||
return_interm_layers: bool,
|
|
||||||
dilation: bool):
|
|
||||||
backbone = getattr(torchvision.models, name)(
|
|
||||||
replace_stride_with_dilation=[False, False, dilation],
|
|
||||||
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
|
|
||||||
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
|
||||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
|
||||||
|
|
||||||
|
|
||||||
# class DINOv2BackBone(nn.Module):
|
|
||||||
# def __init__(self) -> None:
|
|
||||||
# super().__init__()
|
|
||||||
# self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
|
||||||
# self.body.eval()
|
|
||||||
# self.num_channels = 384
|
|
||||||
|
|
||||||
# @torch.no_grad()
|
|
||||||
# def forward(self, tensor):
|
|
||||||
# xs = self.body.forward_features(tensor)["x_norm_patchtokens"]
|
|
||||||
# od = OrderedDict()
|
|
||||||
# od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
|
||||||
# return od
|
|
||||||
|
|
||||||
class DINOv2BackBone(nn.Module):
|
|
||||||
def __init__(self, return_interm_layers: bool = False) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
|
||||||
self.body.eval()
|
|
||||||
self.num_channels = 384
|
|
||||||
self.return_interm_layers = return_interm_layers
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(self, tensor):
|
|
||||||
features = self.body.forward_features(tensor)
|
|
||||||
|
|
||||||
if self.return_interm_layers:
|
|
||||||
|
|
||||||
layer1 = features["x_norm_patchtokens"]
|
|
||||||
layer2 = features["x_norm_patchtokens"]
|
|
||||||
layer3 = features["x_norm_patchtokens"]
|
|
||||||
layer4 = features["x_norm_patchtokens"]
|
|
||||||
|
|
||||||
od = OrderedDict()
|
|
||||||
od["0"] = layer1.reshape(layer1.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
|
||||||
od["1"] = layer2.reshape(layer2.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
|
||||||
od["2"] = layer3.reshape(layer3.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
|
||||||
od["3"] = layer4.reshape(layer4.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
|
||||||
return od
|
|
||||||
else:
|
|
||||||
xs = features["x_norm_patchtokens"]
|
|
||||||
od = OrderedDict()
|
|
||||||
od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
|
||||||
return od
|
|
||||||
|
|
||||||
class Joiner(nn.Sequential):
|
|
||||||
def __init__(self, backbone, position_embedding):
|
|
||||||
super().__init__(backbone, position_embedding)
|
|
||||||
|
|
||||||
def forward(self, tensor_list: NestedTensor):
|
|
||||||
xs = self[0](tensor_list)
|
|
||||||
out: List[NestedTensor] = []
|
|
||||||
pos = []
|
|
||||||
for name, x in xs.items():
|
|
||||||
out.append(x)
|
|
||||||
# position encoding
|
|
||||||
pos.append(self[1](x).to(x.dtype))
|
|
||||||
|
|
||||||
return out, pos
|
|
||||||
|
|
||||||
|
|
||||||
def build_backbone(args):
|
|
||||||
position_embedding = build_position_encoding(args)
|
|
||||||
train_backbone = args.lr_backbone > 0
|
|
||||||
return_interm_layers = args.masks
|
|
||||||
if args.backbone == 'dino_v2':
|
|
||||||
backbone = DINOv2BackBone()
|
|
||||||
else:
|
|
||||||
assert args.backbone in ['resnet18', 'resnet34']
|
|
||||||
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
|
||||||
model = Joiner(backbone, position_embedding)
|
|
||||||
model.num_channels = backbone.num_channels
|
|
||||||
return model
|
|
||||||
@@ -1,142 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from diffusers import ConfigMixin, ModelMixin
|
|
||||||
from diffusers.configuration_utils import register_to_config
|
|
||||||
from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
class TimestepEncoder(nn.Module):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
embedding_dim = args.embed_dim
|
|
||||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
|
||||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
|
||||||
|
|
||||||
def forward(self, timesteps):
|
|
||||||
dtype = next(self.parameters()).dtype
|
|
||||||
timesteps_proj = self.time_proj(timesteps).to(dtype)
|
|
||||||
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
|
|
||||||
return timesteps_emb
|
|
||||||
|
|
||||||
|
|
||||||
class AdaLayerNorm(nn.Module):
|
|
||||||
def __init__(self, embedding_dim, norm_eps=1e-5, norm_elementwise_affine=False):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
output_dim = embedding_dim * 2
|
|
||||||
self.silu = nn.SiLU()
|
|
||||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
|
||||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
temb: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
temb = self.linear(self.silu(temb))
|
|
||||||
scale, shift = temb.chunk(2, dim=1)
|
|
||||||
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
|
||||||
def __init__(self, args, crosss_attention_dim, use_self_attn=False):
|
|
||||||
super().__init__()
|
|
||||||
dim = args.embed_dim
|
|
||||||
num_heads = args.nheads
|
|
||||||
mlp_ratio = args.mlp_ratio
|
|
||||||
dropout = args.dropout
|
|
||||||
self.norm1 = AdaLayerNorm(dim)
|
|
||||||
|
|
||||||
if not use_self_attn:
|
|
||||||
self.attn = nn.MultiheadAttention(
|
|
||||||
embed_dim=dim,
|
|
||||||
num_heads=num_heads,
|
|
||||||
dropout=dropout,
|
|
||||||
kdim=crosss_attention_dim,
|
|
||||||
vdim=crosss_attention_dim,
|
|
||||||
batch_first=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.attn = nn.MultiheadAttention(
|
|
||||||
embed_dim=dim,
|
|
||||||
num_heads=num_heads,
|
|
||||||
dropout=dropout,
|
|
||||||
batch_first=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
|
||||||
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
nn.Linear(dim, dim * mlp_ratio),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Dropout(dropout),
|
|
||||||
nn.Linear(dim * mlp_ratio, dim),
|
|
||||||
nn.Dropout(dropout)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, temb, context=None):
|
|
||||||
norm_hidden_states = self.norm1(hidden_states, temb)
|
|
||||||
|
|
||||||
attn_output = self.attn(
|
|
||||||
norm_hidden_states,
|
|
||||||
context if context is not None else norm_hidden_states,
|
|
||||||
context if context is not None else norm_hidden_states,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
hidden_states = attn_output + hidden_states
|
|
||||||
|
|
||||||
norm_hidden_states = self.norm2(hidden_states)
|
|
||||||
|
|
||||||
ff_output = self.mlp(norm_hidden_states)
|
|
||||||
|
|
||||||
hidden_states = ff_output + hidden_states
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
class DiT(nn.Module):
|
|
||||||
def __init__(self, args, cross_attention_dim):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = args.embed_dim
|
|
||||||
num_layers = args.num_layers
|
|
||||||
output_dim = args.hidden_dim
|
|
||||||
|
|
||||||
self.timestep_encoder = TimestepEncoder(args)
|
|
||||||
|
|
||||||
all_blocks = []
|
|
||||||
for idx in range(num_layers):
|
|
||||||
use_self_attn = idx % 2 == 1
|
|
||||||
if use_self_attn:
|
|
||||||
block = BasicTransformerBlock(args, crosss_attention_dim=None, use_self_attn=True)
|
|
||||||
else:
|
|
||||||
block = BasicTransformerBlock(args, crosss_attention_dim=cross_attention_dim, use_self_attn=False)
|
|
||||||
all_blocks.append(block)
|
|
||||||
|
|
||||||
self.transformer_blocks = nn.ModuleList(all_blocks)
|
|
||||||
|
|
||||||
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
|
|
||||||
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
|
||||||
self.proj_out_2 = nn.Linear(inner_dim, output_dim)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, timestep, encoder_hidden_states):
|
|
||||||
temb = self.timestep_encoder(timestep)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.contiguous()
|
|
||||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
||||||
|
|
||||||
for idx, block in enumerate(self.transformer_blocks):
|
|
||||||
if idx % 2 == 1:
|
|
||||||
hidden_states = block(hidden_states, temb)
|
|
||||||
else:
|
|
||||||
hidden_states = block(hidden_states, temb, context=encoder_hidden_states)
|
|
||||||
|
|
||||||
conditioning = temb
|
|
||||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
|
||||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
|
||||||
return self.proj_out_2(hidden_states)
|
|
||||||
|
|
||||||
|
|
||||||
def build_dit(args, cross_attention_dim):
|
|
||||||
return DiT(args, cross_attention_dim)
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
|
|
||||||
from .modules import (
|
|
||||||
build_action_decoder,
|
|
||||||
build_action_encoder,
|
|
||||||
build_state_encoder,
|
|
||||||
build_time_sampler,
|
|
||||||
build_noise_scheduler,
|
|
||||||
)
|
|
||||||
from .backbone import build_backbone
|
|
||||||
from .dit import build_dit
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
class gr00t(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
backbones,
|
|
||||||
dit,
|
|
||||||
state_encoder,
|
|
||||||
action_encoder,
|
|
||||||
action_decoder,
|
|
||||||
time_sampler,
|
|
||||||
noise_scheduler,
|
|
||||||
num_queries,
|
|
||||||
camera_names,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.num_queries = num_queries
|
|
||||||
self.camera_names = camera_names
|
|
||||||
self.dit = dit
|
|
||||||
self.state_encoder = state_encoder
|
|
||||||
self.action_encoder = action_encoder
|
|
||||||
self.action_decoder = action_decoder
|
|
||||||
self.time_sampler = time_sampler
|
|
||||||
self.noise_scheduler = noise_scheduler
|
|
||||||
|
|
||||||
if backbones is not None:
|
|
||||||
self.backbones = nn.ModuleList(backbones)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def forward(self, qpos, image, actions=None, is_pad=None):
|
|
||||||
is_training = actions is not None # train or val
|
|
||||||
bs, _ = qpos.shape
|
|
||||||
|
|
||||||
all_cam_features = []
|
|
||||||
for cam_id, cam_name in enumerate(self.camera_names):
|
|
||||||
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
|
||||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
|
||||||
features = features[0] # take the last layer feature
|
|
||||||
B, C, H, W = features.shape
|
|
||||||
features_seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)
|
|
||||||
all_cam_features.append(features_seq)
|
|
||||||
encoder_hidden_states = torch.cat(all_cam_features, dim=1)
|
|
||||||
|
|
||||||
state_features = self.state_encoder(qpos) # [B, 1, emb_dim]
|
|
||||||
|
|
||||||
if is_training:
|
|
||||||
# training logic
|
|
||||||
|
|
||||||
timesteps = self.time_sampler(bs, actions.device, actions.dtype)
|
|
||||||
noisy_actions, target_velocity = self.noise_scheduler.add_noise(
|
|
||||||
actions, timesteps
|
|
||||||
)
|
|
||||||
t_discretized = (timesteps[:, 0, 0] * 1000).long()
|
|
||||||
action_features = self.action_encoder(noisy_actions, t_discretized)
|
|
||||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
|
||||||
model_output = self.dit(sa_embs, t_discretized, encoder_hidden_states)
|
|
||||||
pred = self.action_decoder(model_output)
|
|
||||||
pred_actions = pred[:, -actions.shape[1] :]
|
|
||||||
action_loss = F.mse_loss(pred_actions, target_velocity, reduction='none')
|
|
||||||
return pred_actions, action_loss
|
|
||||||
else:
|
|
||||||
actions = torch.randn(bs, self.num_queries, qpos.shape[-1], device=qpos.device, dtype=qpos.dtype)
|
|
||||||
k = 5
|
|
||||||
dt = 1.0 / k
|
|
||||||
for t in range(k):
|
|
||||||
t_cont = t / float(k)
|
|
||||||
t_discretized = int(t_cont * 1000)
|
|
||||||
timesteps = torch.full((bs,), t_discretized, device=qpos.device, dtype=qpos.dtype)
|
|
||||||
action_features = self.action_encoder(actions, timesteps)
|
|
||||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
|
||||||
# Create tensor of shape [B] for DiT (consistent with training path)
|
|
||||||
model_output = self.dit(sa_embs, timesteps, encoder_hidden_states)
|
|
||||||
pred = self.action_decoder(model_output)
|
|
||||||
pred_velocity = pred[:, -self.num_queries :]
|
|
||||||
actions = actions + pred_velocity * dt
|
|
||||||
return actions, _
|
|
||||||
def build_gr00t_model(args):
|
|
||||||
state_dim = args.state_dim
|
|
||||||
action_dim = args.action_dim
|
|
||||||
|
|
||||||
backbones = []
|
|
||||||
for _ in args.camera_names:
|
|
||||||
backbone = build_backbone(args)
|
|
||||||
backbones.append(backbone)
|
|
||||||
|
|
||||||
cross_attention_dim = backbones[0].num_channels
|
|
||||||
|
|
||||||
dit = build_dit(args, cross_attention_dim)
|
|
||||||
|
|
||||||
state_encoder = build_state_encoder(args)
|
|
||||||
action_encoder = build_action_encoder(args)
|
|
||||||
action_decoder = build_action_decoder(args)
|
|
||||||
time_sampler = build_time_sampler(args)
|
|
||||||
noise_scheduler = build_noise_scheduler(args)
|
|
||||||
model = gr00t(
|
|
||||||
backbones,
|
|
||||||
dit,
|
|
||||||
state_encoder,
|
|
||||||
action_encoder,
|
|
||||||
action_decoder,
|
|
||||||
time_sampler,
|
|
||||||
noise_scheduler,
|
|
||||||
args.num_queries,
|
|
||||||
args.camera_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
print("number of parameters: %.2fM" % (n_parameters/1e6,))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,179 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
# ActionEncoder
|
|
||||||
class SinusoidalPositionalEncoding(nn.Module):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
self.embed_dim = args.embed_dim
|
|
||||||
|
|
||||||
def forward(self, timesteps):
|
|
||||||
timesteps = timesteps.float()
|
|
||||||
B, T = timesteps.shape
|
|
||||||
device = timesteps.device
|
|
||||||
|
|
||||||
half_dim = self.embed_dim // 2
|
|
||||||
|
|
||||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
|
||||||
torch.log(torch.tensor(10000.0)) / half_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
|
||||||
|
|
||||||
sin = torch.sin(freqs)
|
|
||||||
cos = torch.cos(freqs)
|
|
||||||
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
|
||||||
|
|
||||||
return enc
|
|
||||||
|
|
||||||
|
|
||||||
class ActionEncoder(nn.Module):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
action_dim = args.action_dim
|
|
||||||
embed_dim = args.embed_dim
|
|
||||||
|
|
||||||
self.W1 = nn.Linear(action_dim, embed_dim)
|
|
||||||
self.W2 = nn.Linear(2 * embed_dim, embed_dim)
|
|
||||||
self.W3 = nn.Linear(embed_dim, embed_dim)
|
|
||||||
|
|
||||||
self.pos_encoder = SinusoidalPositionalEncoding(args)
|
|
||||||
|
|
||||||
def forward(self, actions, timesteps):
|
|
||||||
B, T, _ = actions.shape
|
|
||||||
|
|
||||||
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
|
||||||
# so that shape => (B, T)
|
|
||||||
# Handle different input shapes: (B,), (B, 1), (B, 1, 1)
|
|
||||||
# Reshape to (B,) then expand to (B, T)
|
|
||||||
# if timesteps.dim() == 3:
|
|
||||||
# # Shape (B, 1, 1) or (B, T, 1) -> (B,)
|
|
||||||
# timesteps = timesteps[:, 0, 0]
|
|
||||||
# elif timesteps.dim() == 2:
|
|
||||||
# # Shape (B, 1) or (B, T) -> take first element if needed
|
|
||||||
# if timesteps.shape[1] == 1:
|
|
||||||
# timesteps = timesteps[:, 0]
|
|
||||||
# # else: already (B, T), use as is
|
|
||||||
# elif timesteps.dim() != 1:
|
|
||||||
# raise ValueError(
|
|
||||||
# f"Expected `timesteps` to have shape (B,), (B, 1), or (B, 1, 1), got {timesteps.shape}"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# Now timesteps should be (B,), expand to (B, T)
|
|
||||||
if timesteps.dim() == 1 and timesteps.shape[0] == B:
|
|
||||||
timesteps = timesteps.unsqueeze(1).expand(-1, T)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Expected `timesteps` to have shape (B,) so we can replicate across T."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2) Standard action MLP step for shape => (B, T, w)
|
|
||||||
a_emb = self.W1(actions)
|
|
||||||
|
|
||||||
# 3) Get the sinusoidal encoding (B, T, w)
|
|
||||||
tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype)
|
|
||||||
|
|
||||||
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
|
||||||
x = torch.cat([a_emb, tau_emb], dim=-1)
|
|
||||||
x = F.silu(self.W2(x))
|
|
||||||
|
|
||||||
# 5) Finally W3 => (B, T, w)
|
|
||||||
x = self.W3(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def build_action_encoder(args):
|
|
||||||
return ActionEncoder(args)
|
|
||||||
|
|
||||||
|
|
||||||
# StateEncoder
|
|
||||||
class StateEncoder(nn.Module):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
input_dim = args.state_dim
|
|
||||||
hidden_dim = args.hidden_dim
|
|
||||||
output_dim = args.embed_dim
|
|
||||||
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
nn.Linear(input_dim, hidden_dim),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_dim, output_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, states):
|
|
||||||
state_emb = self.mlp(states) # [B, emb_dim]
|
|
||||||
state_emb = state_emb.unsqueeze(1)
|
|
||||||
return state_emb # [B, 1, emb_dim]
|
|
||||||
|
|
||||||
|
|
||||||
def build_state_encoder(args):
|
|
||||||
return StateEncoder(args)
|
|
||||||
|
|
||||||
|
|
||||||
# ActionDecoder
|
|
||||||
class ActionDecoder(nn.Module):
|
|
||||||
def __init__(self,args):
|
|
||||||
super().__init__()
|
|
||||||
input_dim = args.hidden_dim
|
|
||||||
hidden_dim = args.hidden_dim
|
|
||||||
output_dim = args.action_dim
|
|
||||||
|
|
||||||
self.num_queries = args.num_queries
|
|
||||||
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
nn.Linear(input_dim, hidden_dim),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_dim, output_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, model_output):
|
|
||||||
pred_actions = self.mlp(model_output)
|
|
||||||
return pred_actions[:, -self.num_queries:]
|
|
||||||
|
|
||||||
|
|
||||||
def build_action_decoder(args):
|
|
||||||
return ActionDecoder(args)
|
|
||||||
|
|
||||||
|
|
||||||
# TimeSampler
|
|
||||||
class TimeSampler(nn.Module):
|
|
||||||
def __init__(self, noise_s = 0.999, noise_beta_alpha=1.5, noise_beta_beta=1.0):
|
|
||||||
super().__init__()
|
|
||||||
self.noise_s = noise_s
|
|
||||||
self.beta_dist = torch.distributions.Beta(noise_beta_alpha, noise_beta_beta)
|
|
||||||
|
|
||||||
def forward(self, batch_size, device, dtype):
|
|
||||||
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
|
||||||
sample = (1 - sample) * self.noise_s
|
|
||||||
return sample[:, None, None]
|
|
||||||
|
|
||||||
|
|
||||||
def build_time_sampler(args):
|
|
||||||
return TimeSampler()
|
|
||||||
|
|
||||||
|
|
||||||
# NoiseScheduler
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class FlowMatchingScheduler(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# --- 训练逻辑:加噪并计算目标 ---
|
|
||||||
def add_noise(self, actions, timesteps):
|
|
||||||
noise = torch.randn_like(actions)
|
|
||||||
noisy_samples = actions * timesteps + noise * (1 - timesteps)
|
|
||||||
target_velocity = actions - noise
|
|
||||||
|
|
||||||
return noisy_samples, target_velocity
|
|
||||||
|
|
||||||
# --- 推理逻辑:欧拉步 (Euler Step) ---
|
|
||||||
def step(self, model_output, sample, dt):
|
|
||||||
prev_sample = sample + model_output * dt
|
|
||||||
return prev_sample
|
|
||||||
|
|
||||||
def build_noise_scheduler(args):
|
|
||||||
return FlowMatchingScheduler()
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
"""
|
|
||||||
Various positional encodings for the transformer.
|
|
||||||
"""
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from util.misc import NestedTensor
|
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingSine(nn.Module):
|
|
||||||
"""
|
|
||||||
This is a more standard version of the position embedding, very similar to the one
|
|
||||||
used by the Attention is all you need paper, generalized to work on images.
|
|
||||||
"""
|
|
||||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
|
||||||
super().__init__()
|
|
||||||
self.num_pos_feats = num_pos_feats
|
|
||||||
self.temperature = temperature
|
|
||||||
self.normalize = normalize
|
|
||||||
if scale is not None and normalize is False:
|
|
||||||
raise ValueError("normalize should be True if scale is passed")
|
|
||||||
if scale is None:
|
|
||||||
scale = 2 * math.pi
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def forward(self, tensor):
|
|
||||||
x = tensor
|
|
||||||
# mask = tensor_list.mask
|
|
||||||
# assert mask is not None
|
|
||||||
# not_mask = ~mask
|
|
||||||
|
|
||||||
not_mask = torch.ones_like(x[0, [0]])
|
|
||||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
|
||||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
|
||||||
if self.normalize:
|
|
||||||
eps = 1e-6
|
|
||||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
||||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
||||||
|
|
||||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
|
||||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
|
||||||
|
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t
|
|
||||||
pos_y = y_embed[:, :, :, None] / dim_t
|
|
||||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
||||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
||||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
||||||
return pos
|
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingLearned(nn.Module):
|
|
||||||
"""
|
|
||||||
Absolute pos embedding, learned.
|
|
||||||
"""
|
|
||||||
def __init__(self, num_pos_feats=256):
|
|
||||||
super().__init__()
|
|
||||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
|
||||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
nn.init.uniform_(self.row_embed.weight)
|
|
||||||
nn.init.uniform_(self.col_embed.weight)
|
|
||||||
|
|
||||||
def forward(self, tensor_list: NestedTensor):
|
|
||||||
x = tensor_list.tensors
|
|
||||||
h, w = x.shape[-2:]
|
|
||||||
i = torch.arange(w, device=x.device)
|
|
||||||
j = torch.arange(h, device=x.device)
|
|
||||||
x_emb = self.col_embed(i)
|
|
||||||
y_emb = self.row_embed(j)
|
|
||||||
pos = torch.cat([
|
|
||||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
|
||||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
|
||||||
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
|
||||||
return pos
|
|
||||||
|
|
||||||
|
|
||||||
def build_position_encoding(args):
|
|
||||||
N_steps = args.hidden_dim // 2
|
|
||||||
if args.position_embedding in ('v2', 'sine'):
|
|
||||||
# TODO find a better way of exposing other arguments
|
|
||||||
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
|
||||||
elif args.position_embedding in ('v3', 'learned'):
|
|
||||||
position_embedding = PositionEmbeddingLearned(N_steps)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"not supported {args.position_embedding}")
|
|
||||||
|
|
||||||
return position_embedding
|
|
||||||
@@ -1,90 +0,0 @@
|
|||||||
"""
|
|
||||||
GR00T Policy wrapper for imitation learning.
|
|
||||||
|
|
||||||
This module provides the gr00tPolicy class that wraps the GR00T model
|
|
||||||
for training and evaluation in the imitation learning framework.
|
|
||||||
"""
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from torchvision.transforms import v2
|
|
||||||
import torch
|
|
||||||
from roboimi.gr00t.main import build_gr00t_model_and_optimizer
|
|
||||||
|
|
||||||
|
|
||||||
class gr00tPolicy(nn.Module):
|
|
||||||
"""
|
|
||||||
GR00T Policy for action prediction using diffusion-based DiT architecture.
|
|
||||||
|
|
||||||
This policy wraps the GR00T model and handles:
|
|
||||||
- Image resizing to match DINOv2 patch size requirements
|
|
||||||
- Image normalization (ImageNet stats)
|
|
||||||
- Training with action chunks and loss computation
|
|
||||||
- Inference with diffusion sampling
|
|
||||||
"""
|
|
||||||
def __init__(self, args_override):
|
|
||||||
super().__init__()
|
|
||||||
model, optimizer = build_gr00t_model_and_optimizer(args_override)
|
|
||||||
self.model = model
|
|
||||||
self.optimizer = optimizer
|
|
||||||
|
|
||||||
# DINOv2 requires image dimensions to be multiples of patch size (14)
|
|
||||||
# Common sizes: 224x224, 336x336, etc. (14*16=224, 14*24=336)
|
|
||||||
self.patch_h = 16 # Number of patches vertically
|
|
||||||
self.patch_w = 22 # Number of patches horizontally
|
|
||||||
target_size = (self.patch_h * 14, self.patch_w * 14) # (224, 308)
|
|
||||||
|
|
||||||
# Training transform with data augmentation
|
|
||||||
self.train_transform = v2.Compose([
|
|
||||||
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
|
|
||||||
v2.RandomPerspective(distortion_scale=0.5),
|
|
||||||
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
|
|
||||||
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
|
|
||||||
v2.Resize(target_size),
|
|
||||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
||||||
])
|
|
||||||
|
|
||||||
# Inference transform (no augmentation)
|
|
||||||
self.inference_transform = v2.Compose([
|
|
||||||
v2.Resize(target_size),
|
|
||||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
||||||
])
|
|
||||||
|
|
||||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
|
||||||
"""
|
|
||||||
Forward pass for training or inference.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
qpos: Joint positions [B, state_dim]
|
|
||||||
image: Camera images [B, num_cameras, C, H, W]
|
|
||||||
actions: Ground truth actions [B, chunk_size, action_dim] (training only)
|
|
||||||
is_pad: Padding mask [B, chunk_size] (training only)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Training: dict with 'mse' loss
|
|
||||||
Inference: predicted actions [B, num_queries, action_dim]
|
|
||||||
"""
|
|
||||||
# Apply transforms (resize + normalization)
|
|
||||||
if actions is not None: # training time
|
|
||||||
image = self.train_transform(image)
|
|
||||||
else: # inference time
|
|
||||||
image = self.inference_transform(image)
|
|
||||||
|
|
||||||
if actions is not None: # training time
|
|
||||||
actions = actions[:, :self.model.num_queries]
|
|
||||||
is_pad = is_pad[:, :self.model.num_queries]
|
|
||||||
_, action_loss = self.model(qpos, image, actions, is_pad)
|
|
||||||
|
|
||||||
# Mask out padded positions
|
|
||||||
mse_loss = (action_loss * ~is_pad.unsqueeze(-1)).mean()
|
|
||||||
|
|
||||||
loss_dict = {
|
|
||||||
'loss': mse_loss
|
|
||||||
}
|
|
||||||
return loss_dict
|
|
||||||
else: # inference time
|
|
||||||
a_hat, _ = self.model(qpos, image)
|
|
||||||
return a_hat
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
"""Return the optimizer for training."""
|
|
||||||
return self.optimizer
|
|
||||||
Reference in New Issue
Block a user