first commit

This commit is contained in:
gameloader
2025-08-28 10:17:59 +00:00
commit d6dd462886
350 changed files with 39789 additions and 0 deletions

View File

@ -0,0 +1,162 @@
# coding=utf-8
# author=maziqing
# email=maziqing.mzq@alibaba-inc.com
import numpy as np
import torch
import torch.nn as nn
def get_frequency_modes(seq_len, modes=64, mode_select_method='random'):
"""
get modes on frequency domain:
'random' means sampling randomly;
'else' means sampling the lowest modes;
"""
modes = min(modes, seq_len // 2)
if mode_select_method == 'random':
index = list(range(0, seq_len // 2))
np.random.shuffle(index)
index = index[:modes]
else:
index = list(range(0, modes))
index.sort()
return index
# ########## fourier layer #############
class FourierBlock(nn.Module):
def __init__(self, in_channels, out_channels, n_heads, seq_len, modes=0, mode_select_method='random'):
super(FourierBlock, self).__init__()
print('fourier enhanced block used!')
"""
1D Fourier block. It performs representation learning on frequency domain,
it does FFT, linear transform, and Inverse FFT.
"""
# get modes on frequency domain
self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method)
print('modes={}, index={}'.format(modes, self.index))
self.n_heads = n_heads
self.scale = (1 / (in_channels * out_channels))
self.weights1 = nn.Parameter(
self.scale * torch.rand(self.n_heads, in_channels // self.n_heads, out_channels // self.n_heads,
len(self.index), dtype=torch.float))
self.weights2 = nn.Parameter(
self.scale * torch.rand(self.n_heads, in_channels // self.n_heads, out_channels // self.n_heads,
len(self.index), dtype=torch.float))
# Complex multiplication
def compl_mul1d(self, order, x, weights):
x_flag = True
w_flag = True
if not torch.is_complex(x):
x_flag = False
x = torch.complex(x, torch.zeros_like(x).to(x.device))
if not torch.is_complex(weights):
w_flag = False
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
if x_flag or w_flag:
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
else:
return torch.einsum(order, x.real, weights.real)
def forward(self, q, k, v, mask):
# size = [B, L, H, E]
B, L, H, E = q.shape
x = q.permute(0, 2, 3, 1)
# Compute Fourier coefficients
x_ft = torch.fft.rfft(x, dim=-1)
# Perform Fourier neural operations
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
for wi, i in enumerate(self.index):
if i >= x_ft.shape[3] or wi >= out_ft.shape[3]:
continue
out_ft[:, :, :, wi] = self.compl_mul1d("bhi,hio->bho", x_ft[:, :, :, i],
torch.complex(self.weights1, self.weights2)[:, :, :, wi])
# Return to time domain
x = torch.fft.irfft(out_ft, n=x.size(-1))
return (x, None)
# ########## Fourier Cross Former ####################
class FourierCrossAttention(nn.Module):
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random',
activation='tanh', policy=0, num_heads=8):
super(FourierCrossAttention, self).__init__()
print(' fourier enhanced cross attention used!')
"""
1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT.
"""
self.activation = activation
self.in_channels = in_channels
self.out_channels = out_channels
# get modes for queries and keys (& values) on frequency domain
self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method)
self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method)
print('modes_q={}, index_q={}'.format(len(self.index_q), self.index_q))
print('modes_kv={}, index_kv={}'.format(len(self.index_kv), self.index_kv))
self.scale = (1 / (in_channels * out_channels))
self.weights1 = nn.Parameter(
self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float))
self.weights2 = nn.Parameter(
self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float))
# Complex multiplication
def compl_mul1d(self, order, x, weights):
x_flag = True
w_flag = True
if not torch.is_complex(x):
x_flag = False
x = torch.complex(x, torch.zeros_like(x).to(x.device))
if not torch.is_complex(weights):
w_flag = False
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
if x_flag or w_flag:
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
else:
return torch.einsum(order, x.real, weights.real)
def forward(self, q, k, v, mask):
# size = [B, L, H, E]
B, L, H, E = q.shape
xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L]
xk = k.permute(0, 2, 3, 1)
xv = v.permute(0, 2, 3, 1)
# Compute Fourier coefficients
xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)
xq_ft = torch.fft.rfft(xq, dim=-1)
for i, j in enumerate(self.index_q):
if j >= xq_ft.shape[3]:
continue
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat)
xk_ft = torch.fft.rfft(xk, dim=-1)
for i, j in enumerate(self.index_kv):
if j >= xk_ft.shape[3]:
continue
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
# perform attention mechanism on frequency domain
xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_))
if self.activation == 'tanh':
xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
elif self.activation == 'softmax':
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
else:
raise Exception('{} actiation function is not implemented'.format(self.activation))
xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
xqkvw = self.compl_mul1d("bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2))
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
for i, j in enumerate(self.index_q):
if i >= xqkvw.shape[3] or j >= out_ft.shape[3]:
continue
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
# Return to time domain
out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1))
return (out, None)