31 lines
816 B
Python
31 lines
816 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class TeLU(nn.Module):
|
|
"""
|
|
实现论文中提出的 TeLU 激活函数。
|
|
论文: TeLU Activation Function for Fast and Stable Deep Learning
|
|
公式: TeLU(x) = x * tanh(e^x)
|
|
"""
|
|
def __init__(self):
|
|
"""
|
|
TeLU 激活函数没有可学习的参数,所以 __init__ 方法很简单。
|
|
"""
|
|
super(TeLU, self).__init__()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
前向传播的计算逻辑。
|
|
"""
|
|
# 直接应用公式
|
|
return x * torch.tanh(torch.exp(x))
|
|
|
|
def __repr__(self):
|
|
"""
|
|
(可选但推荐) 定义一个好的字符串表示,方便打印模型结构。
|
|
"""
|
|
return f"{self.__class__.__name__}()"
|
|
|
|
|