229 lines
6.8 KiB
Python
229 lines
6.8 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Iterable, Sequence
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from roboimi.vla.models.heads.attnres_transformer_components import AttnResTransformerBackbone
|
|
|
|
|
|
def _make_norm2d(num_channels: int, use_group_norm: bool) -> nn.Module:
|
|
if use_group_norm:
|
|
num_groups = max(1, num_channels // 16)
|
|
return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
|
|
return nn.BatchNorm2d(num_channels)
|
|
|
|
|
|
class _ConvNormAct(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
*,
|
|
kernel_size: int,
|
|
stride: int,
|
|
padding: int,
|
|
use_group_norm: bool,
|
|
) -> None:
|
|
super().__init__()
|
|
self.block = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=False,
|
|
),
|
|
_make_norm2d(out_channels, use_group_norm),
|
|
nn.SiLU(),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.block(x)
|
|
|
|
|
|
class AttnResImageBlock2D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
*,
|
|
window_size: int,
|
|
n_heads: int,
|
|
n_kv_heads: int,
|
|
dropout: float,
|
|
ffn_mult: float,
|
|
eps: float,
|
|
rope_theta: float,
|
|
) -> None:
|
|
super().__init__()
|
|
self.window_size = int(window_size)
|
|
self.block = AttnResTransformerBackbone(
|
|
d_model=dim,
|
|
n_blocks=1,
|
|
n_heads=n_heads,
|
|
n_kv_heads=n_kv_heads,
|
|
max_seq_len=self.window_size * self.window_size,
|
|
dropout=dropout,
|
|
ffn_mult=ffn_mult,
|
|
eps=eps,
|
|
rope_theta=rope_theta,
|
|
causal_attn=False,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
bsz, channels, height, width = x.shape
|
|
ws = self.window_size
|
|
pad_h = (ws - height % ws) % ws
|
|
pad_w = (ws - width % ws) % ws
|
|
if pad_h or pad_w:
|
|
x = F.pad(x, (0, pad_w, 0, pad_h))
|
|
padded_height, padded_width = x.shape[-2:]
|
|
num_h = padded_height // ws
|
|
num_w = padded_width // ws
|
|
|
|
windows = (
|
|
x.permute(0, 2, 3, 1)
|
|
.contiguous()
|
|
.view(bsz, num_h, ws, num_w, ws, channels)
|
|
.permute(0, 1, 3, 2, 4, 5)
|
|
.contiguous()
|
|
.view(bsz * num_h * num_w, ws * ws, channels)
|
|
)
|
|
windows = self.block(windows)
|
|
x = (
|
|
windows.view(bsz, num_h, num_w, ws, ws, channels)
|
|
.permute(0, 1, 3, 2, 4, 5)
|
|
.contiguous()
|
|
.view(bsz, padded_height, padded_width, channels)
|
|
.permute(0, 3, 1, 2)
|
|
.contiguous()
|
|
)
|
|
return x[:, :, :height, :width]
|
|
|
|
|
|
class _AttnResStage2D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
*,
|
|
depth: int,
|
|
downsample_stride: int,
|
|
use_group_norm: bool,
|
|
window_size: int,
|
|
n_heads: int,
|
|
n_kv_heads: int,
|
|
dropout: float,
|
|
ffn_mult: float,
|
|
eps: float,
|
|
rope_theta: float,
|
|
) -> None:
|
|
super().__init__()
|
|
self.downsample = None
|
|
if in_channels != out_channels or downsample_stride != 1:
|
|
kernel_size = 1 if downsample_stride == 1 else 3
|
|
padding = 0 if downsample_stride == 1 else 1
|
|
self.downsample = _ConvNormAct(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=downsample_stride,
|
|
padding=padding,
|
|
use_group_norm=use_group_norm,
|
|
)
|
|
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
AttnResImageBlock2D(
|
|
out_channels,
|
|
window_size=window_size,
|
|
n_heads=n_heads,
|
|
n_kv_heads=n_kv_heads,
|
|
dropout=dropout,
|
|
ffn_mult=ffn_mult,
|
|
eps=eps,
|
|
rope_theta=rope_theta,
|
|
)
|
|
for _ in range(depth)
|
|
]
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
if self.downsample is not None:
|
|
x = self.downsample(x)
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
return x
|
|
|
|
|
|
class AttnResResNetLikeBackbone2D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
input_channels: int = 3,
|
|
stem_dim: int = 64,
|
|
stage_dims: Sequence[int] = (64, 128, 256, 512),
|
|
stage_depths: Sequence[int] = (2, 2, 2, 2),
|
|
stage_heads: Sequence[int] = (4, 4, 8, 8),
|
|
stage_kv_heads: Sequence[int] = (1, 1, 1, 1),
|
|
stage_window_sizes: Sequence[int] = (7, 7, 7, 7),
|
|
use_group_norm: bool = True,
|
|
dropout: float = 0.0,
|
|
ffn_mult: float = 2.667,
|
|
eps: float = 1e-6,
|
|
rope_theta: float = 10000.0,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
lengths = {
|
|
len(stage_dims),
|
|
len(stage_depths),
|
|
len(stage_heads),
|
|
len(stage_kv_heads),
|
|
len(stage_window_sizes),
|
|
}
|
|
if len(lengths) != 1:
|
|
raise ValueError('stage_dims/depths/heads/kv_heads/window_sizes 长度必须一致')
|
|
|
|
self.stem = nn.Sequential(
|
|
nn.Conv2d(input_channels, stem_dim, kernel_size=7, stride=2, padding=3, bias=False),
|
|
_make_norm2d(stem_dim, use_group_norm),
|
|
nn.SiLU(),
|
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
|
)
|
|
|
|
in_channels = stem_dim
|
|
stages = []
|
|
for stage_idx, (out_channels, depth, n_heads, n_kv_heads, window_size) in enumerate(
|
|
zip(stage_dims, stage_depths, stage_heads, stage_kv_heads, stage_window_sizes)
|
|
):
|
|
stages.append(
|
|
_AttnResStage2D(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
depth=int(depth),
|
|
downsample_stride=1 if stage_idx == 0 else 2,
|
|
use_group_norm=use_group_norm,
|
|
window_size=int(window_size),
|
|
n_heads=int(n_heads),
|
|
n_kv_heads=int(n_kv_heads),
|
|
dropout=float(dropout),
|
|
ffn_mult=float(ffn_mult),
|
|
eps=float(eps),
|
|
rope_theta=float(rope_theta),
|
|
)
|
|
)
|
|
in_channels = out_channels
|
|
|
|
self.stages = nn.ModuleList(stages)
|
|
self.output_channels = in_channels
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.stem(x)
|
|
for stage in self.stages:
|
|
x = stage(x)
|
|
return x
|