chore: 导入gr00t
This commit is contained in:
124
gr00t/models/gr00t.py
Normal file
124
gr00t/models/gr00t.py
Normal file
@@ -0,0 +1,124 @@
|
||||
|
||||
from .modules import (
|
||||
build_action_decoder,
|
||||
build_action_encoder,
|
||||
build_state_encoder,
|
||||
build_time_sampler,
|
||||
build_noise_scheduler,
|
||||
)
|
||||
from .backbone import build_backbone
|
||||
from .dit import build_dit
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class gr00t(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
backbones,
|
||||
dit,
|
||||
state_encoder,
|
||||
action_encoder,
|
||||
action_decoder,
|
||||
time_sampler,
|
||||
noise_scheduler,
|
||||
num_queries,
|
||||
camera_names,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.camera_names = camera_names
|
||||
self.dit = dit
|
||||
self.state_encoder = state_encoder
|
||||
self.action_encoder = action_encoder
|
||||
self.action_decoder = action_decoder
|
||||
self.time_sampler = time_sampler
|
||||
self.noise_scheduler = noise_scheduler
|
||||
|
||||
if backbones is not None:
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, qpos, image, actions=None, is_pad=None):
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
|
||||
all_cam_features = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
B, C, H, W = features.shape
|
||||
features_seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)
|
||||
all_cam_features.append(features_seq)
|
||||
encoder_hidden_states = torch.cat(all_cam_features, dim=1)
|
||||
|
||||
state_features = self.state_encoder(qpos) # [B, 1, emb_dim]
|
||||
|
||||
if is_training:
|
||||
# training logic
|
||||
|
||||
timesteps = self.time_sampler(bs, actions.device, actions.dtype)
|
||||
noisy_actions, target_velocity = self.noise_scheduler.add_noise(
|
||||
actions, timesteps
|
||||
)
|
||||
t_discretized = (timesteps[:, 0, 0] * 1000).long()
|
||||
action_features = self.action_encoder(noisy_actions, t_discretized)
|
||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||
model_output = self.dit(sa_embs, t_discretized, encoder_hidden_states)
|
||||
pred = self.action_decoder(model_output)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
action_loss = F.mse_loss(pred_actions, target_velocity, reduction='none')
|
||||
return pred_actions, action_loss
|
||||
else:
|
||||
actions = torch.randn(bs, self.num_queries, qpos.shape[-1], device=qpos.device, dtype=qpos.dtype)
|
||||
k = 5
|
||||
dt = 1.0 / k
|
||||
for t in range(k):
|
||||
t_cont = t / float(k)
|
||||
t_discretized = int(t_cont * 1000)
|
||||
timesteps = torch.full((bs,), t_discretized, device=qpos.device, dtype=qpos.dtype)
|
||||
action_features = self.action_encoder(actions, timesteps)
|
||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||
# Create tensor of shape [B] for DiT (consistent with training path)
|
||||
model_output = self.dit(sa_embs, timesteps, encoder_hidden_states)
|
||||
pred = self.action_decoder(model_output)
|
||||
pred_velocity = pred[:, -self.num_queries :]
|
||||
actions = actions + pred_velocity * dt
|
||||
return actions, _
|
||||
def build_gr00t_model(args):
|
||||
state_dim = args.state_dim
|
||||
action_dim = args.action_dim
|
||||
|
||||
backbones = []
|
||||
for _ in args.camera_names:
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
cross_attention_dim = backbones[0].num_channels
|
||||
|
||||
dit = build_dit(args, cross_attention_dim)
|
||||
|
||||
state_encoder = build_state_encoder(args)
|
||||
action_encoder = build_action_encoder(args)
|
||||
action_decoder = build_action_decoder(args)
|
||||
time_sampler = build_time_sampler(args)
|
||||
noise_scheduler = build_noise_scheduler(args)
|
||||
model = gr00t(
|
||||
backbones,
|
||||
dit,
|
||||
state_encoder,
|
||||
action_encoder,
|
||||
action_decoder,
|
||||
time_sampler,
|
||||
noise_scheduler,
|
||||
args.num_queries,
|
||||
args.camera_names,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters/1e6,))
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user