feat: 编写状态编码器、动作编码器
This commit is contained in:
@@ -1 +1,106 @@
|
|||||||
# StateEncoder, ActionEncoder
|
# StateEncoder, ActionEncoder
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim,
|
||||||
|
hidden_dim,
|
||||||
|
output_dim
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, output_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input
|
||||||
|
):
|
||||||
|
output = self.model(input)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidalPositionalEncoding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
emb_dim
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.emb_dim = emb_dim
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
timesteps = timesteps.float()
|
||||||
|
B, T = timesteps.shape
|
||||||
|
device = timesteps.device
|
||||||
|
|
||||||
|
half_dim = self.emb_dim // 2
|
||||||
|
|
||||||
|
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||||
|
torch.log(torch.tensor(10000.0)) / half_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||||
|
|
||||||
|
sin = torch.sin(freqs)
|
||||||
|
cos = torch.cos(freqs)
|
||||||
|
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||||
|
|
||||||
|
return enc
|
||||||
|
|
||||||
|
class ActionEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
action_dim,
|
||||||
|
emb_dim,
|
||||||
|
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.W1 = nn.Linear(action_dim, emb_dim)
|
||||||
|
self.W2 = nn.Linear(2 * action_dim, action_dim)
|
||||||
|
self.W3 = nn.Linear(emb_dim, emb_dim)
|
||||||
|
self.pos_encoder = SinusoidalPositionalEncoding(emb_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
actions,
|
||||||
|
timesteps
|
||||||
|
):
|
||||||
|
B, T, _ = actions.shape
|
||||||
|
timesteps = timesteps.unsqueeze(1).expand(-1, T)
|
||||||
|
|
||||||
|
a_emb = self.W1(actions)
|
||||||
|
tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype)
|
||||||
|
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||||
|
x = F.silu(self.W2(x))
|
||||||
|
x = self.W3(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class StateEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
state_dim,
|
||||||
|
hidden_dim,
|
||||||
|
emb_dim
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = MLP(
|
||||||
|
state_dim,
|
||||||
|
hidden_dim,
|
||||||
|
emb_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
states
|
||||||
|
):
|
||||||
|
state_emb = self.mlp(states)
|
||||||
|
return state_emb # [B, 1, emb_dim]
|
||||||
Reference in New Issue
Block a user