Files
roboimi/roboimi/vla/models/backbones/attnres_resnet2d.py
2026-04-05 00:07:59 +08:00

229 lines
6.8 KiB
Python

from __future__ import annotations
from typing import Iterable, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from roboimi.vla.models.heads.attnres_transformer_components import AttnResTransformerBackbone
def _make_norm2d(num_channels: int, use_group_norm: bool) -> nn.Module:
if use_group_norm:
num_groups = max(1, num_channels // 16)
return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
return nn.BatchNorm2d(num_channels)
class _ConvNormAct(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
kernel_size: int,
stride: int,
padding: int,
use_group_norm: bool,
) -> None:
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False,
),
_make_norm2d(out_channels, use_group_norm),
nn.SiLU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.block(x)
class AttnResImageBlock2D(nn.Module):
def __init__(
self,
dim: int,
*,
window_size: int,
n_heads: int,
n_kv_heads: int,
dropout: float,
ffn_mult: float,
eps: float,
rope_theta: float,
) -> None:
super().__init__()
self.window_size = int(window_size)
self.block = AttnResTransformerBackbone(
d_model=dim,
n_blocks=1,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
max_seq_len=self.window_size * self.window_size,
dropout=dropout,
ffn_mult=ffn_mult,
eps=eps,
rope_theta=rope_theta,
causal_attn=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
bsz, channels, height, width = x.shape
ws = self.window_size
pad_h = (ws - height % ws) % ws
pad_w = (ws - width % ws) % ws
if pad_h or pad_w:
x = F.pad(x, (0, pad_w, 0, pad_h))
padded_height, padded_width = x.shape[-2:]
num_h = padded_height // ws
num_w = padded_width // ws
windows = (
x.permute(0, 2, 3, 1)
.contiguous()
.view(bsz, num_h, ws, num_w, ws, channels)
.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(bsz * num_h * num_w, ws * ws, channels)
)
windows = self.block(windows)
x = (
windows.view(bsz, num_h, num_w, ws, ws, channels)
.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(bsz, padded_height, padded_width, channels)
.permute(0, 3, 1, 2)
.contiguous()
)
return x[:, :, :height, :width]
class _AttnResStage2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
depth: int,
downsample_stride: int,
use_group_norm: bool,
window_size: int,
n_heads: int,
n_kv_heads: int,
dropout: float,
ffn_mult: float,
eps: float,
rope_theta: float,
) -> None:
super().__init__()
self.downsample = None
if in_channels != out_channels or downsample_stride != 1:
kernel_size = 1 if downsample_stride == 1 else 3
padding = 0 if downsample_stride == 1 else 1
self.downsample = _ConvNormAct(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=downsample_stride,
padding=padding,
use_group_norm=use_group_norm,
)
self.blocks = nn.ModuleList(
[
AttnResImageBlock2D(
out_channels,
window_size=window_size,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
dropout=dropout,
ffn_mult=ffn_mult,
eps=eps,
rope_theta=rope_theta,
)
for _ in range(depth)
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.downsample is not None:
x = self.downsample(x)
for block in self.blocks:
x = block(x)
return x
class AttnResResNetLikeBackbone2D(nn.Module):
def __init__(
self,
*,
input_channels: int = 3,
stem_dim: int = 64,
stage_dims: Sequence[int] = (64, 128, 256, 512),
stage_depths: Sequence[int] = (2, 2, 2, 2),
stage_heads: Sequence[int] = (4, 4, 8, 8),
stage_kv_heads: Sequence[int] = (1, 1, 1, 1),
stage_window_sizes: Sequence[int] = (7, 7, 7, 7),
use_group_norm: bool = True,
dropout: float = 0.0,
ffn_mult: float = 2.667,
eps: float = 1e-6,
rope_theta: float = 10000.0,
) -> None:
super().__init__()
lengths = {
len(stage_dims),
len(stage_depths),
len(stage_heads),
len(stage_kv_heads),
len(stage_window_sizes),
}
if len(lengths) != 1:
raise ValueError('stage_dims/depths/heads/kv_heads/window_sizes 长度必须一致')
self.stem = nn.Sequential(
nn.Conv2d(input_channels, stem_dim, kernel_size=7, stride=2, padding=3, bias=False),
_make_norm2d(stem_dim, use_group_norm),
nn.SiLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
)
in_channels = stem_dim
stages = []
for stage_idx, (out_channels, depth, n_heads, n_kv_heads, window_size) in enumerate(
zip(stage_dims, stage_depths, stage_heads, stage_kv_heads, stage_window_sizes)
):
stages.append(
_AttnResStage2D(
in_channels=in_channels,
out_channels=out_channels,
depth=int(depth),
downsample_stride=1 if stage_idx == 0 else 2,
use_group_norm=use_group_norm,
window_size=int(window_size),
n_heads=int(n_heads),
n_kv_heads=int(n_kv_heads),
dropout=float(dropout),
ffn_mult=float(ffn_mult),
eps=float(eps),
rope_theta=float(rope_theta),
)
)
in_channels = out_channels
self.stages = nn.ModuleList(stages)
self.output_channels = in_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.stem(x)
for stage in self.stages:
x = stage(x)
return x