Source code for memcnn.models.revop

import functools
import warnings
import numpy as np
import torch
import torch.nn as nn

from memcnn.models.additive import AdditiveCoupling
from memcnn.models.affine import AffineCoupling

try:
    from torch.cuda.amp import custom_fwd, custom_bwd
except ModuleNotFoundError:
    def custom_fwd(fwd=None, *, cast_inputs=None):
        if fwd is None:
            return functools.partial(custom_fwd)
        return functools.partial(fwd)

    def custom_bwd(bwd):
        return functools.partial(bwd)


class InvertibleCheckpointFunction(torch.autograd.Function):
    @staticmethod
    @custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, fn, fn_inverse, keep_input, num_bwd_passes, preserve_rng_state, num_inputs, *inputs_and_weights):
        # store in context
        ctx.fn = fn
        ctx.fn_inverse = fn_inverse
        ctx.keep_input = keep_input
        ctx.weights = inputs_and_weights[num_inputs:]
        ctx.num_bwd_passes = num_bwd_passes
        ctx.preserve_rng_state = preserve_rng_state
        ctx.num_inputs = num_inputs
        inputs = inputs_and_weights[:num_inputs]

        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            # Don't eagerly initialize the cuda context by accident.
            # (If the user intends that the context is initialized later, within their
            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
            # we have no way to anticipate this will happen before we run the function.)
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*inputs)

        ctx.input_requires_grad = [element.requires_grad for element in inputs]

        with torch.no_grad():
            # Makes a detached copy which shares the storage
            x = [element.detach() for element in inputs]
            outputs = ctx.fn(*x)

        if not isinstance(outputs, tuple):
            outputs = (outputs,)

        # Detaches y in-place (inbetween computations can now be discarded)
        detached_outputs = tuple([element.detach_() for element in outputs])

        # clear memory from inputs
        if not ctx.keep_input:
            # PyTorch 1.0+ way to clear storage
            for element in inputs:
                element.storage().resize_(0)

        # store these tensor nodes for backward pass
        ctx.inputs = [inputs] * num_bwd_passes
        ctx.outputs = [detached_outputs] * num_bwd_passes

        return detached_outputs

    @staticmethod
    @custom_bwd
    def backward(ctx, *grad_outputs):  # pragma: no cover
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("InvertibleCheckpointFunction is not compatible with .grad(), please use .backward() if possible")
        # retrieve input and output tensor nodes
        if len(ctx.outputs) == 0:
            raise RuntimeError("Trying to perform backward on the InvertibleCheckpointFunction for more than "
                               "{} times! Try raising `num_bwd_passes` by one.".format(ctx.num_bwd_passes))
        inputs = ctx.inputs.pop()
        outputs = ctx.outputs.pop()

        # recompute input if necessary
        if not ctx.keep_input:
            # Stash the surrounding rng state, and mimic the state that was
            # present at this time during forward.  Restore the surrounding state
            # when we're done.
            rng_devices = []
            if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
                rng_devices = ctx.fwd_gpu_devices
            with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
                if ctx.preserve_rng_state:
                    torch.set_rng_state(ctx.fwd_cpu_state)
                    if ctx.had_cuda_in_fwd:
                        set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
                # recompute input
                with torch.no_grad():
                    inputs_inverted = ctx.fn_inverse(*outputs)
                    if not isinstance(inputs_inverted, tuple):
                        inputs_inverted = (inputs_inverted,)
                    for element_original, element_inverted in zip(inputs, inputs_inverted):
                        element_original.storage().resize_(int(np.prod(element_original.size())))
                        element_original.set_(element_inverted)

        # compute gradients
        with torch.set_grad_enabled(True):
            detached_inputs = tuple([element.detach().requires_grad_() for element in inputs])
            temp_output = ctx.fn(*detached_inputs)
        if not isinstance(temp_output, tuple):
            temp_output = (temp_output,)

        gradients = torch.autograd.grad(outputs=temp_output, inputs=detached_inputs + ctx.weights, grad_outputs=grad_outputs)

        # Setting the gradients manually on the inputs and outputs (mimic backwards)
        for element, element_grad in zip(inputs, gradients[:ctx.num_inputs]):
            element.grad = element_grad

        for element, element_grad in zip(outputs, grad_outputs):
            element.grad = element_grad

        return (None, None, None, None, None, None) + gradients


