588 lines
22 KiB
Python
588 lines
22 KiB
Python
import torch
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from typing import List, Tuple
|
|
import math
|
|
from functools import partial
|
|
from torch import nn, einsum, diagonal
|
|
from math import log2, ceil
|
|
import pdb
|
|
from sympy import Poly, legendre, Symbol, chebyshevt
|
|
from scipy.special import eval_legendre
|
|
|
|
|
|
def legendreDer(k, x):
|
|
def _legendre(k, x):
|
|
return (2 * k + 1) * eval_legendre(k, x)
|
|
|
|
out = 0
|
|
for i in np.arange(k - 1, -1, -2):
|
|
out += _legendre(i, x)
|
|
return out
|
|
|
|
|
|
def phi_(phi_c, x, lb=0, ub=1):
|
|
mask = np.logical_or(x < lb, x > ub) * 1.0
|
|
return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1 - mask)
|
|
|
|
|
|
def get_phi_psi(k, base):
|
|
x = Symbol('x')
|
|
phi_coeff = np.zeros((k, k))
|
|
phi_2x_coeff = np.zeros((k, k))
|
|
if base == 'legendre':
|
|
for ki in range(k):
|
|
coeff_ = Poly(legendre(ki, 2 * x - 1), x).all_coeffs()
|
|
phi_coeff[ki, :ki + 1] = np.flip(np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64))
|
|
coeff_ = Poly(legendre(ki, 4 * x - 1), x).all_coeffs()
|
|
phi_2x_coeff[ki, :ki + 1] = np.flip(np.sqrt(2) * np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64))
|
|
|
|
psi1_coeff = np.zeros((k, k))
|
|
psi2_coeff = np.zeros((k, k))
|
|
for ki in range(k):
|
|
psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
|
|
for i in range(k):
|
|
a = phi_2x_coeff[ki, :ki + 1]
|
|
b = phi_coeff[i, :i + 1]
|
|
prod_ = np.convolve(a, b)
|
|
prod_[np.abs(prod_) < 1e-8] = 0
|
|
proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
|
|
psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
|
|
psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
|
|
for j in range(ki):
|
|
a = phi_2x_coeff[ki, :ki + 1]
|
|
b = psi1_coeff[j, :]
|
|
prod_ = np.convolve(a, b)
|
|
prod_[np.abs(prod_) < 1e-8] = 0
|
|
proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
|
|
psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
|
|
psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]
|
|
|
|
a = psi1_coeff[ki, :]
|
|
prod_ = np.convolve(a, a)
|
|
prod_[np.abs(prod_) < 1e-8] = 0
|
|
norm1 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
|
|
|
|
a = psi2_coeff[ki, :]
|
|
prod_ = np.convolve(a, a)
|
|
prod_[np.abs(prod_) < 1e-8] = 0
|
|
norm2 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * (1 - np.power(0.5, 1 + np.arange(len(prod_))))).sum()
|
|
norm_ = np.sqrt(norm1 + norm2)
|
|
psi1_coeff[ki, :] /= norm_
|
|
psi2_coeff[ki, :] /= norm_
|
|
psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
|
|
psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0
|
|
|
|
phi = [np.poly1d(np.flip(phi_coeff[i, :])) for i in range(k)]
|
|
psi1 = [np.poly1d(np.flip(psi1_coeff[i, :])) for i in range(k)]
|
|
psi2 = [np.poly1d(np.flip(psi2_coeff[i, :])) for i in range(k)]
|
|
|
|
elif base == 'chebyshev':
|
|
for ki in range(k):
|
|
if ki == 0:
|
|
phi_coeff[ki, :ki + 1] = np.sqrt(2 / np.pi)
|
|
phi_2x_coeff[ki, :ki + 1] = np.sqrt(2 / np.pi) * np.sqrt(2)
|
|
else:
|
|
coeff_ = Poly(chebyshevt(ki, 2 * x - 1), x).all_coeffs()
|
|
phi_coeff[ki, :ki + 1] = np.flip(2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
|
|
coeff_ = Poly(chebyshevt(ki, 4 * x - 1), x).all_coeffs()
|
|
phi_2x_coeff[ki, :ki + 1] = np.flip(
|
|
np.sqrt(2) * 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
|
|
|
|
phi = [partial(phi_, phi_coeff[i, :]) for i in range(k)]
|
|
|
|
x = Symbol('x')
|
|
kUse = 2 * k
|
|
roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots()
|
|
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
|
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
|
|
# not needed for our purpose here, we use even k always to avoid
|
|
wm = np.pi / kUse / 2
|
|
|
|
psi1_coeff = np.zeros((k, k))
|
|
psi2_coeff = np.zeros((k, k))
|
|
|
|
psi1 = [[] for _ in range(k)]
|
|
psi2 = [[] for _ in range(k)]
|
|
|
|
for ki in range(k):
|
|
psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
|
|
for i in range(k):
|
|
proj_ = (wm * phi[i](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum()
|
|
psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
|
|
psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
|
|
|
|
for j in range(ki):
|
|
proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum()
|
|
psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
|
|
psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]
|
|
|
|
psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5)
|
|
psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5, ub=1)
|
|
|
|
norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum()
|
|
norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum()
|
|
|
|
norm_ = np.sqrt(norm1 + norm2)
|
|
psi1_coeff[ki, :] /= norm_
|
|
psi2_coeff[ki, :] /= norm_
|
|
psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
|
|
psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0
|
|
|
|
psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5 + 1e-16)
|
|
psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5 + 1e-16, ub=1)
|
|
|
|
return phi, psi1, psi2
|
|
|
|
|
|
def get_filter(base, k):
|
|
def psi(psi1, psi2, i, inp):
|
|
mask = (inp <= 0.5) * 1.0
|
|
return psi1[i](inp) * mask + psi2[i](inp) * (1 - mask)
|
|
|
|
if base not in ['legendre', 'chebyshev']:
|
|
raise Exception('Base not supported')
|
|
|
|
x = Symbol('x')
|
|
H0 = np.zeros((k, k))
|
|
H1 = np.zeros((k, k))
|
|
G0 = np.zeros((k, k))
|
|
G1 = np.zeros((k, k))
|
|
PHI0 = np.zeros((k, k))
|
|
PHI1 = np.zeros((k, k))
|
|
phi, psi1, psi2 = get_phi_psi(k, base)
|
|
if base == 'legendre':
|
|
roots = Poly(legendre(k, 2 * x - 1)).all_roots()
|
|
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
|
wm = 1 / k / legendreDer(k, 2 * x_m - 1) / eval_legendre(k - 1, 2 * x_m - 1)
|
|
|
|
for ki in range(k):
|
|
for kpi in range(k):
|
|
H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum()
|
|
G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum()
|
|
H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum()
|
|
G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum()
|
|
|
|
PHI0 = np.eye(k)
|
|
PHI1 = np.eye(k)
|
|
|
|
elif base == 'chebyshev':
|
|
x = Symbol('x')
|
|
kUse = 2 * k
|
|
roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots()
|
|
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
|
|
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
|
|
# not needed for our purpose here, we use even k always to avoid
|
|
wm = np.pi / kUse / 2
|
|
|
|
for ki in range(k):
|
|
for kpi in range(k):
|
|
H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum()
|
|
G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum()
|
|
H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum()
|
|
G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum()
|
|
|
|
PHI0[ki, kpi] = (wm * phi[ki](2 * x_m) * phi[kpi](2 * x_m)).sum() * 2
|
|
PHI1[ki, kpi] = (wm * phi[ki](2 * x_m - 1) * phi[kpi](2 * x_m - 1)).sum() * 2
|
|
|
|
PHI0[np.abs(PHI0) < 1e-8] = 0
|
|
PHI1[np.abs(PHI1) < 1e-8] = 0
|
|
|
|
H0[np.abs(H0) < 1e-8] = 0
|
|
H1[np.abs(H1) < 1e-8] = 0
|
|
G0[np.abs(G0) < 1e-8] = 0
|
|
G1[np.abs(G1) < 1e-8] = 0
|
|
|
|
return H0, H1, G0, G1, PHI0, PHI1
|
|
|
|
|
|
class MultiWaveletTransform(nn.Module):
|
|
"""
|
|
1D multiwavelet block.
|
|
"""
|
|
|
|
def __init__(self, ich=1, k=8, alpha=16, c=128,
|
|
nCZ=1, L=0, base='legendre', attention_dropout=0.1):
|
|
super(MultiWaveletTransform, self).__init__()
|
|
print('base', base)
|
|
self.k = k
|
|
self.c = c
|
|
self.L = L
|
|
self.nCZ = nCZ
|
|
self.Lk0 = nn.Linear(ich, c * k)
|
|
self.Lk1 = nn.Linear(c * k, ich)
|
|
self.ich = ich
|
|
self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ))
|
|
|
|
def forward(self, queries, keys, values, attn_mask):
|
|
B, L, H, E = queries.shape
|
|
_, S, _, D = values.shape
|
|
if L > S:
|
|
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
|
|
values = torch.cat([values, zeros], dim=1)
|
|
keys = torch.cat([keys, zeros], dim=1)
|
|
else:
|
|
values = values[:, :L, :, :]
|
|
keys = keys[:, :L, :, :]
|
|
values = values.view(B, L, -1)
|
|
|
|
V = self.Lk0(values).view(B, L, self.c, -1)
|
|
for i in range(self.nCZ):
|
|
V = self.MWT_CZ[i](V)
|
|
if i < self.nCZ - 1:
|
|
V = F.relu(V)
|
|
|
|
V = self.Lk1(V.view(B, L, -1))
|
|
V = V.view(B, L, -1, D)
|
|
return (V.contiguous(), None)
|
|
|
|
|
|
class MultiWaveletCross(nn.Module):
|
|
"""
|
|
1D Multiwavelet Cross Attention layer.
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes, c=64,
|
|
k=8, ich=512,
|
|
L=0,
|
|
base='legendre',
|
|
mode_select_method='random',
|
|
initializer=None, activation='tanh',
|
|
**kwargs):
|
|
super(MultiWaveletCross, self).__init__()
|
|
print('base', base)
|
|
|
|
self.c = c
|
|
self.k = k
|
|
self.L = L
|
|
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
|
|
H0r = H0 @ PHI0
|
|
G0r = G0 @ PHI0
|
|
H1r = H1 @ PHI1
|
|
G1r = G1 @ PHI1
|
|
|
|
H0r[np.abs(H0r) < 1e-8] = 0
|
|
H1r[np.abs(H1r) < 1e-8] = 0
|
|
G0r[np.abs(G0r) < 1e-8] = 0
|
|
G1r[np.abs(G1r) < 1e-8] = 0
|
|
self.max_item = 3
|
|
|
|
self.attn1 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
|
|
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
|
|
mode_select_method=mode_select_method)
|
|
self.attn2 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
|
|
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
|
|
mode_select_method=mode_select_method)
|
|
self.attn3 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
|
|
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
|
|
mode_select_method=mode_select_method)
|
|
self.attn4 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
|
|
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
|
|
mode_select_method=mode_select_method)
|
|
self.T0 = nn.Linear(k, k)
|
|
self.register_buffer('ec_s', torch.Tensor(
|
|
np.concatenate((H0.T, H1.T), axis=0)))
|
|
self.register_buffer('ec_d', torch.Tensor(
|
|
np.concatenate((G0.T, G1.T), axis=0)))
|
|
|
|
self.register_buffer('rc_e', torch.Tensor(
|
|
np.concatenate((H0r, G0r), axis=0)))
|
|
self.register_buffer('rc_o', torch.Tensor(
|
|
np.concatenate((H1r, G1r), axis=0)))
|
|
|
|
self.Lk = nn.Linear(ich, c * k)
|
|
self.Lq = nn.Linear(ich, c * k)
|
|
self.Lv = nn.Linear(ich, c * k)
|
|
self.out = nn.Linear(c * k, ich)
|
|
self.modes1 = modes
|
|
|
|
def forward(self, q, k, v, mask=None):
|
|
B, N, H, E = q.shape # (B, N, H, E) torch.Size([3, 768, 8, 2])
|
|
_, S, _, _ = k.shape # (B, S, H, E) torch.Size([3, 96, 8, 2])
|
|
|
|
q = q.view(q.shape[0], q.shape[1], -1)
|
|
k = k.view(k.shape[0], k.shape[1], -1)
|
|
v = v.view(v.shape[0], v.shape[1], -1)
|
|
q = self.Lq(q)
|
|
q = q.view(q.shape[0], q.shape[1], self.c, self.k)
|
|
k = self.Lk(k)
|
|
k = k.view(k.shape[0], k.shape[1], self.c, self.k)
|
|
v = self.Lv(v)
|
|
v = v.view(v.shape[0], v.shape[1], self.c, self.k)
|
|
|
|
if N > S:
|
|
zeros = torch.zeros_like(q[:, :(N - S), :]).float()
|
|
v = torch.cat([v, zeros], dim=1)
|
|
k = torch.cat([k, zeros], dim=1)
|
|
else:
|
|
v = v[:, :N, :, :]
|
|
k = k[:, :N, :, :]
|
|
|
|
ns = math.floor(np.log2(N))
|
|
nl = pow(2, math.ceil(np.log2(N)))
|
|
extra_q = q[:, 0:nl - N, :, :]
|
|
extra_k = k[:, 0:nl - N, :, :]
|
|
extra_v = v[:, 0:nl - N, :, :]
|
|
q = torch.cat([q, extra_q], 1)
|
|
k = torch.cat([k, extra_k], 1)
|
|
v = torch.cat([v, extra_v], 1)
|
|
|
|
Ud_q = torch.jit.annotate(List[Tuple[Tensor]], [])
|
|
Ud_k = torch.jit.annotate(List[Tuple[Tensor]], [])
|
|
Ud_v = torch.jit.annotate(List[Tuple[Tensor]], [])
|
|
|
|
Us_q = torch.jit.annotate(List[Tensor], [])
|
|
Us_k = torch.jit.annotate(List[Tensor], [])
|
|
Us_v = torch.jit.annotate(List[Tensor], [])
|
|
|
|
Ud = torch.jit.annotate(List[Tensor], [])
|
|
Us = torch.jit.annotate(List[Tensor], [])
|
|
|
|
# decompose
|
|
for i in range(ns - self.L):
|
|
# print('q shape',q.shape)
|
|
d, q = self.wavelet_transform(q)
|
|
Ud_q += [tuple([d, q])]
|
|
Us_q += [d]
|
|
for i in range(ns - self.L):
|
|
d, k = self.wavelet_transform(k)
|
|
Ud_k += [tuple([d, k])]
|
|
Us_k += [d]
|
|
for i in range(ns - self.L):
|
|
d, v = self.wavelet_transform(v)
|
|
Ud_v += [tuple([d, v])]
|
|
Us_v += [d]
|
|
for i in range(ns - self.L):
|
|
dk, sk = Ud_k[i], Us_k[i]
|
|
dq, sq = Ud_q[i], Us_q[i]
|
|
dv, sv = Ud_v[i], Us_v[i]
|
|
Ud += [self.attn1(dq[0], dk[0], dv[0], mask)[0] + self.attn2(dq[1], dk[1], dv[1], mask)[0]]
|
|
Us += [self.attn3(sq, sk, sv, mask)[0]]
|
|
v = self.attn4(q, k, v, mask)[0]
|
|
|
|
# reconstruct
|
|
for i in range(ns - 1 - self.L, -1, -1):
|
|
v = v + Us[i]
|
|
v = torch.cat((v, Ud[i]), -1)
|
|
v = self.evenOdd(v)
|
|
v = self.out(v[:, :N, :, :].contiguous().view(B, N, -1))
|
|
return (v.contiguous(), None)
|
|
|
|
def wavelet_transform(self, x):
|
|
xa = torch.cat([x[:, ::2, :, :],
|
|
x[:, 1::2, :, :],
|
|
], -1)
|
|
d = torch.matmul(xa, self.ec_d)
|
|
s = torch.matmul(xa, self.ec_s)
|
|
return d, s
|
|
|
|
def evenOdd(self, x):
|
|
B, N, c, ich = x.shape # (B, N, c, k)
|
|
assert ich == 2 * self.k
|
|
x_e = torch.matmul(x, self.rc_e)
|
|
x_o = torch.matmul(x, self.rc_o)
|
|
|
|
x = torch.zeros(B, N * 2, c, self.k,
|
|
device=x.device)
|
|
x[..., ::2, :, :] = x_e
|
|
x[..., 1::2, :, :] = x_o
|
|
return x
|
|
|
|
|
|
class FourierCrossAttentionW(nn.Module):
|
|
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=16, activation='tanh',
|
|
mode_select_method='random'):
|
|
super(FourierCrossAttentionW, self).__init__()
|
|
print('corss fourier correlation used!')
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.modes1 = modes
|
|
self.activation = activation
|
|
|
|
def compl_mul1d(self, order, x, weights):
|
|
x_flag = True
|
|
w_flag = True
|
|
if not torch.is_complex(x):
|
|
x_flag = False
|
|
x = torch.complex(x, torch.zeros_like(x).to(x.device))
|
|
if not torch.is_complex(weights):
|
|
w_flag = False
|
|
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
|
|
if x_flag or w_flag:
|
|
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
|
|
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
|
|
else:
|
|
return torch.einsum(order, x.real, weights.real)
|
|
|
|
def forward(self, q, k, v, mask):
|
|
B, L, E, H = q.shape
|
|
|
|
xq = q.permute(0, 3, 2, 1) # size = [B, H, E, L] torch.Size([3, 8, 64, 512])
|
|
xk = k.permute(0, 3, 2, 1)
|
|
xv = v.permute(0, 3, 2, 1)
|
|
self.index_q = list(range(0, min(int(L // 2), self.modes1)))
|
|
self.index_k_v = list(range(0, min(int(xv.shape[3] // 2), self.modes1)))
|
|
|
|
# Compute Fourier coefficients
|
|
xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)
|
|
xq_ft = torch.fft.rfft(xq, dim=-1)
|
|
for i, j in enumerate(self.index_q):
|
|
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
|
|
|
|
xk_ft_ = torch.zeros(B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat)
|
|
xk_ft = torch.fft.rfft(xk, dim=-1)
|
|
for i, j in enumerate(self.index_k_v):
|
|
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
|
|
xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_))
|
|
if self.activation == 'tanh':
|
|
xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
|
|
elif self.activation == 'softmax':
|
|
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
|
|
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
|
|
else:
|
|
raise Exception('{} actiation function is not implemented'.format(self.activation))
|
|
xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
|
|
|
|
xqkvw = xqkv_ft
|
|
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
|
|
for i, j in enumerate(self.index_q):
|
|
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
|
|
|
|
out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)).permute(0, 3, 2, 1)
|
|
# size = [B, L, H, E]
|
|
return (out, None)
|
|
|
|
|
|
class sparseKernelFT1d(nn.Module):
|
|
def __init__(self,
|
|
k, alpha, c=1,
|
|
nl=1,
|
|
initializer=None,
|
|
**kwargs):
|
|
super(sparseKernelFT1d, self).__init__()
|
|
|
|
self.modes1 = alpha
|
|
self.scale = (1 / (c * k * c * k))
|
|
self.weights1 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float))
|
|
self.weights2 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float))
|
|
self.weights1.requires_grad = True
|
|
self.weights2.requires_grad = True
|
|
self.k = k
|
|
|
|
def compl_mul1d(self, order, x, weights):
|
|
x_flag = True
|
|
w_flag = True
|
|
if not torch.is_complex(x):
|
|
x_flag = False
|
|
x = torch.complex(x, torch.zeros_like(x).to(x.device))
|
|
if not torch.is_complex(weights):
|
|
w_flag = False
|
|
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
|
|
if x_flag or w_flag:
|
|
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
|
|
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
|
|
else:
|
|
return torch.einsum(order, x.real, weights.real)
|
|
|
|
def forward(self, x):
|
|
B, N, c, k = x.shape # (B, N, c, k)
|
|
|
|
x = x.view(B, N, -1)
|
|
x = x.permute(0, 2, 1)
|
|
x_fft = torch.fft.rfft(x)
|
|
# Multiply relevant Fourier modes
|
|
l = min(self.modes1, N // 2 + 1)
|
|
out_ft = torch.zeros(B, c * k, N // 2 + 1, device=x.device, dtype=torch.cfloat)
|
|
out_ft[:, :, :l] = self.compl_mul1d("bix,iox->box", x_fft[:, :, :l],
|
|
torch.complex(self.weights1, self.weights2)[:, :, :l])
|
|
x = torch.fft.irfft(out_ft, n=N)
|
|
x = x.permute(0, 2, 1).view(B, N, c, k)
|
|
return x
|
|
|
|
|
|
# ##
|
|
class MWT_CZ1d(nn.Module):
|
|
def __init__(self,
|
|
k=3, alpha=64,
|
|
L=0, c=1,
|
|
base='legendre',
|
|
initializer=None,
|
|
**kwargs):
|
|
super(MWT_CZ1d, self).__init__()
|
|
|
|
self.k = k
|
|
self.L = L
|
|
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
|
|
H0r = H0 @ PHI0
|
|
G0r = G0 @ PHI0
|
|
H1r = H1 @ PHI1
|
|
G1r = G1 @ PHI1
|
|
|
|
H0r[np.abs(H0r) < 1e-8] = 0
|
|
H1r[np.abs(H1r) < 1e-8] = 0
|
|
G0r[np.abs(G0r) < 1e-8] = 0
|
|
G1r[np.abs(G1r) < 1e-8] = 0
|
|
self.max_item = 3
|
|
|
|
self.A = sparseKernelFT1d(k, alpha, c)
|
|
self.B = sparseKernelFT1d(k, alpha, c)
|
|
self.C = sparseKernelFT1d(k, alpha, c)
|
|
|
|
self.T0 = nn.Linear(k, k)
|
|
|
|
self.register_buffer('ec_s', torch.Tensor(
|
|
np.concatenate((H0.T, H1.T), axis=0)))
|
|
self.register_buffer('ec_d', torch.Tensor(
|
|
np.concatenate((G0.T, G1.T), axis=0)))
|
|
|
|
self.register_buffer('rc_e', torch.Tensor(
|
|
np.concatenate((H0r, G0r), axis=0)))
|
|
self.register_buffer('rc_o', torch.Tensor(
|
|
np.concatenate((H1r, G1r), axis=0)))
|
|
|
|
def forward(self, x):
|
|
B, N, c, k = x.shape # (B, N, k)
|
|
ns = math.floor(np.log2(N))
|
|
nl = pow(2, math.ceil(np.log2(N)))
|
|
extra_x = x[:, 0:nl - N, :, :]
|
|
x = torch.cat([x, extra_x], 1)
|
|
Ud = torch.jit.annotate(List[Tensor], [])
|
|
Us = torch.jit.annotate(List[Tensor], [])
|
|
for i in range(ns - self.L):
|
|
d, x = self.wavelet_transform(x)
|
|
Ud += [self.A(d) + self.B(x)]
|
|
Us += [self.C(d)]
|
|
x = self.T0(x) # coarsest scale transform
|
|
|
|
# reconstruct
|
|
for i in range(ns - 1 - self.L, -1, -1):
|
|
x = x + Us[i]
|
|
x = torch.cat((x, Ud[i]), -1)
|
|
x = self.evenOdd(x)
|
|
x = x[:, :N, :, :]
|
|
|
|
return x
|
|
|
|
def wavelet_transform(self, x):
|
|
xa = torch.cat([x[:, ::2, :, :],
|
|
x[:, 1::2, :, :],
|
|
], -1)
|
|
d = torch.matmul(xa, self.ec_d)
|
|
s = torch.matmul(xa, self.ec_s)
|
|
return d, s
|
|
|
|
def evenOdd(self, x):
|
|
|
|
B, N, c, ich = x.shape # (B, N, c, k)
|
|
assert ich == 2 * self.k
|
|
x_e = torch.matmul(x, self.rc_e)
|
|
x_o = torch.matmul(x, self.rc_o)
|
|
|
|
x = torch.zeros(B, N * 2, c, self.k,
|
|
device=x.device)
|
|
x[..., ::2, :, :] = x_e
|
|
x[..., 1::2, :, :] = x_o
|
|
return x
|