Files
TSlib/layers/DWT_Decomposition.py
2025-08-28 10:17:59 +00:00

1268 lines
46 KiB
Python

# -*- coding: utf-8 -*-
"""
Created on Sun Jan 5
@author: Murad
SISLab, USF
mmurad@usf.edu
https://github.com/Secure-and-Intelligent-Systems-Lab/WPMixer
"""
import torch
import torch.nn as nn
import pywt
import numpy as np
import torch.nn.functional as F
from torch.autograd import Function
class Decomposition(nn.Module):
def __init__(self,
input_length=[],
pred_length=[],
wavelet_name=[],
level=[],
batch_size=[],
channel=[],
d_model=[],
tfactor=[],
dfactor=[],
device=[],
no_decomposition=[],
use_amp=[]):
super(Decomposition, self).__init__()
self.input_length = input_length
self.pred_length = pred_length
self.wavelet_name = wavelet_name
self.level = level
self.batch_size = batch_size
self.channel = channel
self.d_model = d_model
self.device = device
self.no_decomposition = no_decomposition
self.use_amp = use_amp
self.eps = 1e-5
self.dwt = DWT1DForward(wave=self.wavelet_name, J=self.level,
use_amp=self.use_amp).cuda() if self.device.type == 'cuda' else DWT1DForward(
wave=self.wavelet_name, J=self.level, use_amp=self.use_amp)
self.idwt = DWT1DInverse(wave=self.wavelet_name,
use_amp=self.use_amp).cuda() if self.device.type == 'cuda' else DWT1DInverse(
wave=self.wavelet_name, use_amp=self.use_amp)
self.input_w_dim = self._dummy_forward(self.input_length) if not self.no_decomposition else [
self.input_length] # length of the input seq after decompose
self.pred_w_dim = self._dummy_forward(self.pred_length) if not self.no_decomposition else [
self.pred_length] # required length of the pred seq after decom
self.tfactor = tfactor
self.dfactor = dfactor
#################################
self.affine = False
#################################
if self.affine:
self._init_params()
def transform(self, x):
# input: x shape: batch, channel, seq
if not self.no_decomposition:
yl, yh = self._wavelet_decompose(x)
else:
yl, yh = x, [] # no decompose: returning the same value in yl
return yl, yh
def inv_transform(self, yl, yh):
if not self.no_decomposition:
x = self._wavelet_reverse_decompose(yl, yh)
else:
x = yl # no decompose: returning the same value in x
return x
def _dummy_forward(self, input_length):
dummy_x = torch.ones((self.batch_size, self.channel, input_length)).to(self.device)
yl, yh = self.dwt(dummy_x)
l = []
l.append(yl.shape[-1])
for i in range(len(yh)):
l.append(yh[i].shape[-1])
return l
def _init_params(self):
self.affine_weight = nn.Parameter(torch.ones((self.level + 1, self.channel)))
self.affine_bias = nn.Parameter(torch.zeros((self.level + 1, self.channel)))
def _wavelet_decompose(self, x):
# input: x shape: batch, channel, seq
yl, yh = self.dwt(x)
if self.affine:
yl = yl.transpose(1, 2) # batch, seq, channel
yl = yl * self.affine_weight[0]
yl = yl + self.affine_bias[0]
yl = yl.transpose(1, 2) # batch, channel, seq
for i in range(self.level):
yh_ = yh[i].transpose(1, 2) # batch, seq, channel
yh_ = yh_ * self.affine_weight[i + 1]
yh_ = yh_ + self.affine_bias[i + 1]
yh[i] = yh_.transpose(1, 2) # batch, channel, seq
return yl, yh
def _wavelet_reverse_decompose(self, yl, yh):
if self.affine:
yl = yl.transpose(1, 2) # batch, seq, channel
yl = yl - self.affine_bias[0]
yl = yl / (self.affine_weight[0] + self.eps)
yl = yl.transpose(1, 2) # batch, channel, seq
for i in range(self.level):
yh_ = yh[i].transpose(1, 2) # batch, seq, channel
yh_ = yh_ - self.affine_bias[i + 1]
yh_ = yh_ / (self.affine_weight[i + 1] + self.eps)
yh[i] = yh_.transpose(1, 2) # batch, channel, seq
x = self.idwt((yl, yh))
return x # shape: batch, channel, seq
###############################################################################################
"""
Following codes are combined from https://github.com/fbcotter/pytorch_wavelets.
To use Wavelet decomposition, you do not need to modify any of the codes below this line,
we can just play with the class Decomposition(above)
"""
###############################################################################################
class DWT1DForward(nn.Module):
""" Performs a 1d DWT Forward decomposition of an image
Args:
J (int): Number of levels of decomposition
wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use.
Can be:
1) a string to pass to pywt.Wavelet constructor
2) a pywt.Wavelet class
3) a tuple of numpy arrays (h0, h1)
mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The
padding scheme
"""
def __init__(self, J=1, wave='db1', mode='zero', use_amp=False):
super().__init__()
self.use_amp = use_amp
if isinstance(wave, str):
wave = pywt.Wavelet(wave)
if isinstance(wave, pywt.Wavelet):
h0, h1 = wave.dec_lo, wave.dec_hi
else:
assert len(wave) == 2
h0, h1 = wave[0], wave[1]
# Prepare the filters - this makes them into column filters
filts = prep_filt_afb1d(h0, h1)
self.register_buffer('h0', filts[0])
self.register_buffer('h1', filts[1])
self.J = J
self.mode = mode
def forward(self, x):
""" Forward pass of the DWT.
Args:
x (tensor): Input of shape :math:`(N, C_{in}, L_{in})`
Returns:
(yl, yh)
tuple of lowpass (yl) and bandpass (yh) coefficients.
yh is a list of length J with the first entry
being the finest scale coefficients.
"""
assert x.ndim == 3, "Can only handle 3d inputs (N, C, L)"
highs = []
x0 = x
mode = mode_to_int(self.mode)
# Do a multilevel transform
for j in range(self.J):
x0, x1 = AFB1D.apply(x0, self.h0, self.h1, mode, self.use_amp)
highs.append(x1)
return x0, highs
class DWT1DInverse(nn.Module):
""" Performs a 1d DWT Inverse reconstruction of an image
Args:
wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use.
Can be:
1) a string to pass to pywt.Wavelet constructor
2) a pywt.Wavelet class
3) a tuple of numpy arrays (h0, h1)
mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The
padding scheme
"""
def __init__(self, wave='db1', mode='zero', use_amp=False):
super().__init__()
self.use_amp = use_amp
if isinstance(wave, str):
wave = pywt.Wavelet(wave)
if isinstance(wave, pywt.Wavelet):
g0, g1 = wave.rec_lo, wave.rec_hi
else:
assert len(wave) == 2
g0, g1 = wave[0], wave[1]
# Prepare the filters
filts = prep_filt_sfb1d(g0, g1)
self.register_buffer('g0', filts[0])
self.register_buffer('g1', filts[1])
self.mode = mode
def forward(self, coeffs):
"""
Args:
coeffs (yl, yh): tuple of lowpass and bandpass coefficients, should
match the format returned by DWT1DForward.
Returns:
Reconstructed input of shape :math:`(N, C_{in}, L_{in})`
Note:
Can have None for any of the highpass scales and will treat the
values as zeros (not in an efficient way though).
"""
x0, highs = coeffs
assert x0.ndim == 3, "Can only handle 3d inputs (N, C, L)"
mode = mode_to_int(self.mode)
# Do a multilevel inverse transform
for x1 in highs[::-1]:
if x1 is None:
x1 = torch.zeros_like(x0)
# 'Unpad' added signal
if x0.shape[-1] > x1.shape[-1]:
x0 = x0[..., :-1]
x0 = SFB1D.apply(x0, x1, self.g0, self.g1, mode, self.use_amp)
return x0
def roll(x, n, dim, make_even=False):
if n < 0:
n = x.shape[dim] + n
if make_even and x.shape[dim] % 2 == 1:
end = 1
else:
end = 0
if dim == 0:
return torch.cat((x[-n:], x[:-n + end]), dim=0)
elif dim == 1:
return torch.cat((x[:, -n:], x[:, :-n + end]), dim=1)
elif dim == 2 or dim == -2:
return torch.cat((x[:, :, -n:], x[:, :, :-n + end]), dim=2)
elif dim == 3 or dim == -1:
return torch.cat((x[:, :, :, -n:], x[:, :, :, :-n + end]), dim=3)
def mypad(x, pad, mode='constant', value=0):
""" Function to do numpy like padding on tensors. Only works for 2-D
padding.
Inputs:
x (tensor): tensor to pad
pad (tuple): tuple of (left, right, top, bottom) pad sizes
mode (str): 'symmetric', 'wrap', 'constant, 'reflect', 'replicate', or
'zero'. The padding technique.
"""
if mode == 'symmetric':
# Vertical only
if pad[0] == 0 and pad[1] == 0:
m1, m2 = pad[2], pad[3]
l = x.shape[-2]
xe = reflect(np.arange(-m1, l + m2, dtype='int32'), -0.5, l - 0.5)
return x[:, :, xe]
# horizontal only
elif pad[2] == 0 and pad[3] == 0:
m1, m2 = pad[0], pad[1]
l = x.shape[-1]
xe = reflect(np.arange(-m1, l + m2, dtype='int32'), -0.5, l - 0.5)
return x[:, :, :, xe]
# Both
else:
m1, m2 = pad[0], pad[1]
l1 = x.shape[-1]
xe_row = reflect(np.arange(-m1, l1 + m2, dtype='int32'), -0.5, l1 - 0.5)
m1, m2 = pad[2], pad[3]
l2 = x.shape[-2]
xe_col = reflect(np.arange(-m1, l2 + m2, dtype='int32'), -0.5, l2 - 0.5)
i = np.outer(xe_col, np.ones(xe_row.shape[0]))
j = np.outer(np.ones(xe_col.shape[0]), xe_row)
return x[:, :, i, j]
elif mode == 'periodic':
# Vertical only
if pad[0] == 0 and pad[1] == 0:
xe = np.arange(x.shape[-2])
xe = np.pad(xe, (pad[2], pad[3]), mode='wrap')
return x[:, :, xe]
# Horizontal only
elif pad[2] == 0 and pad[3] == 0:
xe = np.arange(x.shape[-1])
xe = np.pad(xe, (pad[0], pad[1]), mode='wrap')
return x[:, :, :, xe]
# Both
else:
xe_col = np.arange(x.shape[-2])
xe_col = np.pad(xe_col, (pad[2], pad[3]), mode='wrap')
xe_row = np.arange(x.shape[-1])
xe_row = np.pad(xe_row, (pad[0], pad[1]), mode='wrap')
i = np.outer(xe_col, np.ones(xe_row.shape[0]))
j = np.outer(np.ones(xe_col.shape[0]), xe_row)
return x[:, :, i, j]
elif mode == 'constant' or mode == 'reflect' or mode == 'replicate':
return F.pad(x, pad, mode, value)
elif mode == 'zero':
return F.pad(x, pad)
else:
raise ValueError("Unkown pad type: {}".format(mode))
def afb1d(x, h0, h1, use_amp, mode='zero', dim=-1):
""" 1D analysis filter bank (along one dimension only) of an image
Inputs:
x (tensor): 4D input with the last two dimensions the spatial input
h0 (tensor): 4D input for the lowpass filter. Should have shape (1, 1,
h, 1) or (1, 1, 1, w)
h1 (tensor): 4D input for the highpass filter. Should have shape (1, 1,
h, 1) or (1, 1, 1, w)
mode (str): padding method
dim (int) - dimension of filtering. d=2 is for a vertical filter (called
column filtering but filters across the rows). d=3 is for a
horizontal filter, (called row filtering but filters across the
columns).
Returns:
lohi: lowpass and highpass subbands concatenated along the channel
dimension
"""
C = x.shape[1]
# Convert the dim to positive
d = dim % 4
s = (2, 1) if d == 2 else (1, 2)
N = x.shape[d]
# If h0, h1 are not tensors, make them. If they are, then assume that they
# are in the right order
if not isinstance(h0, torch.Tensor):
h0 = torch.tensor(np.copy(np.array(h0).ravel()[::-1]),
dtype=torch.float, device=x.device)
if not isinstance(h1, torch.Tensor):
h1 = torch.tensor(np.copy(np.array(h1).ravel()[::-1]),
dtype=torch.float, device=x.device)
L = h0.numel()
L2 = L // 2
shape = [1, 1, 1, 1]
shape[d] = L
# If h aren't in the right shape, make them so
if h0.shape != tuple(shape):
h0 = h0.reshape(*shape)
if h1.shape != tuple(shape):
h1 = h1.reshape(*shape)
h = torch.cat([h0, h1] * C, dim=0)
if mode == 'per' or mode == 'periodization':
if x.shape[dim] % 2 == 1:
if d == 2:
x = torch.cat((x, x[:, :, -1:]), dim=2)
else:
x = torch.cat((x, x[:, :, :, -1:]), dim=3)
N += 1
x = roll(x, -L2, dim=d)
pad = (L - 1, 0) if d == 2 else (0, L - 1)
if use_amp:
with torch.cuda.amp.autocast(): # for mixed precision
lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
else:
lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
N2 = N // 2
if d == 2:
lohi[:, :, :L2] = lohi[:, :, :L2] + lohi[:, :, N2:N2 + L2]
lohi = lohi[:, :, :N2]
else:
lohi[:, :, :, :L2] = lohi[:, :, :, :L2] + lohi[:, :, :, N2:N2 + L2]
lohi = lohi[:, :, :, :N2]
else:
# Calculate the pad size
outsize = pywt.dwt_coeff_len(N, L, mode=mode)
p = 2 * (outsize - 1) - N + L
if mode == 'zero':
# Sadly, pytorch only allows for same padding before and after, if
# we need to do more padding after for odd length signals, have to
# prepad
if p % 2 == 1:
pad = (0, 0, 0, 1) if d == 2 else (0, 1, 0, 0)
x = F.pad(x, pad)
pad = (p // 2, 0) if d == 2 else (0, p // 2)
# Calculate the high and lowpass
if use_amp:
with torch.cuda.amp.autocast():
lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
else:
lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
elif mode == 'symmetric' or mode == 'reflect' or mode == 'periodic':
pad = (0, 0, p // 2, (p + 1) // 2) if d == 2 else (p // 2, (p + 1) // 2, 0, 0)
x = mypad(x, pad=pad, mode=mode)
if use_amp:
with torch.cuda.amp.autocast():
lohi = F.conv2d(x, h, stride=s, groups=C)
else:
lohi = F.conv2d(x, h, stride=s, groups=C)
else:
raise ValueError("Unkown pad type: {}".format(mode))
return lohi
def afb1d_atrous(x, h0, h1, mode='periodic', dim=-1, dilation=1):
""" 1D analysis filter bank (along one dimension only) of an image without
downsampling. Does the a trous algorithm.
Inputs:
x (tensor): 4D input with the last two dimensions the spatial input
h0 (tensor): 4D input for the lowpass filter. Should have shape (1, 1,
h, 1) or (1, 1, 1, w)
h1 (tensor): 4D input for the highpass filter. Should have shape (1, 1,
h, 1) or (1, 1, 1, w)
mode (str): padding method
dim (int) - dimension of filtering. d=2 is for a vertical filter (called
column filtering but filters across the rows). d=3 is for a
horizontal filter, (called row filtering but filters across the
columns).
dilation (int): dilation factor. Should be a power of 2.
Returns:
lohi: lowpass and highpass subbands concatenated along the channel
dimension
"""
C = x.shape[1]
# Convert the dim to positive
d = dim % 4
# If h0, h1 are not tensors, make them. If they are, then assume that they
# are in the right order
if not isinstance(h0, torch.Tensor):
h0 = torch.tensor(np.copy(np.array(h0).ravel()[::-1]),
dtype=torch.float, device=x.device)
if not isinstance(h1, torch.Tensor):
h1 = torch.tensor(np.copy(np.array(h1).ravel()[::-1]),
dtype=torch.float, device=x.device)
L = h0.numel()
shape = [1, 1, 1, 1]
shape[d] = L
# If h aren't in the right shape, make them so
if h0.shape != tuple(shape):
h0 = h0.reshape(*shape)
if h1.shape != tuple(shape):
h1 = h1.reshape(*shape)
h = torch.cat([h0, h1] * C, dim=0)
# Calculate the pad size
L2 = (L * dilation) // 2
pad = (0, 0, L2 - dilation, L2) if d == 2 else (L2 - dilation, L2, 0, 0)
x = mypad(x, pad=pad, mode=mode)
lohi = F.conv2d(x, h, groups=C, dilation=dilation)
return lohi
def sfb1d(lo, hi, g0, g1, use_amp, mode='zero', dim=-1):
""" 1D synthesis filter bank of an image tensor
"""
C = lo.shape[1]
d = dim % 4
# If g0, g1 are not tensors, make them. If they are, then assume that they
# are in the right order
if not isinstance(g0, torch.Tensor):
g0 = torch.tensor(np.copy(np.array(g0).ravel()),
dtype=torch.float, device=lo.device)
if not isinstance(g1, torch.Tensor):
g1 = torch.tensor(np.copy(np.array(g1).ravel()),
dtype=torch.float, device=lo.device)
L = g0.numel()
shape = [1, 1, 1, 1]
shape[d] = L
N = 2 * lo.shape[d]
# If g aren't in the right shape, make them so
if g0.shape != tuple(shape):
g0 = g0.reshape(*shape)
if g1.shape != tuple(shape):
g1 = g1.reshape(*shape)
s = (2, 1) if d == 2 else (1, 2)
g0 = torch.cat([g0] * C, dim=0)
g1 = torch.cat([g1] * C, dim=0)
if mode == 'per' or mode == 'periodization':
if use_amp:
with torch.cuda.amp.autocast():
y = F.conv_transpose2d(lo, g0, stride=s, groups=C) + \
F.conv_transpose2d(hi, g1, stride=s, groups=C)
else:
y = F.conv_transpose2d(lo, g0, stride=s, groups=C) + \
F.conv_transpose2d(hi, g1, stride=s, groups=C)
if d == 2:
y[:, :, :L - 2] = y[:, :, :L - 2] + y[:, :, N:N + L - 2]
y = y[:, :, :N]
else:
y[:, :, :, :L - 2] = y[:, :, :, :L - 2] + y[:, :, :, N:N + L - 2]
y = y[:, :, :, :N]
y = roll(y, 1 - L // 2, dim=dim)
else:
if mode == 'zero' or mode == 'symmetric' or mode == 'reflect' or \
mode == 'periodic':
pad = (L - 2, 0) if d == 2 else (0, L - 2)
if use_amp:
with torch.cuda.amp.autocast():
y = F.conv_transpose2d(lo, g0, stride=s, padding=pad, groups=C) + \
F.conv_transpose2d(hi, g1, stride=s, padding=pad, groups=C)
else:
y = F.conv_transpose2d(lo, g0, stride=s, padding=pad, groups=C) + \
F.conv_transpose2d(hi, g1, stride=s, padding=pad, groups=C)
else:
raise ValueError("Unkown pad type: {}".format(mode))
return y
def mode_to_int(mode):
if mode == 'zero':
return 0
elif mode == 'symmetric':
return 1
elif mode == 'per' or mode == 'periodization':
return 2
elif mode == 'constant':
return 3
elif mode == 'reflect':
return 4
elif mode == 'replicate':
return 5
elif mode == 'periodic':
return 6
else:
raise ValueError("Unkown pad type: {}".format(mode))
def int_to_mode(mode):
if mode == 0:
return 'zero'
elif mode == 1:
return 'symmetric'
elif mode == 2:
return 'periodization'
elif mode == 3:
return 'constant'
elif mode == 4:
return 'reflect'
elif mode == 5:
return 'replicate'
elif mode == 6:
return 'periodic'
else:
raise ValueError("Unkown pad type: {}".format(mode))
class AFB2D(Function):
""" Does a single level 2d wavelet decomposition of an input. Does separate
row and column filtering by two calls to
:py:func:`pytorch_wavelets.dwt.lowlevel.afb1d`
Needs to have the tensors in the right form. Because this function defines
its own backward pass, saves on memory by not having to save the input
tensors.
Inputs:
x (torch.Tensor): Input to decompose
h0_row: row lowpass
h1_row: row highpass
h0_col: col lowpass
h1_col: col highpass
mode (int): use mode_to_int to get the int code here
We encode the mode as an integer rather than a string as gradcheck causes an
error when a string is provided.
Returns:
y: Tensor of shape (N, C*4, H, W)
"""
@staticmethod
def forward(ctx, x, h0_row, h1_row, h0_col, h1_col, mode):
ctx.save_for_backward(h0_row, h1_row, h0_col, h1_col)
ctx.shape = x.shape[-2:]
mode = int_to_mode(mode)
ctx.mode = mode
lohi = afb1d(x, h0_row, h1_row, mode=mode, dim=3)
y = afb1d(lohi, h0_col, h1_col, mode=mode, dim=2)
s = y.shape
y = y.reshape(s[0], -1, 4, s[-2], s[-1])
low = y[:, :, 0].contiguous()
highs = y[:, :, 1:].contiguous()
return low, highs
@staticmethod
def backward(ctx, low, highs):
dx = None
if ctx.needs_input_grad[0]:
mode = ctx.mode
h0_row, h1_row, h0_col, h1_col = ctx.saved_tensors
lh, hl, hh = torch.unbind(highs, dim=2)
lo = sfb1d(low, lh, h0_col, h1_col, mode=mode, dim=2)
hi = sfb1d(hl, hh, h0_col, h1_col, mode=mode, dim=2)
dx = sfb1d(lo, hi, h0_row, h1_row, mode=mode, dim=3)
if dx.shape[-2] > ctx.shape[-2] and dx.shape[-1] > ctx.shape[-1]:
dx = dx[:, :, :ctx.shape[-2], :ctx.shape[-1]]
elif dx.shape[-2] > ctx.shape[-2]:
dx = dx[:, :, :ctx.shape[-2]]
elif dx.shape[-1] > ctx.shape[-1]:
dx = dx[:, :, :, :ctx.shape[-1]]
return dx, None, None, None, None, None
class AFB1D(Function):
""" Does a single level 1d wavelet decomposition of an input.
Needs to have the tensors in the right form. Because this function defines
its own backward pass, saves on memory by not having to save the input
tensors.
Inputs:
x (torch.Tensor): Input to decompose
h0: lowpass
h1: highpass
mode (int): use mode_to_int to get the int code here
We encode the mode as an integer rather than a string as gradcheck causes an
error when a string is provided.
Returns:
x0: Tensor of shape (N, C, L') - lowpass
x1: Tensor of shape (N, C, L') - highpass
"""
@staticmethod
def forward(ctx, x, h0, h1, mode, use_amp):
mode = int_to_mode(mode)
# Make inputs 4d
x = x[:, :, None, :]
h0 = h0[:, :, None, :]
h1 = h1[:, :, None, :]
# Save for backwards
ctx.save_for_backward(h0, h1)
ctx.shape = x.shape[3]
ctx.mode = mode
ctx.use_amp = use_amp
lohi = afb1d(x, h0, h1, use_amp, mode=mode, dim=3)
x0 = lohi[:, ::2, 0].contiguous()
x1 = lohi[:, 1::2, 0].contiguous()
return x0, x1
@staticmethod
def backward(ctx, dx0, dx1):
dx = None
if ctx.needs_input_grad[0]:
mode = ctx.mode
h0, h1 = ctx.saved_tensors
use_amp = ctx.use_amp
# Make grads 4d
dx0 = dx0[:, :, None, :]
dx1 = dx1[:, :, None, :]
dx = sfb1d(dx0, dx1, h0, h1, use_amp, mode=mode, dim=3)[:, :, 0]
# Check for odd input
if dx.shape[2] > ctx.shape:
dx = dx[:, :, :ctx.shape]
return dx, None, None, None, None, None
def afb2d(x, filts, mode='zero'):
""" Does a single level 2d wavelet decomposition of an input. Does separate
row and column filtering by two calls to
:py:func:`pytorch_wavelets.dwt.lowlevel.afb1d`
Inputs:
x (torch.Tensor): Input to decompose
filts (list of ndarray or torch.Tensor): If a list of tensors has been
given, this function assumes they are in the right form (the form
returned by
:py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`).
Otherwise, this function will prepare the filters to be of the right
form by calling
:py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`.
mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which
padding to use. If periodization, the output size will be half the
input size. Otherwise, the output size will be slightly larger than
half.
Returns:
y: Tensor of shape (N, C*4, H, W)
"""
tensorize = [not isinstance(f, torch.Tensor) for f in filts]
if len(filts) == 2:
h0, h1 = filts
if True in tensorize:
h0_col, h1_col, h0_row, h1_row = prep_filt_afb2d(
h0, h1, device=x.device)
else:
h0_col = h0
h0_row = h0.transpose(2, 3)
h1_col = h1
h1_row = h1.transpose(2, 3)
elif len(filts) == 4:
if True in tensorize:
h0_col, h1_col, h0_row, h1_row = prep_filt_afb2d(
*filts, device=x.device)
else:
h0_col, h1_col, h0_row, h1_row = filts
else:
raise ValueError("Unknown form for input filts")
lohi = afb1d(x, h0_row, h1_row, mode=mode, dim=3)
y = afb1d(lohi, h0_col, h1_col, mode=mode, dim=2)
return y
def afb2d_atrous(x, filts, mode='periodization', dilation=1):
""" Does a single level 2d wavelet decomposition of an input. Does separate
row and column filtering by two calls to
:py:func:`pytorch_wavelets.dwt.lowlevel.afb1d`
Inputs:
x (torch.Tensor): Input to decompose
filts (list of ndarray or torch.Tensor): If a list of tensors has been
given, this function assumes they are in the right form (the form
returned by
:py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`).
Otherwise, this function will prepare the filters to be of the right
form by calling
:py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`.
mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which
padding to use. If periodization, the output size will be half the
input size. Otherwise, the output size will be slightly larger than
half.
dilation (int): dilation factor for the filters. Should be 2**level
Returns:
y: Tensor of shape (N, C, 4, H, W)
"""
tensorize = [not isinstance(f, torch.Tensor) for f in filts]
if len(filts) == 2:
h0, h1 = filts
if True in tensorize:
h0_col, h1_col, h0_row, h1_row = prep_filt_afb2d(
h0, h1, device=x.device)
else:
h0_col = h0
h0_row = h0.transpose(2, 3)
h1_col = h1
h1_row = h1.transpose(2, 3)
elif len(filts) == 4:
if True in tensorize:
h0_col, h1_col, h0_row, h1_row = prep_filt_afb2d(
*filts, device=x.device)
else:
h0_col, h1_col, h0_row, h1_row = filts
else:
raise ValueError("Unknown form for input filts")
lohi = afb1d_atrous(x, h0_row, h1_row, mode=mode, dim=3, dilation=dilation)
y = afb1d_atrous(lohi, h0_col, h1_col, mode=mode, dim=2, dilation=dilation)
return y
def afb2d_nonsep(x, filts, mode='zero'):
""" Does a 1 level 2d wavelet decomposition of an input. Doesn't do separate
row and column filtering.
Inputs:
x (torch.Tensor): Input to decompose
filts (list or torch.Tensor): If a list is given, should be the low and
highpass filter banks. If a tensor is given, it should be of the
form created by
:py:func:`pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d_nonsep`
mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which
padding to use. If periodization, the output size will be half the
input size. Otherwise, the output size will be slightly larger than
half.
Returns:
y: Tensor of shape (N, C, 4, H, W)
"""
C = x.shape[1]
Ny = x.shape[2]
Nx = x.shape[3]
# Check the filter inputs
if isinstance(filts, (tuple, list)):
if len(filts) == 2:
filts = prep_filt_afb2d_nonsep(filts[0], filts[1], device=x.device)
else:
filts = prep_filt_afb2d_nonsep(
filts[0], filts[1], filts[2], filts[3], device=x.device)
f = torch.cat([filts] * C, dim=0)
Ly = f.shape[2]
Lx = f.shape[3]
if mode == 'periodization' or mode == 'per':
if x.shape[2] % 2 == 1:
x = torch.cat((x, x[:, :, -1:]), dim=2)
Ny += 1
if x.shape[3] % 2 == 1:
x = torch.cat((x, x[:, :, :, -1:]), dim=3)
Nx += 1
pad = (Ly - 1, Lx - 1)
stride = (2, 2)
x = roll(roll(x, -Ly // 2, dim=2), -Lx // 2, dim=3)
y = F.conv2d(x, f, padding=pad, stride=stride, groups=C)
y[:, :, :Ly // 2] += y[:, :, Ny // 2:Ny // 2 + Ly // 2]
y[:, :, :, :Lx // 2] += y[:, :, :, Nx // 2:Nx // 2 + Lx // 2]
y = y[:, :, :Ny // 2, :Nx // 2]
elif mode == 'zero' or mode == 'symmetric' or mode == 'reflect':
# Calculate the pad size
out1 = pywt.dwt_coeff_len(Ny, Ly, mode=mode)
out2 = pywt.dwt_coeff_len(Nx, Lx, mode=mode)
p1 = 2 * (out1 - 1) - Ny + Ly
p2 = 2 * (out2 - 1) - Nx + Lx
if mode == 'zero':
# Sadly, pytorch only allows for same padding before and after, if
# we need to do more padding after for odd length signals, have to
# prepad
if p1 % 2 == 1 and p2 % 2 == 1:
x = F.pad(x, (0, 1, 0, 1))
elif p1 % 2 == 1:
x = F.pad(x, (0, 0, 0, 1))
elif p2 % 2 == 1:
x = F.pad(x, (0, 1, 0, 0))
# Calculate the high and lowpass
y = F.conv2d(
x, f, padding=(p1 // 2, p2 // 2), stride=2, groups=C)
elif mode == 'symmetric' or mode == 'reflect' or mode == 'periodic':
pad = (p2 // 2, (p2 + 1) // 2, p1 // 2, (p1 + 1) // 2)
x = mypad(x, pad=pad, mode=mode)
y = F.conv2d(x, f, stride=2, groups=C)
else:
raise ValueError("Unkown pad type: {}".format(mode))
return y
def sfb2d(ll, lh, hl, hh, filts, mode='zero'):
""" Does a single level 2d wavelet reconstruction of wavelet coefficients.
Does separate row and column filtering by two calls to
:py:func:`pytorch_wavelets.dwt.lowlevel.sfb1d`
Inputs:
ll (torch.Tensor): lowpass coefficients
lh (torch.Tensor): horizontal coefficients
hl (torch.Tensor): vertical coefficients
hh (torch.Tensor): diagonal coefficients
filts (list of ndarray or torch.Tensor): If a list of tensors has been
given, this function assumes they are in the right form (the form
returned by
:py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d`).
Otherwise, this function will prepare the filters to be of the right
form by calling
:py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d`.
mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which
padding to use. If periodization, the output size will be half the
input size. Otherwise, the output size will be slightly larger than
half.
"""
tensorize = [not isinstance(x, torch.Tensor) for x in filts]
if len(filts) == 2:
g0, g1 = filts
if True in tensorize:
g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(g0, g1)
else:
g0_col = g0
g0_row = g0.transpose(2, 3)
g1_col = g1
g1_row = g1.transpose(2, 3)
elif len(filts) == 4:
if True in tensorize:
g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(*filts)
else:
g0_col, g1_col, g0_row, g1_row = filts
else:
raise ValueError("Unknown form for input filts")
lo = sfb1d(ll, lh, g0_col, g1_col, mode=mode, dim=2)
hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2)
y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3)
return y
class SFB2D(Function):
""" Does a single level 2d wavelet decomposition of an input. Does separate
row and column filtering by two calls to
:py:func:`pytorch_wavelets.dwt.lowlevel.afb1d`
Needs to have the tensors in the right form. Because this function defines
its own backward pass, saves on memory by not having to save the input
tensors.
Inputs:
x (torch.Tensor): Input to decompose
h0_row: row lowpass
h1_row: row highpass
h0_col: col lowpass
h1_col: col highpass
mode (int): use mode_to_int to get the int code here
We encode the mode as an integer rather than a string as gradcheck causes an
error when a string is provided.
Returns:
y: Tensor of shape (N, C*4, H, W)
"""
@staticmethod
def forward(ctx, low, highs, g0_row, g1_row, g0_col, g1_col, mode):
mode = int_to_mode(mode)
ctx.mode = mode
ctx.save_for_backward(g0_row, g1_row, g0_col, g1_col)
lh, hl, hh = torch.unbind(highs, dim=2)
lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2)
hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2)
y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3)
return y
@staticmethod
def backward(ctx, dy):
dlow, dhigh = None, None
if ctx.needs_input_grad[0]:
mode = ctx.mode
g0_row, g1_row, g0_col, g1_col = ctx.saved_tensors
dx = afb1d(dy, g0_row, g1_row, mode=mode, dim=3)
dx = afb1d(dx, g0_col, g1_col, mode=mode, dim=2)
s = dx.shape
dx = dx.reshape(s[0], -1, 4, s[-2], s[-1])
dlow = dx[:, :, 0].contiguous()
dhigh = dx[:, :, 1:].contiguous()
return dlow, dhigh, None, None, None, None, None
class SFB1D(Function):
""" Does a single level 1d wavelet decomposition of an input.
Needs to have the tensors in the right form. Because this function defines
its own backward pass, saves on memory by not having to save the input
tensors.
Inputs:
low (torch.Tensor): Lowpass to reconstruct of shape (N, C, L)
high (torch.Tensor): Highpass to reconstruct of shape (N, C, L)
g0: lowpass
g1: highpass
mode (int): use mode_to_int to get the int code here
We encode the mode as an integer rather than a string as gradcheck causes an
error when a string is provided.
Returns:
y: Tensor of shape (N, C*2, L')
"""
@staticmethod
def forward(ctx, low, high, g0, g1, mode, use_amp):
mode = int_to_mode(mode)
# Make into a 2d tensor with 1 row
low = low[:, :, None, :]
high = high[:, :, None, :]
g0 = g0[:, :, None, :]
g1 = g1[:, :, None, :]
ctx.mode = mode
ctx.save_for_backward(g0, g1)
ctx.use_amp = use_amp
return sfb1d(low, high, g0, g1, use_amp, mode=mode, dim=3)[:, :, 0]
@staticmethod
def backward(ctx, dy):
dlow, dhigh = None, None
if ctx.needs_input_grad[0]:
mode = ctx.mode
use_amp = ctx.use_amp
g0, g1, = ctx.saved_tensors
dy = dy[:, :, None, :]
dx = afb1d(dy, g0, g1, use_amp, mode=mode, dim=3)
dlow = dx[:, ::2, 0].contiguous()
dhigh = dx[:, 1::2, 0].contiguous()
return dlow, dhigh, None, None, None, None, None
def sfb2d_nonsep(coeffs, filts, mode='zero'):
""" Does a single level 2d wavelet reconstruction of wavelet coefficients.
Does not do separable filtering.
Inputs:
coeffs (torch.Tensor): tensor of coefficients of shape (N, C, 4, H, W)
where the third dimension indexes across the (ll, lh, hl, hh) bands.
filts (list of ndarray or torch.Tensor): If a list of tensors has been
given, this function assumes they are in the right form (the form
returned by
:py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d_nonsep`).
Otherwise, this function will prepare the filters to be of the right
form by calling
:py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d_nonsep`.
mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which
padding to use. If periodization, the output size will be half the
input size. Otherwise, the output size will be slightly larger than
half.
"""
C = coeffs.shape[1]
Ny = coeffs.shape[-2]
Nx = coeffs.shape[-1]
# Check the filter inputs - should be in the form of a torch tensor, but if
# not, tensorize it here.
if isinstance(filts, (tuple, list)):
if len(filts) == 2:
filts = prep_filt_sfb2d_nonsep(filts[0], filts[1],
device=coeffs.device)
elif len(filts) == 4:
filts = prep_filt_sfb2d_nonsep(
filts[0], filts[1], filts[2], filts[3], device=coeffs.device)
else:
raise ValueError("Unkown form for input filts")
f = torch.cat([filts] * C, dim=0)
Ly = f.shape[2]
Lx = f.shape[3]
x = coeffs.reshape(coeffs.shape[0], -1, coeffs.shape[-2], coeffs.shape[-1])
if mode == 'periodization' or mode == 'per':
ll = F.conv_transpose2d(x, f, groups=C, stride=2)
ll[:, :, :Ly - 2] += ll[:, :, 2 * Ny:2 * Ny + Ly - 2]
ll[:, :, :, :Lx - 2] += ll[:, :, :, 2 * Nx:2 * Nx + Lx - 2]
ll = ll[:, :, :2 * Ny, :2 * Nx]
ll = roll(roll(ll, 1 - Ly // 2, dim=2), 1 - Lx // 2, dim=3)
elif mode == 'symmetric' or mode == 'zero' or mode == 'reflect' or \
mode == 'periodic':
pad = (Ly - 2, Lx - 2)
ll = F.conv_transpose2d(x, f, padding=pad, groups=C, stride=2)
else:
raise ValueError("Unkown pad type: {}".format(mode))
return ll.contiguous()
def prep_filt_afb2d_nonsep(h0_col, h1_col, h0_row=None, h1_row=None,
device=None):
"""
Prepares the filters to be of the right form for the afb2d_nonsep function.
In particular, makes 2d point spread functions, and mirror images them in
preparation to do torch.conv2d.
Inputs:
h0_col (array-like): low pass column filter bank
h1_col (array-like): high pass column filter bank
h0_row (array-like): low pass row filter bank. If none, will assume the
same as column filter
h1_row (array-like): high pass row filter bank. If none, will assume the
same as column filter
device: which device to put the tensors on to
Returns:
filts: (4, 1, h, w) tensor ready to get the four subbands
"""
h0_col = np.array(h0_col).ravel()
h1_col = np.array(h1_col).ravel()
if h0_row is None:
h0_row = h0_col
if h1_row is None:
h1_row = h1_col
ll = np.outer(h0_col, h0_row)
lh = np.outer(h1_col, h0_row)
hl = np.outer(h0_col, h1_row)
hh = np.outer(h1_col, h1_row)
filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1],
hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]], axis=0)
filts = torch.tensor(filts, dtype=torch.get_default_dtype(), device=device)
return filts
def prep_filt_sfb2d_nonsep(g0_col, g1_col, g0_row=None, g1_row=None,
device=None):
"""
Prepares the filters to be of the right form for the sfb2d_nonsep function.
In particular, makes 2d point spread functions. Does not mirror image them
as sfb2d_nonsep uses conv2d_transpose which acts like normal convolution.
Inputs:
g0_col (array-like): low pass column filter bank
g1_col (array-like): high pass column filter bank
g0_row (array-like): low pass row filter bank. If none, will assume the
same as column filter
g1_row (array-like): high pass row filter bank. If none, will assume the
same as column filter
device: which device to put the tensors on to
Returns:
filts: (4, 1, h, w) tensor ready to combine the four subbands
"""
g0_col = np.array(g0_col).ravel()
g1_col = np.array(g1_col).ravel()
if g0_row is None:
g0_row = g0_col
if g1_row is None:
g1_row = g1_col
ll = np.outer(g0_col, g0_row)
lh = np.outer(g1_col, g0_row)
hl = np.outer(g0_col, g1_row)
hh = np.outer(g1_col, g1_row)
filts = np.stack([ll[None], lh[None], hl[None], hh[None]], axis=0)
filts = torch.tensor(filts, dtype=torch.get_default_dtype(), device=device)
return filts
def prep_filt_sfb2d(g0_col, g1_col, g0_row=None, g1_row=None, device=None):
"""
Prepares the filters to be of the right form for the sfb2d function. In
particular, makes the tensors the right shape. It does not mirror image them
as as sfb2d uses conv2d_transpose which acts like normal convolution.
Inputs:
g0_col (array-like): low pass column filter bank
g1_col (array-like): high pass column filter bank
g0_row (array-like): low pass row filter bank. If none, will assume the
same as column filter
g1_row (array-like): high pass row filter bank. If none, will assume the
same as column filter
device: which device to put the tensors on to
Returns:
(g0_col, g1_col, g0_row, g1_row)
"""
g0_col, g1_col = prep_filt_sfb1d(g0_col, g1_col, device)
if g0_row is None:
g0_row, g1_row = g0_col, g1_col
else:
g0_row, g1_row = prep_filt_sfb1d(g0_row, g1_row, device)
g0_col = g0_col.reshape((1, 1, -1, 1))
g1_col = g1_col.reshape((1, 1, -1, 1))
g0_row = g0_row.reshape((1, 1, 1, -1))
g1_row = g1_row.reshape((1, 1, 1, -1))
return g0_col, g1_col, g0_row, g1_row
def prep_filt_sfb1d(g0, g1, device=None):
"""
Prepares the filters to be of the right form for the sfb1d function. In
particular, makes the tensors the right shape. It does not mirror image them
as as sfb2d uses conv2d_transpose which acts like normal convolution.
Inputs:
g0 (array-like): low pass filter bank
g1 (array-like): high pass filter bank
device: which device to put the tensors on to
Returns:
(g0, g1)
"""
g0 = np.array(g0).ravel()
g1 = np.array(g1).ravel()
t = torch.get_default_dtype()
g0 = torch.tensor(g0, device=device, dtype=t).reshape((1, 1, -1))
g1 = torch.tensor(g1, device=device, dtype=t).reshape((1, 1, -1))
return g0, g1
def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=None):
"""
Prepares the filters to be of the right form for the afb2d function. In
particular, makes the tensors the right shape. It takes mirror images of
them as as afb2d uses conv2d which acts like normal correlation.
Inputs:
h0_col (array-like): low pass column filter bank
h1_col (array-like): high pass column filter bank
h0_row (array-like): low pass row filter bank. If none, will assume the
same as column filter
h1_row (array-like): high pass row filter bank. If none, will assume the
same as column filter
device: which device to put the tensors on to
Returns:
(h0_col, h1_col, h0_row, h1_row)
"""
h0_col, h1_col = prep_filt_afb1d(h0_col, h1_col, device)
if h0_row is None:
h0_row, h1_row = h0_col, h1_col
else:
h0_row, h1_row = prep_filt_afb1d(h0_row, h1_row, device)
h0_col = h0_col.reshape((1, 1, -1, 1))
h1_col = h1_col.reshape((1, 1, -1, 1))
h0_row = h0_row.reshape((1, 1, 1, -1))
h1_row = h1_row.reshape((1, 1, 1, -1))
return h0_col, h1_col, h0_row, h1_row
def prep_filt_afb1d(h0, h1, device=None):
"""
Prepares the filters to be of the right form for the afb2d function. In
particular, makes the tensors the right shape. It takes mirror images of
them as as afb2d uses conv2d which acts like normal correlation.
Inputs:
h0 (array-like): low pass column filter bank
h1 (array-like): high pass column filter bank
device: which device to put the tensors on to
Returns:
(h0, h1)
"""
h0 = np.array(h0[::-1]).ravel()
h1 = np.array(h1[::-1]).ravel()
t = torch.get_default_dtype()
h0 = torch.tensor(h0, device=device, dtype=t).reshape((1, 1, -1))
h1 = torch.tensor(h1, device=device, dtype=t).reshape((1, 1, -1))
return h0, h1
def reflect(x, minx, maxx):
"""Reflect the values in matrix *x* about the scalar values *minx* and
*maxx*. Hence a vector *x* containing a long linearly increasing series is
converted into a waveform which ramps linearly up and down between *minx*
and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers +
0.5), the ramps will have repeated max and min samples.
.. codeauthor:: Rich Wareham <rjw57@cantab.net>, Aug 2013
.. codeauthor:: Nick Kingsbury, Cambridge University, January 1999.
"""
x = np.asanyarray(x)
rng = maxx - minx
rng_by_2 = 2 * rng
mod = np.fmod(x - minx, rng_by_2)
normed_mod = np.where(mod < 0, mod + rng_by_2, mod)
out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx
return np.array(out, dtype=x.dtype)