Source code for memcnn.models.affine

import torch
import torch.nn as nn
import copy
import warnings
from torch import set_grad_enabled

warnings.filterwarnings(action='ignore', category=UserWarning)


[docs]class AffineAdapterNaive(nn.Module): """ Naive Affine adapter Outputs exp(f(x)), f(x) given f(.) and x """ def __init__(self, module): super(AffineAdapterNaive, self).__init__() self.f = module def forward(self, x): t = self.f(x) s = torch.exp(t) return s, t
[docs]class AffineAdapterSigmoid(nn.Module): """ Sigmoid based affine adapter Partitions the output h of f(x) = h into s and t by extracting every odd and even channel Outputs sigmoid(s), t """ def __init__(self, module): super(AffineAdapterSigmoid, self).__init__() self.f = module def forward(self, x): h = self.f(x) assert h.shape[1] % 2 == 0 # nosec scale = torch.sigmoid(h[:, 1::2] + 2.0) shift = h[:, 0::2] return scale, shift
[docs]class AffineCoupling(nn.Module): def __init__(self, Fm, Gm=None, adapter=None, implementation_fwd=-1, implementation_bwd=-1, split_dim=1): """ This computes the output :math:`y` on forward given input :math:`x` and arbitrary modules :math:`Fm` and :math:`Gm` according to: :math:`(x1, x2) = x` :math:`(log({s1}), t1) = Fm(x2)` :math:`s1 = exp(log({s1}))` :math:`y1 = s1 * x1 + t1` :math:`(log({s2}), t2) = Gm(y1)` :math:`s2 = exp(log({s2}))` :math:`y2 = s2 * x2 + t2` :math:`y = (y1, y2)` Parameters ---------- Fm : :obj:`torch.nn.Module` A torch.nn.Module encapsulating an arbitrary function Gm : :obj:`torch.nn.Module` A torch.nn.Module encapsulating an arbitrary function (If not specified a deepcopy of Gm is used as a Module) adapter : :obj:`torch.nn.Module` class An optional wrapper class A for Fm and Gm which must output s, t = A(x) with shape(s) = shape(t) = shape(x) s, t are respectively the scale and shift tensors for the affine coupling. implementation_fwd : :obj:`int` Switch between different Affine Operation implementations for forward pass. Default = -1 implementation_bwd : :obj:`int` Switch between different Affine Operation implementations for inverse pass. Default = -1 split_dim : :obj:`int` Dimension to split the input tensors on. Default = 1, generally corresponding to channels. """ super(AffineCoupling, self).__init__() # mirror the passed module, without parameter sharing... if Gm is None: Gm = copy.deepcopy(Fm) # apply the adapter class if it is given self.Gm = adapter(Gm) if adapter is not None else Gm self.Fm = adapter(Fm) if adapter is not None else Fm self.implementation_fwd = implementation_fwd self.implementation_bwd = implementation_bwd self.split_dim = split_dim if implementation_bwd != -1 or implementation_fwd != -1: warnings.warn("Other implementations than the default (-1) are now deprecated.", DeprecationWarning) def forward(self, x): args = [x, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()] if self.implementation_fwd == 0: out = AffineBlockFunction.apply(*args) elif self.implementation_fwd == 1: out = AffineBlockFunction2.apply(*args) elif self.implementation_fwd == -1: x1, x2 = torch.chunk(x, 2, dim=self.split_dim) x1, x2 = x1.contiguous(), x2.contiguous() fmr1, fmr2 = self.Fm.forward(x2) y1 = (x1 * fmr1) + fmr2 gmr1, gmr2 = self.Gm.forward(y1) y2 = (x2 * gmr1) + gmr2 out = torch.cat([y1, y2], dim=self.split_dim) else: raise NotImplementedError("Selected implementation ({}) not implemented..." .format(self.implementation_fwd)) return out def inverse(self, y): args = [y, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()] if self.implementation_bwd == 0: x = AffineBlockInverseFunction.apply(*args) elif self.implementation_bwd == 1: x = AffineBlockInverseFunction2.apply(*args) elif self.implementation_bwd == -1: y1, y2 = torch.chunk(y, 2, dim=self.split_dim) y1, y2 = y1.contiguous(), y2.contiguous() gmr1, gmr2 = self.Gm.forward(y1) x2 = (y2 - gmr2) / gmr1 fmr1, fmr2 = self.Fm.forward(x2) x1 = (y1 - fmr2) / fmr1 x = torch.cat([x1, x2], dim=self.split_dim) else: raise NotImplementedError("Inverse for selected implementation ({}) not implemented..." .format(self.implementation_bwd)) return x
class AffineBlock(AffineCoupling): def __init__(self, Fm, Gm=None, implementation_fwd=1, implementation_bwd=1): warnings.warn("This class has been deprecated. Use the AffineCoupling class instead.", DeprecationWarning) super(AffineBlock, self).__init__(Fm=Fm, Gm=Gm, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd) class AffineBlockFunction(torch.autograd.Function): @staticmethod def forward(ctx, xin, Fm, Gm, *weights): """Forward pass for the affine block computes: {x1, x2} = x {log_s1, t1} = Fm(x2) s1 = exp(log_s1) y1 = s1 * x1 + t1 {log_s2, t2} = Gm(y1) s2 = exp(log_s2) y2 = s2 * x2 + t2 output = {y1, y2} Parameters ---------- ctx : torch.autograd.function.RevNetFunctionBackward The backward pass context object x : TorchTensor Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions Fm : nn.Module Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape Gm : nn.Module Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape *weights : TorchTensor weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn} Note ---- All tensor/autograd variable input arguments and the output are TorchTensors for the scope of this function """ # check if possible to partition into two equally sized partitions assert xin.shape[1] % 2 == 0 # nosec # store partition size, Fm and Gm functions in context ctx.Fm = Fm ctx.Gm = Gm with torch.no_grad(): # partition in two equally sized set of channels x = xin.detach() x1, x2 = torch.chunk(x, 2, dim=1) x1, x2 = x1.contiguous(), x2.contiguous() # compute outputs x2var = x2 fmr1, fmr2 = Fm.forward(x2var) y1 = (x1 * fmr1) + fmr2 x1.set_() del x1 y1var = y1 gmr1, gmr2 = Gm.forward(y1var) y2 = (x2 * gmr1) + gmr2 x2.set_() del x2 output = torch.cat([y1, y2], dim=1).detach_() # save the (empty) input and (non-empty) output variables ctx.save_for_backward(xin, output) return output @staticmethod def backward(ctx, grad_output): # pragma: no cover # retrieve weight references Fm, Gm = ctx.Fm, ctx.Gm # retrieve input and output references xin, output = ctx.saved_tensors x = xin.detach() x1, x2 = torch.chunk(x.detach(), 2, dim=1) GWeights = [p for p in Gm.parameters()] # partition output gradient also on channels assert (grad_output.shape[1] % 2 == 0) # nosec with set_grad_enabled(True): # compute outputs building a sub-graph x1.requires_grad = True x2.requires_grad = True fmr1, fmr2 = Fm.forward(x2) y1 = x1 * fmr1 + fmr2 gmr1, gmr2 = Gm.forward(y1) y2 = x2 * gmr1 + gmr2 y = torch.cat([y1, y2], dim=1) # perform full backward pass on graph... dd = torch.autograd.grad(y, (x1, x2) + tuple(Gm.parameters()) + tuple(Fm.parameters()), grad_output) GWgrads = dd[2:2 + len(GWeights)] FWgrads = dd[2 + len(GWeights):] grad_input = torch.cat([dd[0], dd[1]], dim=1) return (grad_input, None, None) + FWgrads + GWgrads class AffineBlockInverseFunction(torch.autograd.Function): @staticmethod def forward(cty, yin, Fm, Gm, *weights): """Forward inverse pass for the affine block computes: {y1, y2} = y {log_s2, t2} = Gm(y1) s2 = exp(log_s2) x2 = (y2 - t2) / s2 {log_s1, t1} = Fm(x2) s1 = exp(log_s1) x1 = (y1 - t1) / s1 output = {x1, x2} Parameters ---------- cty : torch.autograd.function.RevNetInverseFunctionBackward The backward pass context object y : TorchTensor Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions Fm : nn.Module Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape Gm : nn.Module Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape *weights : TorchTensor weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn} Note ---- All tensor/autograd variable input arguments and the output are TorchTensors for the scope of this fuction """ # check if possible to partition into two equally sized partitions assert yin.shape[1] % 2 == 0 # nosec # store partition size, Fm and Gm functions in context cty.Fm = Fm cty.Gm = Gm with torch.no_grad(): # partition in two equally sized set of channels y = yin.detach() y1, y2 = torch.chunk(y, 2, dim=1) y1, y2 = y1.contiguous(), y2.contiguous() # compute outputs y1var = y1 gmr1, gmr2 = Gm.forward(y1var) x2 = (y2 - gmr2) / gmr1 y2.set_() del y2 x2var = x2 fmr1, fmr2 = Fm.forward(x2var) x1 = (y1 - fmr2) / fmr1 y1.set_() del y1 output = torch.cat([x1, x2], dim=1).detach_() # save input and output variables cty.save_for_backward(yin, output) return output @staticmethod def backward(cty, grad_output): # pragma: no cover # retrieve weight references Fm, Gm = cty.Fm, cty.Gm # retrieve input and output references yin, output = cty.saved_tensors y = yin.detach() y1, y2 = torch.chunk(y.detach(), 2, dim=1) FWeights = [p for p in Gm.parameters()] # partition output gradient also on channels assert grad_output.shape[1] % 2 == 0 # nosec with set_grad_enabled(True): # compute outputs building a sub-graph y2.requires_grad = True y1.requires_grad = True gmr1, gmr2 = Gm.forward(y1) # x2 = (y2 - gmr2) / gmr1 fmr1, fmr2 = Fm.forward(x2) x1 = (y1 - fmr2) / fmr1 x = torch.cat([x1, x2], dim=1) # perform full backward pass on graph... dd = torch.autograd.grad(x, (y2, y1) + tuple(Fm.parameters()) + tuple(Gm.parameters()), grad_output) FWgrads = dd[2:2 + len(FWeights)] GWgrads = dd[2 + len(FWeights):] grad_input = torch.cat([dd[0], dd[1]], dim=1) return (grad_input, None, None) + FWgrads + GWgrads class AffineBlockFunction2(torch.autograd.Function): @staticmethod def forward(ctx, xin, Fm, Gm, *weights): """Forward pass for the affine block computes: {x1, x2} = x {log_s1, t1} = Fm(x2) s1 = exp(log_s1) y1 = s1 * x1 + t1 {log_s2, t2} = Gm(y1) s2 = exp(log_s2) y2 = s2 * x2 + t2 output = {y1, y2} Parameters ---------- ctx : torch.autograd.function.RevNetFunctionBackward The backward pass context object x : TorchTensor Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions Fm : nn.Module Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape Gm : nn.Module Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape *weights : TorchTensor weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn} Note ---- All tensor/autograd variable input arguments and the output are TorchTensors for the scope of this fuction """ # check if possible to partition into two equally sized partitions assert xin.shape[1] % 2 == 0 # nosec # store partition size, Fm and Gm functions in context ctx.Fm = Fm ctx.Gm = Gm with torch.no_grad(): # partition in two equally sized set of channels x = xin.detach() x1, x2 = torch.chunk(x, 2, dim=1) x1, x2 = x1.contiguous(), x2.contiguous() # compute outputs x2var = x2 fmr1, fmr2 = Fm.forward(x2var) y1 = x1 * fmr1 + fmr2 x1.set_() del x1 y1var = y1 gmr1, gmr2 = Gm.forward(y1var) y2 = x2 * gmr1 + gmr2 x2.set_() del x2 output = torch.cat([y1, y2], dim=1).detach_() # save the input and output variables ctx.save_for_backward(xin, output) return output @staticmethod def backward(ctx, grad_output): # pragma: no cover Fm, Gm = ctx.Fm, ctx.Gm # are all variable objects now x, output = ctx.saved_tensors with set_grad_enabled(False): y1, y2 = torch.chunk(output, 2, dim=1) y1, y2 = y1.contiguous(), y2.contiguous() # partition output gradient also on channels assert (grad_output.shape[1] % 2 == 0) # nosec y1_grad, y2_grad = torch.chunk(grad_output, 2, dim=1) y1_grad, y2_grad = y1_grad.contiguous(), y2_grad.contiguous() # Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes: # z1_stop, x2_stop, GW, FW # Also recompute inputs (x1, x2) from outputs (y1, y2) with set_grad_enabled(True): z1_stop = y1 z1_stop.requires_grad = True G_z11, G_z12 = Gm.forward(z1_stop) x2 = (y2 - G_z12) / G_z11 x2_stop = x2.detach() x2_stop.requires_grad = True F_x21, F_x22 = Fm.forward(x2_stop) x1 = (y1 - F_x22) / F_x21 x1_stop = x1.detach() x1_stop.requires_grad = True # compute outputs building a sub-graph z1 = x1_stop * F_x21 + F_x22 y2_ = x2_stop * G_z11 + G_z12 y1_ = z1 # calculate the final gradients for the weights and inputs dd = torch.autograd.grad(y2_, (z1_stop,) + tuple(Gm.parameters()), y2_grad) z1_grad = dd[0] + y1_grad GWgrads = dd[1:] dd = torch.autograd.grad(y1_, (x1_stop, x2_stop) + tuple(Fm.parameters()), z1_grad, retain_graph=False) FWgrads = dd[2:] x2_grad = dd[1] + y2_grad x1_grad = dd[0] grad_input = torch.cat([x1_grad, x2_grad], dim=1) y1_.detach_() y2_.detach_() del y1_, y2_ return (grad_input, None, None) + FWgrads + GWgrads class AffineBlockInverseFunction2(torch.autograd.Function): @staticmethod def forward(cty, yin, Fm, Gm, *weights): """Forward pass for the affine block computes: Parameters ---------- cty : torch.autograd.function.RevNetInverseFunctionBackward The backward pass context object y : TorchTensor Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions Fm : nn.Module Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape Gm : nn.Module Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape *weights : TorchTensor weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn} Note ---- All tensor/autograd variable input arguments and the output are TorchTensors for the scope of this fuction """ # check if possible to partition into two equally sized partitions assert yin.shape[1] % 2 == 0 # nosec # store partition size, Fm and Gm functions in context cty.Fm = Fm cty.Gm = Gm with torch.no_grad(): # partition in two equally sized set of channels y = yin.detach() y1, y2 = torch.chunk(y, 2, dim=1) y1, y2 = y1.contiguous(), y2.contiguous() # compute outputs y1var = y1 gmr1, gmr2 = Gm.forward(y1var) x2 = (y2 - gmr2) / gmr1 y2.set_() del y2 x2var = x2 fmr1, fmr2 = Fm.forward(x2var) x1 = (y1 - fmr2) / fmr1 y1.set_() del y1 output = torch.cat([x1, x2], dim=1).detach_() # save the input and output variables cty.save_for_backward(yin, output) return output @staticmethod def backward(cty, grad_output): # pragma: no cover Fm, Gm = cty.Fm, cty.Gm # are all variable objects now y, output = cty.saved_tensors with set_grad_enabled(False): x1, x2 = torch.chunk(output, 2, dim=1) x1, x2 = x1.contiguous(), x2.contiguous() # partition output gradient also on channels assert (grad_output.shape[1] % 2 == 0) # nosec x1_grad, x2_grad = torch.chunk(grad_output, 2, dim=1) x1_grad, x2_grad = x1_grad.contiguous(), x2_grad.contiguous() # Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes: # z1_stop, y1_stop, GW, FW # Also recompute inputs (y1, y2) from outputs (x1, x2) with set_grad_enabled(True): z1_stop = x2 z1_stop.requires_grad = True F_z11, F_z12 = Fm.forward(z1_stop) y1 = x1 * F_z11 + F_z12 y1_stop = y1.detach() y1_stop.requires_grad = True G_y11, G_y12 = Gm.forward(y1_stop) y2 = x2 * G_y11 + G_y12 y2_stop = y2.detach() y2_stop.requires_grad = True # compute outputs building a sub-graph z1 = (y2_stop - G_y12) / G_y11 x1_ = (y1_stop - F_z12) / F_z11 x2_ = z1 # calculate the final gradients for the weights and inputs dd = torch.autograd.grad(x1_, (z1_stop,) + tuple(Fm.parameters()), x1_grad) z1_grad = dd[0] + x2_grad FWgrads = dd[1:] dd = torch.autograd.grad(x2_, (y2_stop, y1_stop) + tuple(Gm.parameters()), z1_grad, retain_graph=False) GWgrads = dd[2:] y1_grad = dd[1] + x1_grad y2_grad = dd[0] grad_input = torch.cat([y1_grad, y2_grad], dim=1) return (grad_input, None, None) + FWgrads + GWgrads