180 lines
5.2 KiB
Python
180 lines
5.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
# ActionEncoder
|
|
class SinusoidalPositionalEncoding(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.embed_dim = args.embed_dim
|
|
|
|
def forward(self, timesteps):
|
|
timesteps = timesteps.float()
|
|
B, T = timesteps.shape
|
|
device = timesteps.device
|
|
|
|
half_dim = self.embed_dim // 2
|
|
|
|
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
|
torch.log(torch.tensor(10000.0)) / half_dim
|
|
)
|
|
|
|
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
|
|
|
sin = torch.sin(freqs)
|
|
cos = torch.cos(freqs)
|
|
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
|
|
|
return enc
|
|
|
|
|
|
class ActionEncoder(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
action_dim = args.action_dim
|
|
embed_dim = args.embed_dim
|
|
|
|
self.W1 = nn.Linear(action_dim, embed_dim)
|
|
self.W2 = nn.Linear(2 * embed_dim, embed_dim)
|
|
self.W3 = nn.Linear(embed_dim, embed_dim)
|
|
|
|
self.pos_encoder = SinusoidalPositionalEncoding(args)
|
|
|
|
def forward(self, actions, timesteps):
|
|
B, T, _ = actions.shape
|
|
|
|
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
|
# so that shape => (B, T)
|
|
# Handle different input shapes: (B,), (B, 1), (B, 1, 1)
|
|
# Reshape to (B,) then expand to (B, T)
|
|
# if timesteps.dim() == 3:
|
|
# # Shape (B, 1, 1) or (B, T, 1) -> (B,)
|
|
# timesteps = timesteps[:, 0, 0]
|
|
# elif timesteps.dim() == 2:
|
|
# # Shape (B, 1) or (B, T) -> take first element if needed
|
|
# if timesteps.shape[1] == 1:
|
|
# timesteps = timesteps[:, 0]
|
|
# # else: already (B, T), use as is
|
|
# elif timesteps.dim() != 1:
|
|
# raise ValueError(
|
|
# f"Expected `timesteps` to have shape (B,), (B, 1), or (B, 1, 1), got {timesteps.shape}"
|
|
# )
|
|
|
|
# Now timesteps should be (B,), expand to (B, T)
|
|
if timesteps.dim() == 1 and timesteps.shape[0] == B:
|
|
timesteps = timesteps.unsqueeze(1).expand(-1, T)
|
|
else:
|
|
raise ValueError(
|
|
"Expected `timesteps` to have shape (B,) so we can replicate across T."
|
|
)
|
|
|
|
# 2) Standard action MLP step for shape => (B, T, w)
|
|
a_emb = self.W1(actions)
|
|
|
|
# 3) Get the sinusoidal encoding (B, T, w)
|
|
tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype)
|
|
|
|
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
|
x = torch.cat([a_emb, tau_emb], dim=-1)
|
|
x = F.silu(self.W2(x))
|
|
|
|
# 5) Finally W3 => (B, T, w)
|
|
x = self.W3(x)
|
|
|
|
return x
|
|
|
|
|
|
def build_action_encoder(args):
|
|
return ActionEncoder(args)
|
|
|
|
|
|
# StateEncoder
|
|
class StateEncoder(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
input_dim = args.state_dim
|
|
hidden_dim = args.hidden_dim
|
|
output_dim = args.embed_dim
|
|
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(input_dim, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim, output_dim),
|
|
)
|
|
|
|
def forward(self, states):
|
|
state_emb = self.mlp(states) # [B, emb_dim]
|
|
state_emb = state_emb.unsqueeze(1)
|
|
return state_emb # [B, 1, emb_dim]
|
|
|
|
|
|
def build_state_encoder(args):
|
|
return StateEncoder(args)
|
|
|
|
|
|
# ActionDecoder
|
|
class ActionDecoder(nn.Module):
|
|
def __init__(self,args):
|
|
super().__init__()
|
|
input_dim = args.hidden_dim
|
|
hidden_dim = args.hidden_dim
|
|
output_dim = args.action_dim
|
|
|
|
self.num_queries = args.num_queries
|
|
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(input_dim, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim, output_dim),
|
|
)
|
|
|
|
def forward(self, model_output):
|
|
pred_actions = self.mlp(model_output)
|
|
return pred_actions[:, -self.num_queries:]
|
|
|
|
|
|
def build_action_decoder(args):
|
|
return ActionDecoder(args)
|
|
|
|
|
|
# TimeSampler
|
|
class TimeSampler(nn.Module):
|
|
def __init__(self, noise_s = 0.999, noise_beta_alpha=1.5, noise_beta_beta=1.0):
|
|
super().__init__()
|
|
self.noise_s = noise_s
|
|
self.beta_dist = torch.distributions.Beta(noise_beta_alpha, noise_beta_beta)
|
|
|
|
def forward(self, batch_size, device, dtype):
|
|
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
|
sample = (1 - sample) * self.noise_s
|
|
return sample[:, None, None]
|
|
|
|
|
|
def build_time_sampler(args):
|
|
return TimeSampler()
|
|
|
|
|
|
# NoiseScheduler
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class FlowMatchingScheduler(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
# --- 训练逻辑:加噪并计算目标 ---
|
|
def add_noise(self, actions, timesteps):
|
|
noise = torch.randn_like(actions)
|
|
noisy_samples = actions * timesteps + noise * (1 - timesteps)
|
|
target_velocity = actions - noise
|
|
|
|
return noisy_samples, target_velocity
|
|
|
|
# --- 推理逻辑:欧拉步 (Euler Step) ---
|
|
def step(self, model_output, sample, dt):
|
|
prev_sample = sample + model_output * dt
|
|
return prev_sample
|
|
|
|
def build_noise_scheduler(args):
|
|
return FlowMatchingScheduler()
|