Modules

Top-level package for MemCNN.

memcnn.is_invertible_module(module_in, test_input_shape, test_input_dtype=<sphinx.ext.autodoc.importer._MockObject object>, atol=1e-06, random_seed=42)[source]

Test if a torch.nn.Module is invertible

Parameters:
  • module_in (torch.nn.Module) – A torch.nn.Module to test.
  • test_input_shape (tuple of int or tuple of tuple of int) – Dimensions of test tensor(s) object to perform the test with.
  • test_input_dtype (torch.dtype, optional) – Data type of test tensor object to perform the test with.
  • atol (float, optional) – Tolerance value used for comparing the outputs.
  • random_seed (int, optional) – Use this value to seed the pseudo-random test_input_shapes with different numbers.
Returns:

True if the input module is invertible, False otherwise.

Return type:

bool

class memcnn.InvertibleModuleWrapper(fn, keep_input=False, keep_input_inverse=False, num_bwd_passes=1, disable=False, preserve_rng_state=False)[source]

The InvertibleModuleWrapper which enables memory savings during training by exploiting the invertible properties of the wrapped module.

Parameters:
  • fn (torch.nn.Module) – A torch.nn.Module which has a forward and an inverse function implemented with \(x == m.inverse(m.forward(x))\)
  • keep_input (bool, optional) – Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass.
  • keep_input_inverse (bool, optional) – Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass.
:param num_bwd_passes int, optional: Number of backward passes to retain a link with the output. After the last backward pass the output
is discarded and memory is freed. Warning: if this value is raised higher than the number of required passes memory will not be freed correctly anymore and the training process can quickly run out of memory. Hence, The typical use case is to keep this at 1, until it raises an error for raising this value.
Parameters:
  • disable (bool, optional) – This will disable using the InvertibleCheckpointFunction altogether. Essentially this renders the function as \(y = fn(x)\) without any of the memory savings. Setting this to true will also ignore the keep_input and keep_input_inverse properties.
  • preserve_rng_state (bool, optional) – Setting this will ensure that the same RNG state is used during reconstruction of the inputs. I.e. if keep_input = False on forward or keep_input_inverse = False on inverse. By default this is False since most invertible modules should have a valid inverse and hence are deterministic.
keep_input

Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass.

Type:bool, optional
keep_input_inverse

Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass.

Type:bool, optional

Note

The InvertibleModuleWrapper can be used with mixed-precision training using torch.cuda.amp.autocast as of torch v1.6 and above. However, inputs will always be cast to torch.float32 internally. This is done to minimize autocasting inputs to a different datatype which usually results in a disconnected computation graph and will raise an error on the backward pass.

forward(*xin)[source]

Forward operation \(R(x) = y\)

Parameters:*xin (torch.Tensor tuple) – Input torch tensor(s).
Returns:Output torch tensor(s) *y.
Return type:torch.Tensor tuple
inverse(*yin)[source]

Inverse operation \(R^{-1}(y) = x\)

Parameters:*yin (torch.Tensor tuple) – Input torch tensor(s).
Returns:Output torch tensor(s) *x.
Return type:torch.Tensor tuple
class memcnn.AdditiveCoupling(Fm, Gm=None, implementation_fwd=-1, implementation_bwd=-1, split_dim=1)[source]

This computes the output \(y\) on forward given input \(x\) and arbitrary modules \(Fm\) and \(Gm\) according to:

\((x1, x2) = x\)

\(y1 = x1 + Fm(x2)\)

\(y2 = x2 + Gm(y1)\)

\(y = (y1, y2)\)

Parameters:
  • Fm (torch.nn.Module) – A torch.nn.Module encapsulating an arbitrary function
  • Gm (torch.nn.Module) – A torch.nn.Module encapsulating an arbitrary function (If not specified a deepcopy of Fm is used as a Module)
  • implementation_fwd (int) – Switch between different Additive Operation implementations for forward pass. Default = -1
  • implementation_bwd (int) – Switch between different Additive Operation implementations for inverse pass. Default = -1
  • split_dim (int) – Dimension to split the input tensors on. Default = 1, generally corresponding to channels.
