diff --git a/roboimi/vla/modules/encoders.py b/roboimi/vla/modules/encoders.py index 0a5ba28..8e8c411 100644 --- a/roboimi/vla/modules/encoders.py +++ b/roboimi/vla/modules/encoders.py @@ -1 +1,106 @@ # StateEncoder, ActionEncoder +import torch +from torch import nn +import torch.nn.functional as F + + +class MLP(nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + output_dim + ): + super().__init__() + self.model = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, output_dim) + ) + + def forward( + self, + input + ): + output = self.model(input) + return output + + + +class SinusoidalPositionalEncoding(nn.Module): + def __init__( + self, + emb_dim + ): + super().__init__() + self.emb_dim = emb_dim + + def forward(self, timesteps): + timesteps = timesteps.float() + B, T = timesteps.shape + device = timesteps.device + + half_dim = self.emb_dim // 2 + + exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * ( + torch.log(torch.tensor(10000.0)) / half_dim + ) + + freqs = timesteps.unsqueeze(-1) * exponent.exp() + + sin = torch.sin(freqs) + cos = torch.cos(freqs) + enc = torch.cat([sin, cos], dim=-1) # (B, T, w) + + return enc + +class ActionEncoder(nn.Module): + def __init__( + self, + action_dim, + emb_dim, + + ): + super().__init__() + self.W1 = nn.Linear(action_dim, emb_dim) + self.W2 = nn.Linear(2 * action_dim, action_dim) + self.W3 = nn.Linear(emb_dim, emb_dim) + self.pos_encoder = SinusoidalPositionalEncoding(emb_dim) + + def forward( + self, + actions, + timesteps + ): + B, T, _ = actions.shape + timesteps = timesteps.unsqueeze(1).expand(-1, T) + + a_emb = self.W1(actions) + tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype) + x = torch.cat([a_emb, tau_emb], dim=-1) + x = F.silu(self.W2(x)) + x = self.W3(x) + + return x + + +class StateEncoder(nn.Module): + def __init__( + self, + state_dim, + hidden_dim, + emb_dim + ): + super().__init__() + self.mlp = MLP( + state_dim, + hidden_dim, + emb_dim + ) + + def forward( + self, + states + ): + state_emb = self.mlp(states) + return state_emb # [B, 1, emb_dim] \ No newline at end of file