chore: 导入gr00t
This commit is contained in:
179
gr00t/models/modules.py
Normal file
179
gr00t/models/modules.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# ActionEncoder
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.embed_dim = args.embed_dim
|
||||
|
||||
def forward(self, timesteps):
|
||||
timesteps = timesteps.float()
|
||||
B, T = timesteps.shape
|
||||
device = timesteps.device
|
||||
|
||||
half_dim = self.embed_dim // 2
|
||||
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
|
||||
sin = torch.sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||
|
||||
return enc
|
||||
|
||||
|
||||
class ActionEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
action_dim = args.action_dim
|
||||
embed_dim = args.embed_dim
|
||||
|
||||
self.W1 = nn.Linear(action_dim, embed_dim)
|
||||
self.W2 = nn.Linear(2 * embed_dim, embed_dim)
|
||||
self.W3 = nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
self.pos_encoder = SinusoidalPositionalEncoding(args)
|
||||
|
||||
def forward(self, actions, timesteps):
|
||||
B, T, _ = actions.shape
|
||||
|
||||
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||
# so that shape => (B, T)
|
||||
# Handle different input shapes: (B,), (B, 1), (B, 1, 1)
|
||||
# Reshape to (B,) then expand to (B, T)
|
||||
# if timesteps.dim() == 3:
|
||||
# # Shape (B, 1, 1) or (B, T, 1) -> (B,)
|
||||
# timesteps = timesteps[:, 0, 0]
|
||||
# elif timesteps.dim() == 2:
|
||||
# # Shape (B, 1) or (B, T) -> take first element if needed
|
||||
# if timesteps.shape[1] == 1:
|
||||
# timesteps = timesteps[:, 0]
|
||||
# # else: already (B, T), use as is
|
||||
# elif timesteps.dim() != 1:
|
||||
# raise ValueError(
|
||||
# f"Expected `timesteps` to have shape (B,), (B, 1), or (B, 1, 1), got {timesteps.shape}"
|
||||
# )
|
||||
|
||||
# Now timesteps should be (B,), expand to (B, T)
|
||||
if timesteps.dim() == 1 and timesteps.shape[0] == B:
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, T)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected `timesteps` to have shape (B,) so we can replicate across T."
|
||||
)
|
||||
|
||||
# 2) Standard action MLP step for shape => (B, T, w)
|
||||
a_emb = self.W1(actions)
|
||||
|
||||
# 3) Get the sinusoidal encoding (B, T, w)
|
||||
tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype)
|
||||
|
||||
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||
x = F.silu(self.W2(x))
|
||||
|
||||
# 5) Finally W3 => (B, T, w)
|
||||
x = self.W3(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def build_action_encoder(args):
|
||||
return ActionEncoder(args)
|
||||
|
||||
|
||||
# StateEncoder
|
||||
class StateEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
input_dim = args.state_dim
|
||||
hidden_dim = args.hidden_dim
|
||||
output_dim = args.embed_dim
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, output_dim),
|
||||
)
|
||||
|
||||
def forward(self, states):
|
||||
state_emb = self.mlp(states) # [B, emb_dim]
|
||||
state_emb = state_emb.unsqueeze(1)
|
||||
return state_emb # [B, 1, emb_dim]
|
||||
|
||||
|
||||
def build_state_encoder(args):
|
||||
return StateEncoder(args)
|
||||
|
||||
|
||||
# ActionDecoder
|
||||
class ActionDecoder(nn.Module):
|
||||
def __init__(self,args):
|
||||
super().__init__()
|
||||
input_dim = args.hidden_dim
|
||||
hidden_dim = args.hidden_dim
|
||||
output_dim = args.action_dim
|
||||
|
||||
self.num_queries = args.num_queries
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, output_dim),
|
||||
)
|
||||
|
||||
def forward(self, model_output):
|
||||
pred_actions = self.mlp(model_output)
|
||||
return pred_actions[:, -self.num_queries:]
|
||||
|
||||
|
||||
def build_action_decoder(args):
|
||||
return ActionDecoder(args)
|
||||
|
||||
|
||||
# TimeSampler
|
||||
class TimeSampler(nn.Module):
|
||||
def __init__(self, noise_s = 0.999, noise_beta_alpha=1.5, noise_beta_beta=1.0):
|
||||
super().__init__()
|
||||
self.noise_s = noise_s
|
||||
self.beta_dist = torch.distributions.Beta(noise_beta_alpha, noise_beta_beta)
|
||||
|
||||
def forward(self, batch_size, device, dtype):
|
||||
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
sample = (1 - sample) * self.noise_s
|
||||
return sample[:, None, None]
|
||||
|
||||
|
||||
def build_time_sampler(args):
|
||||
return TimeSampler()
|
||||
|
||||
|
||||
# NoiseScheduler
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class FlowMatchingScheduler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# --- 训练逻辑:加噪并计算目标 ---
|
||||
def add_noise(self, actions, timesteps):
|
||||
noise = torch.randn_like(actions)
|
||||
noisy_samples = actions * timesteps + noise * (1 - timesteps)
|
||||
target_velocity = actions - noise
|
||||
|
||||
return noisy_samples, target_velocity
|
||||
|
||||
# --- 推理逻辑:欧拉步 (Euler Step) ---
|
||||
def step(self, model_output, sample, dt):
|
||||
prev_sample = sample + model_output * dt
|
||||
return prev_sample
|
||||
|
||||
def build_noise_scheduler(args):
|
||||
return FlowMatchingScheduler()
|
||||
Reference in New Issue
Block a user