import warnings
import torch
import torch.nn as nn
import copy
from torch import set_grad_enabled
[docs]class AdditiveCoupling(nn.Module):
def __init__(self, Fm, Gm=None, implementation_fwd=-1, implementation_bwd=-1, split_dim=1):
"""
This computes the output :math:`y` on forward given input :math:`x` and arbitrary modules :math:`Fm` and :math:`Gm` according to:
:math:`(x1, x2) = x`
:math:`y1 = x1 + Fm(x2)`
:math:`y2 = x2 + Gm(y1)`
:math:`y = (y1, y2)`
Parameters
----------
Fm : :obj:`torch.nn.Module`
A torch.nn.Module encapsulating an arbitrary function
Gm : :obj:`torch.nn.Module`
A torch.nn.Module encapsulating an arbitrary function
(If not specified a deepcopy of Fm is used as a Module)
implementation_fwd : :obj:`int`
Switch between different Additive Operation implementations for forward pass. Default = -1
implementation_bwd : :obj:`int`
Switch between different Additive Operation implementations for inverse pass. Default = -1
split_dim : :obj:`int`
Dimension to split the input tensors on. Default = 1, generally corresponding to channels.
"""
super(AdditiveCoupling, self).__init__()
# mirror the passed module, without parameter sharing...
if Gm is None:
Gm = copy.deepcopy(Fm)
self.Gm = Gm
self.Fm = Fm
self.implementation_fwd = implementation_fwd
self.implementation_bwd = implementation_bwd
self.split_dim = split_dim
if implementation_bwd != -1 or implementation_fwd != -1:
warnings.warn("Other implementations than the default (-1) are now deprecated.",
DeprecationWarning)
def forward(self, x):
args = [x, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()]
if self.implementation_fwd == 0:
out = AdditiveBlockFunction.apply(*args)
elif self.implementation_fwd == 1:
out = AdditiveBlockFunction2.apply(*args)
elif self.implementation_fwd == -1:
x1, x2 = torch.chunk(x, 2, dim=self.split_dim)
x1, x2 = x1.contiguous(), x2.contiguous()
fmd = self.Fm.forward(x2)
y1 = x1 + fmd
gmd = self.Gm.forward(y1)
y2 = x2 + gmd
out = torch.cat([y1, y2], dim=self.split_dim)
else:
raise NotImplementedError("Selected implementation ({}) not implemented..."
.format(self.implementation_fwd))
return out
def inverse(self, y):
args = [y, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()]
if self.implementation_bwd == 0:
x = AdditiveBlockInverseFunction.apply(*args)
elif self.implementation_bwd == 1:
x = AdditiveBlockInverseFunction2.apply(*args)
elif self.implementation_bwd == -1:
y1, y2 = torch.chunk(y, 2, dim=self.split_dim)
y1, y2 = y1.contiguous(), y2.contiguous()
gmd = self.Gm.forward(y1)
x2 = y2 - gmd
fmd = self.Fm.forward(x2)
x1 = y1 - fmd
x = torch.cat([x1, x2], dim=self.split_dim)
else:
raise NotImplementedError("Inverse for selected implementation ({}) not implemented..."
.format(self.implementation_bwd))
return x
class AdditiveBlock(AdditiveCoupling):
def __init__(self, Fm, Gm=None, implementation_fwd=1, implementation_bwd=1):
warnings.warn("This class has been deprecated. Use the AdditiveCoupling class instead.",
DeprecationWarning)
super(AdditiveBlock, self).__init__(Fm=Fm, Gm=Gm,
implementation_fwd=implementation_fwd,
implementation_bwd=implementation_bwd)
class AdditiveBlockFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xin, Fm, Gm, *weights):
"""Forward pass computes:
{x1, x2} = x
y1 = x1 + Fm(x2)
y2 = x2 + Gm(y1)
output = {y1, y2}
Parameters
----------
ctx : torch.autograd.Function
The backward pass context object
x : TorchTensor
Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions
Fm : nn.Module
Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape
Gm : nn.Module
Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape
*weights : TorchTensor
weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn}
Note
----
All tensor/autograd variable input arguments and the output are
TorchTensors for the scope of this function
"""
# check if possible to partition into two equally sized partitions
assert(xin.shape[1] % 2 == 0) # nosec
# store partition size, Fm and Gm functions in context
ctx.Fm = Fm
ctx.Gm = Gm
with torch.no_grad():
x = xin.detach()
# partition in two equally sized set of channels
x1, x2 = torch.chunk(x, 2, dim=1)
x1, x2 = x1.contiguous(), x2.contiguous()
# compute outputs
fmr = Fm.forward(x2)
y1 = x1 + fmr
x1.set_()
del x1
gmr = Gm.forward(y1)
y2 = x2 + gmr
x2.set_()
del x2
output = torch.cat([y1, y2], dim=1)
ctx.save_for_backward(xin, output)
return output
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
# retrieve weight references
Fm, Gm = ctx.Fm, ctx.Gm
# retrieve input and output references
xin, output = ctx.saved_tensors
x = xin.detach()
x1, x2 = torch.chunk(x, 2, dim=1)
GWeights = [p for p in Gm.parameters()]
# partition output gradient also on channels
assert grad_output.shape[1] % 2 == 0 # nosec
with set_grad_enabled(True):
# compute outputs building a sub-graph
x1.requires_grad_()
x2.requires_grad_()
y1 = x1 + Fm.forward(x2)
y2 = x2 + Gm.forward(y1)
y = torch.cat([y1, y2], dim=1)
# perform full backward pass on graph...
dd = torch.autograd.grad(y, (x1, x2 ) + tuple(Gm.parameters()) + tuple(Fm.parameters()), grad_output)
GWgrads = dd[2:2+len(GWeights)]
FWgrads = dd[2+len(GWeights):]
grad_input = torch.cat([dd[0], dd[1]], dim=1)
return (grad_input, None, None) + FWgrads + GWgrads
class AdditiveBlockInverseFunction(torch.autograd.Function):
@staticmethod
def forward(cty, y, Fm, Gm, *weights):
"""Forward pass computes:
{y1, y2} = y
x2 = y2 - Gm(y1)
x1 = y1 - Fm(x2)
output = {x1, x2}
Parameters
----------
cty : torch.autograd.Function
The backward pass context object
y : TorchTensor
Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions
Fm : nn.Module
Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape
Gm : nn.Module
Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape
*weights : TorchTensor
weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn}
Note
----
All tensor/autograd variable input arguments and the output are
TorchTensors for the scope of this fuction
"""
# check if possible to partition into two equally sized partitions
assert(y.shape[1] % 2 == 0) # nosec
# store partition size, Fm and Gm functions in context
cty.Fm = Fm
cty.Gm = Gm
with torch.no_grad():
# partition in two equally sized set of channels
y1, y2 = torch.chunk(y, 2, dim=1)
y1, y2 = y1.contiguous(), y2.contiguous()
# compute outputs
gmr = Gm.forward(y1)
x2 = y2 - gmr
y2.set_()
del y2
fmr = Fm.forward(x2)
x1 = y1 - fmr
y1.set_()
del y1
output = torch.cat([x1, x2], dim=1)
x1.set_()
x2.set_()
del x1, x2
# save the (empty) input and (non-empty) output variables
cty.save_for_backward(y.data, output)
return output
@staticmethod
def backward(cty, grad_output): # pragma: no cover
# retrieve weight references
Fm, Gm = cty.Fm, cty.Gm
# retrieve input and output references
yin, output = cty.saved_tensors
y = yin.detach()
y1, y2 = torch.chunk(y, 2, dim=1)
FWeights = [p for p in Fm.parameters()]
# partition output gradient also on channels
assert grad_output.shape[1] % 2 == 0 # nosec
with set_grad_enabled(True):
# compute outputs building a sub-graph
y2.requires_grad = True
y1.requires_grad = True
x2 = y2 - Gm.forward(y1)
x1 = y1 - Fm.forward(x2)
x = torch.cat([x1, x2], dim=1)
# perform full backward pass on graph...
dd = torch.autograd.grad(x, (y2, y1 ) + tuple(Fm.parameters()) + tuple(Gm.parameters()), grad_output)
FWgrads = dd[2:2+len(FWeights)]
GWgrads = dd[2+len(FWeights):]
grad_input = torch.cat([dd[0], dd[1]], dim=1)
return (grad_input, None, None) + FWgrads + GWgrads
class AdditiveBlockFunction2(torch.autograd.Function):
@staticmethod
def forward(ctx, xin, Fm, Gm, *weights):
"""Forward pass computes:
{x1, x2} = x
y1 = x1 + Fm(x2)
y2 = x2 + Gm(y1)
output = {y1, y2}
Parameters
----------
ctx : torch.autograd.Function
The backward pass context object
x : TorchTensor
Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions
Fm : nn.Module
Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape
Gm : nn.Module
Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape
*weights : TorchTensor
weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn}
Note
----
All tensor/autograd variable input arguments and the output are
TorchTensors for the scope of this fuction
"""
# check if possible to partition into two equally sized partitions
assert xin.shape[1] % 2 == 0 # nosec
# store partition size, Fm and Gm functions in context
ctx.Fm = Fm
ctx.Gm = Gm
with torch.no_grad():
# partition in two equally sized set of channels
x = xin.detach()
x1, x2 = torch.chunk(x, 2, dim=1)
x1, x2 = x1.contiguous(), x2.contiguous()
# compute outputs
fmr = Fm.forward(x2)
y1 = x1 + fmr
x1.set_()
del x1
gmr = Gm.forward(y1)
y2 = x2 + gmr
x2.set_()
del x2
output = torch.cat([y1, y2], dim=1).detach_()
# save the input and output variables
ctx.save_for_backward(x, output)
return output
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
Fm, Gm = ctx.Fm, ctx.Gm
# are all variable objects now
x, output = ctx.saved_tensors
with torch.no_grad():
y1, y2 = torch.chunk(output, 2, dim=1)
y1, y2 = y1.contiguous(), y2.contiguous()
# partition output gradient also on channels
assert(grad_output.shape[1] % 2 == 0) # nosec
y1_grad, y2_grad = torch.chunk(grad_output, 2, dim=1)
y1_grad, y2_grad = y1_grad.contiguous(), y2_grad.contiguous()
# Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes:
# z1_stop, x2_stop, GW, FW
# Also recompute inputs (x1, x2) from outputs (y1, y2)
with set_grad_enabled(True):
z1_stop = y1.detach()
z1_stop.requires_grad = True
G_z1 = Gm.forward(z1_stop)
x2 = y2 - G_z1
x2_stop = x2.detach()
x2_stop.requires_grad = True
F_x2 = Fm.forward(x2_stop)
x1 = y1 - F_x2
x1_stop = x1.detach()
x1_stop.requires_grad = True
# compute outputs building a sub-graph
y1 = x1_stop + F_x2
y2 = x2_stop + G_z1
# calculate the final gradients for the weights and inputs
dd = torch.autograd.grad(y2, (z1_stop,) + tuple(Gm.parameters()), y2_grad, retain_graph=False)
z1_grad = dd[0] + y1_grad
GWgrads = dd[1:]
dd = torch.autograd.grad(y1, (x1_stop, x2_stop) + tuple(Fm.parameters()), z1_grad, retain_graph=False)
FWgrads = dd[2:]
x2_grad = dd[1] + y2_grad
x1_grad = dd[0]
grad_input = torch.cat([x1_grad, x2_grad], dim=1)
return (grad_input, None, None) + FWgrads + GWgrads
class AdditiveBlockInverseFunction2(torch.autograd.Function):
@staticmethod
def forward(cty, y, Fm, Gm, *weights):
"""Forward pass computes:
{y1, y2} = y
x2 = y2 - Gm(y1)
x1 = y1 - Fm(x2)
output = {x1, x2}
Parameters
----------
cty : torch.autograd.Function
The backward pass context object
y : TorchTensor
Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions
Fm : nn.Module
Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape
Gm : nn.Module
Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape
*weights : TorchTensor
weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn}
Note
----
All tensor/autograd variable input arguments and the output are
TorchTensors for the scope of this fuction
"""
# check if possible to partition into two equally sized partitions
assert(y.shape[1] % 2 == 0) # nosec
# store partition size, Fm and Gm functions in context
cty.Fm = Fm
cty.Gm = Gm
with torch.no_grad():
# partition in two equally sized set of channels
y1, y2 = torch.chunk(y, 2, dim=1)
y1, y2 = y1.contiguous(), y2.contiguous()
# compute outputs
gmr = Gm.forward(y1)
x2 = y2 - gmr
y2.set_()
del y2
fmr = Fm.forward(x2)
x1 = y1 - fmr
y1.set_()
del y1
output = torch.cat([x1, x2], dim=1).detach_()
# save the input and output variables
cty.save_for_backward(y, output)
return output
@staticmethod
def backward(cty, grad_output): # pragma: no cover
Fm, Gm = cty.Fm, cty.Gm
# are all variable objects now
y, output = cty.saved_tensors
with torch.no_grad():
x1, x2 = torch.chunk(output, 2, dim=1)
x1, x2 = x1.contiguous(), x2.contiguous()
# partition output gradient also on channels
assert(grad_output.shape[1] % 2 == 0) # nosec
x1_grad, x2_grad = torch.chunk(grad_output, 2, dim=1)
x1_grad, x2_grad = x1_grad.contiguous(), x2_grad.contiguous()
# Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes:
# z1_stop, y1_stop, GW, FW
# Also recompute inputs (y1, y2) from outputs (x1, x2)
with set_grad_enabled(True):
z1_stop = x2.detach()
z1_stop.requires_grad = True
F_z1 = Fm.forward(z1_stop)
y1 = x1 + F_z1
y1_stop = y1.detach()
y1_stop.requires_grad = True
G_y1 = Gm.forward(y1_stop)
y2 = x2 + G_y1
y2_stop = y2.detach()
y2_stop.requires_grad = True
# compute outputs building a sub-graph
z1 = y2_stop - G_y1
x1 = y1_stop - F_z1
x2 = z1
# calculate the final gradients for the weights and inputs
dd = torch.autograd.grad(x1, (z1_stop,) + tuple(Fm.parameters()), x1_grad)
z1_grad = dd[0] + x2_grad
FWgrads = dd[1:]
dd = torch.autograd.grad(x2, (y2_stop, y1_stop) + tuple(Gm.parameters()), z1_grad, retain_graph=False)
GWgrads = dd[2:]
y1_grad = dd[1] + x1_grad
y2_grad = dd[0]
grad_input = torch.cat([y1_grad, y2_grad], dim=1)
return (grad_input, None, None) + FWgrads + GWgrads