[docs]class InvertibleModuleWrapper(nn.Module): def __init__(self, fn, keep_input=False, keep_input_inverse=False, num_bwd_passes=1, disable=False, preserve_rng_state=False): """ The InvertibleModuleWrapper which enables memory savings during training by exploiting the invertible properties of the wrapped module. Parameters ---------- fn : :obj:`torch.nn.Module` A torch.nn.Module which has a forward and an inverse function implemented with :math:`x == m.inverse(m.forward(x))` keep_input : :obj:`bool`, optional Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass. keep_input_inverse : :obj:`bool`, optional Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass. num_bwd_passes :obj:`int`, optional Number of backward passes to retain a link with the output. After the last backward pass the output is discarded and memory is freed. Warning: if this value is raised higher than the number of required passes memory will not be freed correctly anymore and the training process can quickly run out of memory. Hence, The typical use case is to keep this at 1, until it raises an error for raising this value. disable : :obj:`bool`, optional This will disable using the InvertibleCheckpointFunction altogether. Essentially this renders the function as :math:`y = fn(x)` without any of the memory savings. Setting this to true will also ignore the keep_input and keep_input_inverse properties. preserve_rng_state : :obj:`bool`, optional Setting this will ensure that the same RNG state is used during reconstruction of the inputs. I.e. if keep_input = False on forward or keep_input_inverse = False on inverse. By default this is False since most invertible modules should have a valid inverse and hence are deterministic. Attributes ---------- keep_input : :obj:`bool`, optional Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass. keep_input_inverse : :obj:`bool`, optional Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass. Note ---- The InvertibleModuleWrapper can be used with mixed-precision training using :obj:`torch.cuda.amp.autocast` as of torch v1.6 and above. However, inputs will always be cast to :obj:`torch.float32` internally. This is done to minimize autocasting inputs to a different datatype which usually results in a disconnected computation graph and will raise an error on the backward pass. """ super(InvertibleModuleWrapper, self).__init__() self.disable = disable self.keep_input = keep_input self.keep_input_inverse = keep_input_inverse self.num_bwd_passes = num_bwd_passes self.preserve_rng_state = preserve_rng_state self._fn = fn
[docs] def forward(self, *xin): """Forward operation :math:`R(x) = y` Parameters ---------- *xin : :obj:`torch.Tensor` tuple Input torch tensor(s). Returns ------- :obj:`torch.Tensor` tuple Output torch tensor(s) *y. """ if not self.disable: y = InvertibleCheckpointFunction.apply( self._fn.forward, self._fn.inverse, self.keep_input, self.num_bwd_passes, self.preserve_rng_state, len(xin), *(xin + tuple([p for p in self._fn.parameters() if p.requires_grad]))) else: y = self._fn(*xin) # If the layer only has one input, we unpack the tuple again if isinstance(y, tuple) and len(y) == 1: return y[0] return y
[docs] def inverse(self, *yin): """Inverse operation :math:`R^{-1}(y) = x` Parameters ---------- *yin : :obj:`torch.Tensor` tuple Input torch tensor(s). Returns ------- :obj:`torch.Tensor` tuple Output torch tensor(s) *x. """ if not self.disable: x = InvertibleCheckpointFunction.apply( self._fn.inverse, self._fn.forward, self.keep_input_inverse, self.num_bwd_passes, self.preserve_rng_state, len(yin), *(yin + tuple([p for p in self._fn.parameters() if p.requires_grad]))) else: x = self._fn.inverse(*yin) # If the layer only has one input, we unpack the tuple again if isinstance(x, tuple) and len(x) == 1: return x[0] return x
[docs]class ReversibleBlock(InvertibleModuleWrapper): def __init__(self, Fm, Gm=None, coupling='additive', keep_input=False, keep_input_inverse=False, implementation_fwd=-1, implementation_bwd=-1, adapter=None): """The ReversibleBlock Warning ------- This class has been deprecated. Use the more flexible InvertibleModuleWrapper class. Note ---- The `implementation_fwd` and `implementation_bwd` parameters can be set to one of the following implementations: * -1 Naive implementation without reconstruction on the backward pass. * 0 Memory efficient implementation, compute gradients directly. * 1 Memory efficient implementation, similar to approach in Gomez et al. 2017. Parameters ---------- Fm : :obj:`torch.nn.Module` A torch.nn.Module encapsulating an arbitrary function Gm : :obj:`torch.nn.Module`, optional A torch.nn.Module encapsulating an arbitrary function (If not specified a deepcopy of Fm is used as a Module) coupling : :obj:`str`, optional Type of coupling ['additive', 'affine']. Default = 'additive' keep_input : :obj:`bool`, optional Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass. keep_input_inverse : :obj:`bool`, optional Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass. implementation_fwd : :obj:`int`, optional Switch between different Operation implementations for forward training (Default = 1). If using the naive implementation (-1) then `keep_input` should be True. implementation_bwd : :obj:`int`, optional Switch between different Operation implementations for backward training (Default = 1). If using the naive implementation (-1) then `keep_input_inverse` should be True. adapter : :obj:`class`, optional Only relevant when using the 'affine' coupling. Should be a class of type :obj:`torch.nn.Module` that serves as an optional wrapper class A for Fm and Gm which must output s, t = A(x) with shape(s) = shape(t) = shape(x). s, t are respectively the scale and shift tensors for the affine coupling. Attributes ---------- keep_input : :obj:`bool`, optional Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass. keep_input_inverse : :obj:`bool`, optional Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass. Raises ------ NotImplementedError If an unknown coupling or implementation is given. """ warnings.warn("This class has been deprecated. Use the more flexible InvertibleModuleWrapper class", DeprecationWarning) fn = create_coupling(Fm=Fm, Gm=Gm, coupling=coupling, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd, adapter=adapter) super(ReversibleBlock, self).__init__(fn, keep_input=keep_input, keep_input_inverse=keep_input_inverse)
def create_coupling(Fm, Gm=None, coupling='additive', implementation_fwd=-1, implementation_bwd=-1, adapter=None): if coupling == 'additive': fn = AdditiveCoupling(Fm, Gm, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd) elif coupling == 'affine': fn = AffineCoupling(Fm, Gm, adapter=adapter, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd) else: raise NotImplementedError('Unknown coupling method: %s' % coupling) return fn
[docs]def is_invertible_module(module_in, test_input_shape, test_input_dtype=torch.float32, atol=1e-6, random_seed=42): """Test if a :obj:`torch.nn.Module` is invertible Parameters ---------- module_in : :obj:`torch.nn.Module` A torch.nn.Module to test. test_input_shape : :obj:`tuple` of :obj:`int` or :obj:`tuple` of :obj:`tuple` of :obj:`int` Dimensions of test tensor(s) object to perform the test with. test_input_dtype : :obj:`torch.dtype`, optional Data type of test tensor object to perform the test with. atol : :obj:`float`, optional Tolerance value used for comparing the outputs. random_seed : :obj:`int`, optional Use this value to seed the pseudo-random test_input_shapes with different numbers. Returns ------- :obj:`bool` True if the input module is invertible, False otherwise. """ if isinstance(module_in, InvertibleModuleWrapper): module_in = module_in._fn if not hasattr(module_in, "inverse"): return False def _type_check_input_shape(test_input_shape): if isinstance(test_input_shape, (tuple, list)): if all([isinstance(e, int) for e in test_input_shape]): return True elif all([isinstance(e, (tuple, list)) for e in test_input_shape]): return all([isinstance(ee, int) for e in test_input_shape for ee in e]) else: return False else: return False if not _type_check_input_shape(test_input_shape): raise ValueError("test_input_shape should be of type Tuple[int, ...] or " "Tuple[Tuple[int, ...], ...], but {} found".format(type(test_input_shape))) if not isinstance(test_input_shape[0], (tuple, list)): test_input_shape = (test_input_shape,) def _check_inputs_allclose(inputs, reference, atol): for inp, ref in zip(inputs, reference): if not torch.allclose(inp, ref, atol=atol): return False return True def _pack_if_no_tuple(x): if not isinstance(x, tuple): return (x, ) return x with torch.no_grad(): torch.manual_seed(random_seed) test_inputs = tuple([torch.rand(shape, dtype=test_input_dtype) for shape in test_input_shape]) if any([torch.equal(torch.zeros_like(e), e) for e in test_inputs]): # pragma: no cover warnings.warn("Some inputs were detected to be all zeros, you might want to set a different random_seed.") if not _check_inputs_allclose(_pack_if_no_tuple(module_in.inverse(*_pack_if_no_tuple(module_in(*test_inputs)))), test_inputs, atol=atol): return False test_outputs = _pack_if_no_tuple(module_in(*test_inputs)) if any([torch.equal(torch.zeros_like(e), e) for e in test_outputs]): # pragma: no cover warnings.warn("Some outputs were detected to be all zeros, you might want to set a different random_seed.") if not _check_inputs_allclose(_pack_if_no_tuple(module_in(*_pack_if_no_tuple(module_in.inverse(*test_outputs)))), test_outputs, atol=atol): # pragma: no cover return False test_reconstructed_inputs = _pack_if_no_tuple(module_in.inverse(*test_outputs)) def _test_shared(inputs, outputs, msg): shared = set(inputs) shared_outputs = set(outputs) if len(inputs) != len(shared): # pragma: no cover warnings.warn("Some inputs (*x) share the same tensor, are you sure this is what you want? ({})".format(msg)) if len(outputs) != len(shared_outputs): warnings.warn("Some outputs (*y) share the same tensor, are you sure this is what you want? ({})".format(msg)) if any([inp in shared for inp in shared_outputs]): warnings.warn("Some inputs (*x) and outputs (*y) share the same tensor, this is typically not a " "good function to use with memcnn.InvertibleModuleWrapper as it might increase memory usage. " "E.g. an identity function. ({})".format(msg)) _test_shared(test_inputs, test_outputs, msg="forward") _test_shared(test_reconstructed_inputs, test_outputs, msg="inverse") return True
# We can't know if the run_fn will internally move some args to different devices, # which would require logic to preserve rng states for those devices as well. # We could paranoically stash and restore ALL the rng states for all visible devices, # but that seems very wasteful for most cases. Compromise: Stash the RNG state for # the device of all Tensor args. # # To consider: maybe get_device_states and set_device_states should reside in torch/random.py? # # get_device_states and set_device_states cannot be imported from torch.utils.checkpoint, since it was not # present in older versions, so we include a copy here. def get_device_states(*args): # This will not error out if "arg" is a CPU tensor or a non-tensor type because # the conditionals short-circuit. fwd_gpu_devices = list(set(arg.get_device() for arg in args if isinstance(arg, torch.Tensor) and arg.is_cuda)) fwd_gpu_states = [] for device in fwd_gpu_devices: with torch.cuda.device(device): fwd_gpu_states.append(torch.cuda.get_rng_state()) return fwd_gpu_devices, fwd_gpu_states def set_device_states(devices, states): for device, state in zip(devices, states): with torch.cuda.device(device): torch.cuda.set_rng_state(state)