import torch import torch.nn as nn class Inception_Block_V1(nn.Module): def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): super(Inception_Block_V1, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.num_kernels = num_kernels kernels = [] for i in range(self.num_kernels): kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i)) self.kernels = nn.ModuleList(kernels) if init_weight: self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): res_list = [] for i in range(self.num_kernels): res_list.append(self.kernels[i](x)) res = torch.stack(res_list, dim=-1).mean(-1) return res class Inception_Block_V2(nn.Module): def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): super(Inception_Block_V2, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.num_kernels = num_kernels kernels = [] for i in range(self.num_kernels // 2): kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[1, 2 * i + 3], padding=[0, i + 1])) kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[2 * i + 3, 1], padding=[i + 1, 0])) kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1)) self.kernels = nn.ModuleList(kernels) if init_weight: self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): res_list = [] for i in range(self.num_kernels // 2 * 2 + 1): res_list.append(self.kernels[i](x)) res = torch.stack(res_list, dim=-1).mean(-1) return res