import torch import torch.nn as nn import torch.nn.functional as F # ActionEncoder class SinusoidalPositionalEncoding(nn.Module): def __init__(self, args): super().__init__() self.embed_dim = args.embed_dim def forward(self, timesteps): timesteps = timesteps.float() B, T = timesteps.shape device = timesteps.device half_dim = self.embed_dim // 2 exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * ( torch.log(torch.tensor(10000.0)) / half_dim ) freqs = timesteps.unsqueeze(-1) * exponent.exp() sin = torch.sin(freqs) cos = torch.cos(freqs) enc = torch.cat([sin, cos], dim=-1) # (B, T, w) return enc class ActionEncoder(nn.Module): def __init__(self, args): super().__init__() action_dim = args.action_dim embed_dim = args.embed_dim self.W1 = nn.Linear(action_dim, embed_dim) self.W2 = nn.Linear(2 * embed_dim, embed_dim) self.W3 = nn.Linear(embed_dim, embed_dim) self.pos_encoder = SinusoidalPositionalEncoding(args) def forward(self, actions, timesteps): B, T, _ = actions.shape # 1) Expand each batch's single scalar time 'tau' across all T steps # so that shape => (B, T) # Handle different input shapes: (B,), (B, 1), (B, 1, 1) # Reshape to (B,) then expand to (B, T) # if timesteps.dim() == 3: # # Shape (B, 1, 1) or (B, T, 1) -> (B,) # timesteps = timesteps[:, 0, 0] # elif timesteps.dim() == 2: # # Shape (B, 1) or (B, T) -> take first element if needed # if timesteps.shape[1] == 1: # timesteps = timesteps[:, 0] # # else: already (B, T), use as is # elif timesteps.dim() != 1: # raise ValueError( # f"Expected `timesteps` to have shape (B,), (B, 1), or (B, 1, 1), got {timesteps.shape}" # ) # Now timesteps should be (B,), expand to (B, T) if timesteps.dim() == 1 and timesteps.shape[0] == B: timesteps = timesteps.unsqueeze(1).expand(-1, T) else: raise ValueError( "Expected `timesteps` to have shape (B,) so we can replicate across T." ) # 2) Standard action MLP step for shape => (B, T, w) a_emb = self.W1(actions) # 3) Get the sinusoidal encoding (B, T, w) tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype) # 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish x = torch.cat([a_emb, tau_emb], dim=-1) x = F.silu(self.W2(x)) # 5) Finally W3 => (B, T, w) x = self.W3(x) return x def build_action_encoder(args): return ActionEncoder(args) # StateEncoder class StateEncoder(nn.Module): def __init__(self, args): super().__init__() input_dim = args.state_dim hidden_dim = args.hidden_dim output_dim = args.embed_dim self.mlp = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim), ) def forward(self, states): state_emb = self.mlp(states) # [B, emb_dim] state_emb = state_emb.unsqueeze(1) return state_emb # [B, 1, emb_dim] def build_state_encoder(args): return StateEncoder(args) # ActionDecoder class ActionDecoder(nn.Module): def __init__(self,args): super().__init__() input_dim = args.hidden_dim hidden_dim = args.hidden_dim output_dim = args.action_dim self.num_queries = args.num_queries self.mlp = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim), ) def forward(self, model_output): pred_actions = self.mlp(model_output) return pred_actions[:, -self.num_queries:] def build_action_decoder(args): return ActionDecoder(args) # TimeSampler class TimeSampler(nn.Module): def __init__(self, noise_s = 0.999, noise_beta_alpha=1.5, noise_beta_beta=1.0): super().__init__() self.noise_s = noise_s self.beta_dist = torch.distributions.Beta(noise_beta_alpha, noise_beta_beta) def forward(self, batch_size, device, dtype): sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype) sample = (1 - sample) * self.noise_s return sample[:, None, None] def build_time_sampler(args): return TimeSampler() # NoiseScheduler import torch import torch.nn as nn class FlowMatchingScheduler(nn.Module): def __init__(self): super().__init__() # --- 训练逻辑:加噪并计算目标 --- def add_noise(self, actions, timesteps): noise = torch.randn_like(actions) noisy_samples = actions * timesteps + noise * (1 - timesteps) target_velocity = actions - noise return noisy_samples, target_velocity # --- 推理逻辑:欧拉步 (Euler Step) --- def step(self, model_output, sample, dt): prev_sample = sample + model_output * dt return prev_sample def build_noise_scheduler(args): return FlowMatchingScheduler()