Files
roboimi/roboimi/gr00t/models/modules.py
2026-02-02 17:16:28 +08:00

180 lines
5.2 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
# ActionEncoder
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, args):
super().__init__()
self.embed_dim = args.embed_dim
def forward(self, timesteps):
timesteps = timesteps.float()
B, T = timesteps.shape
device = timesteps.device
half_dim = self.embed_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
freqs = timesteps.unsqueeze(-1) * exponent.exp()
sin = torch.sin(freqs)
cos = torch.cos(freqs)
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
return enc
class ActionEncoder(nn.Module):
def __init__(self, args):
super().__init__()
action_dim = args.action_dim
embed_dim = args.embed_dim
self.W1 = nn.Linear(action_dim, embed_dim)
self.W2 = nn.Linear(2 * embed_dim, embed_dim)
self.W3 = nn.Linear(embed_dim, embed_dim)
self.pos_encoder = SinusoidalPositionalEncoding(args)
def forward(self, actions, timesteps):
B, T, _ = actions.shape
# 1) Expand each batch's single scalar time 'tau' across all T steps
# so that shape => (B, T)
# Handle different input shapes: (B,), (B, 1), (B, 1, 1)
# Reshape to (B,) then expand to (B, T)
# if timesteps.dim() == 3:
# # Shape (B, 1, 1) or (B, T, 1) -> (B,)
# timesteps = timesteps[:, 0, 0]
# elif timesteps.dim() == 2:
# # Shape (B, 1) or (B, T) -> take first element if needed
# if timesteps.shape[1] == 1:
# timesteps = timesteps[:, 0]
# # else: already (B, T), use as is
# elif timesteps.dim() != 1:
# raise ValueError(
# f"Expected `timesteps` to have shape (B,), (B, 1), or (B, 1, 1), got {timesteps.shape}"
# )
# Now timesteps should be (B,), expand to (B, T)
if timesteps.dim() == 1 and timesteps.shape[0] == B:
timesteps = timesteps.unsqueeze(1).expand(-1, T)
else:
raise ValueError(
"Expected `timesteps` to have shape (B,) so we can replicate across T."
)
# 2) Standard action MLP step for shape => (B, T, w)
a_emb = self.W1(actions)
# 3) Get the sinusoidal encoding (B, T, w)
tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype)
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
x = torch.cat([a_emb, tau_emb], dim=-1)
x = F.silu(self.W2(x))
# 5) Finally W3 => (B, T, w)
x = self.W3(x)
return x
def build_action_encoder(args):
return ActionEncoder(args)
# StateEncoder
class StateEncoder(nn.Module):
def __init__(self, args):
super().__init__()
input_dim = args.state_dim
hidden_dim = args.hidden_dim
output_dim = args.embed_dim
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, states):
state_emb = self.mlp(states) # [B, emb_dim]
state_emb = state_emb.unsqueeze(1)
return state_emb # [B, 1, emb_dim]
def build_state_encoder(args):
return StateEncoder(args)
# ActionDecoder
class ActionDecoder(nn.Module):
def __init__(self,args):
super().__init__()
input_dim = args.hidden_dim
hidden_dim = args.hidden_dim
output_dim = args.action_dim
self.num_queries = args.num_queries
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, model_output):
pred_actions = self.mlp(model_output)
return pred_actions[:, -self.num_queries:]
def build_action_decoder(args):
return ActionDecoder(args)
# TimeSampler
class TimeSampler(nn.Module):
def __init__(self, noise_s = 0.999, noise_beta_alpha=1.5, noise_beta_beta=1.0):
super().__init__()
self.noise_s = noise_s
self.beta_dist = torch.distributions.Beta(noise_beta_alpha, noise_beta_beta)
def forward(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
sample = (1 - sample) * self.noise_s
return sample[:, None, None]
def build_time_sampler(args):
return TimeSampler()
# NoiseScheduler
import torch
import torch.nn as nn
class FlowMatchingScheduler(nn.Module):
def __init__(self):
super().__init__()
# --- 训练逻辑:加噪并计算目标 ---
def add_noise(self, actions, timesteps):
noise = torch.randn_like(actions)
noisy_samples = actions * timesteps + noise * (1 - timesteps)
target_velocity = actions - noise
return noisy_samples, target_velocity
# --- 推理逻辑:欧拉步 (Euler Step) ---
def step(self, model_output, sample, dt):
prev_sample = sample + model_output * dt
return prev_sample
def build_noise_scheduler(args):
return FlowMatchingScheduler()