class memcnn.AffineCoupling(Fm, Gm=None, adapter=None, implementation_fwd=-1, implementation_bwd=-1, split_dim=1)[source]

This computes the output \(y\) on forward given input \(x\) and arbitrary modules \(Fm\) and \(Gm\) according to:

\((x1, x2) = x\)

\((log({s1}), t1) = Fm(x2)\)

\(s1 = exp(log({s1}))\)

\(y1 = s1 * x1 + t1\)

\((log({s2}), t2) = Gm(y1)\)

\(s2 = exp(log({s2}))\)

\(y2 = s2 * x2 + t2\)

\(y = (y1, y2)\)

Parameters:
  • Fm (torch.nn.Module) – A torch.nn.Module encapsulating an arbitrary function
  • Gm (torch.nn.Module) – A torch.nn.Module encapsulating an arbitrary function (If not specified a deepcopy of Gm is used as a Module)
  • adapter (torch.nn.Module class) – An optional wrapper class A for Fm and Gm which must output s, t = A(x) with shape(s) = shape(t) = shape(x) s, t are respectively the scale and shift tensors for the affine coupling.
  • implementation_fwd (int) – Switch between different Affine Operation implementations for forward pass. Default = -1
  • implementation_bwd (int) – Switch between different Affine Operation implementations for inverse pass. Default = -1
  • split_dim (int) – Dimension to split the input tensors on. Default = 1, generally corresponding to channels.
class memcnn.AffineAdapterNaive(module)[source]

Naive Affine adapter

Outputs exp(f(x)), f(x) given f(.) and x

class memcnn.AffineAdapterSigmoid(module)[source]

Sigmoid based affine adapter

Partitions the output h of f(x) = h into s and t by extracting every odd and even channel Outputs sigmoid(s), t

class memcnn.ReversibleBlock(Fm, Gm=None, coupling='additive', keep_input=False, keep_input_inverse=False, implementation_fwd=-1, implementation_bwd=-1, adapter=None)[source]

The ReversibleBlock

Warning

This class has been deprecated. Use the more flexible InvertibleModuleWrapper class.

Note

The implementation_fwd and implementation_bwd parameters can be set to one of the following implementations:

  • -1 Naive implementation without reconstruction on the backward pass.
  • 0 Memory efficient implementation, compute gradients directly.
  • 1 Memory efficient implementation, similar to approach in Gomez et al. 2017.
Parameters:
  • Fm (torch.nn.Module) – A torch.nn.Module encapsulating an arbitrary function
  • Gm (torch.nn.Module, optional) – A torch.nn.Module encapsulating an arbitrary function (If not specified a deepcopy of Fm is used as a Module)
  • coupling (str, optional) – Type of coupling [‘additive’, ‘affine’]. Default = ‘additive’
  • keep_input (bool, optional) – Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass.
  • keep_input_inverse (bool, optional) – Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass.
  • implementation_fwd (int, optional) – Switch between different Operation implementations for forward training (Default = 1). If using the naive implementation (-1) then keep_input should be True.
  • implementation_bwd (int, optional) – Switch between different Operation implementations for backward training (Default = 1). If using the naive implementation (-1) then keep_input_inverse should be True.
  • adapter (class, optional) – Only relevant when using the ‘affine’ coupling. Should be a class of type torch.nn.Module that serves as an optional wrapper class A for Fm and Gm which must output s, t = A(x) with shape(s) = shape(t) = shape(x). s, t are respectively the scale and shift tensors for the affine coupling.
keep_input

Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass.

Type:bool, optional
keep_input_inverse

Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass.

Type:bool, optional
Raises:NotImplementedError – If an unknown coupling or implementation is given.
forward(*xin)

Forward operation \(R(x) = y\)

Parameters:*xin (torch.Tensor tuple) – Input torch tensor(s).
Returns:Output torch tensor(s) *y.
Return type:torch.Tensor tuple
inverse(*yin)

Inverse operation \(R^{-1}(y) = x\)

Parameters:*yin (torch.Tensor tuple) – Input torch tensor(s).
Returns:Output torch tensor(s) *x.
Return type:torch.Tensor tuple