Welcome to MemCNN’s documentation!¶
MemCNN¶
A PyTorch framework for developing memory-efficient invertible neural networks.
- Free software: MIT license (please cite our work if you use it)
- Documentation: https://memcnn.readthedocs.io.
- Installation: https://memcnn.readthedocs.io/en/latest/installation.html
Features¶
- Enable memory savings during training by wrapping arbitrary invertible PyTorch functions with the InvertibleModuleWrapper class.
- Simple toggling of memory saving by setting the keep_input property of the InvertibleModuleWrapper.
- Turn arbitrary non-linear PyTorch functions into invertible versions using the AdditiveCoupling or the AffineCoupling classes.
- Training and evaluation code for reproducing RevNet experiments using MemCNN.
- CI tests for Python v3.7 and torch v1.0, v1.1, v1.4 and v1.7 with good code coverage.
Examples¶
Creating an AdditiveCoupling with memory savings¶
import torch
import torch.nn as nn
import memcnn
# define a new torch Module with a sequence of operations: Relu o BatchNorm2d o Conv2d
class ExampleOperation(nn.Module):
def __init__(self, channels):
super(ExampleOperation, self).__init__()
self.seq = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=channels,
kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(num_features=channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.seq(x)
# generate some random input data (batch_size, num_channels, y_elements, x_elements)
X = torch.rand(2, 10, 8, 8)
# application of the operation(s) the normal way
model_normal = ExampleOperation(channels=10)
model_normal.eval()
Y = model_normal(X)
# turn the ExampleOperation invertible using an additive coupling
invertible_module = memcnn.AdditiveCoupling(
Fm=ExampleOperation(channels=10 // 2),
Gm=ExampleOperation(channels=10 // 2)
)
# test that it is actually a valid invertible module (has a valid inverse method)
assert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape)
# wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training
invertible_module_wrapper = memcnn.InvertibleModuleWrapper(fn=invertible_module, keep_input=True, keep_input_inverse=True)
# by default the module is set to training, the following sets this to evaluation
# note that this is required to pass input tensors to the model with requires_grad=False (inference only)
invertible_module_wrapper.eval()
# test that the wrapped module is also a valid invertible module
assert memcnn.is_invertible_module(invertible_module_wrapper, test_input_shape=X.shape)
# compute the forward pass using the wrapper
Y2 = invertible_module_wrapper.forward(X)
# the input (X) can be approximated (X2) by applying the inverse method of the wrapper on Y2
X2 = invertible_module_wrapper.inverse(Y2)
# test that the input and approximation are similar
assert torch.allclose(X, X2, atol=1e-06)
Run PyTorch Experiments¶
After installing MemCNN run:
python -m memcnn.train [MODEL] [DATASET] [--fresh] [--no-cuda]
- Available values for
DATASET
arecifar10
andcifar100
. - Available values for
MODEL
areresnet32
,resnet110
,resnet164
,revnet38
,revnet110
,revnet164
- Use the
--fresh
flag to remove earlier experiment results. - Use the
--no-cuda
flag to train on the CPU rather than the GPU through CUDA.
Datasets are automatically downloaded if they are not available.
When using Python 3.* replace the python
directive with the appropriate Python 3 directive. For example when using the MemCNN docker image use python3.6
.
When MemCNN was installed using pip or from sources you might need to setup a configuration file before running this command. Read the corresponding section about how to do this here: https://memcnn.readthedocs.io/en/latest/installation.html
Results¶
TensorFlow results were obtained from the reversible residual network running the code from their GitHub.
The PyTorch results listed were recomputed on June 11th 2018, and differ from the results in the ICLR paper. The Tensorflow results are still the same.
Prediction accuracy¶
Cifar-10 | Cifar-100 | |||
---|---|---|---|---|
Model | Tensorflow | PyTorch | Tensorflow | PyTorch |
resnet-32 | 92.74 | 92.86 | 69.10 | 69.81 |
resnet-110 | 93.99 | 93.55 | 73.30 | 72.40 |
resnet-164 | 94.57 | 94.80 | 76.79 | 76.47 |
revnet-38 | 93.14 | 92.80 | 71.17 | 69.90 |
revnet-110 | 94.02 | 94.10 | 74.00 | 73.30 |
revnet-164 | 94.56 | 94.90 | 76.39 | 76.90 |
Training time (hours : minutes)¶
Cifar-10 | Cifar-100 | |||
---|---|---|---|---|
Model | Tensorflow | PyTorch | Tensorflow | PyTorch |
resnet-32 | 2:04 | 1:51 | 1:58 | 1:51 |
resnet-110 | 4:11 | 2:51 | 6:44 | 2:39 |
resnet-164 | 11:05 | 4:59 | 10:59 | 3:45 |
revnet-38 | 2:17 | 2:09 | 2:20 | 2:16 |
revnet-110 | 6:59 | 3:42 | 7:03 | 3:50 |
revnet-164 | 13:09 | 7:21 | 13:12 | 7:17 |
Memory consumption of model training in PyTorch¶
Layers | Parameters | Parameters (MB) | Activations (MB) | ||||
---|---|---|---|---|---|---|---|
ResNet | RevNet | ResNet | RevNet | ResNet | RevNet | ResNet | RevNet |
32 | 38 | 466906 | 573994 | 1.9 | 2.3 | 238.6 | 85.6 |
110 | 110 | 1730714 | 1854890 | 6.8 | 7.3 | 810.7 | 85.7 |
164 | 164 | 1704154 | 1983786 | 6.8 | 7.9 | 2452.8 | 432.7 |
The ResNet model is the conventional Residual Network implementation in PyTorch, while the RevNet model uses the memcnn.InvertibleModuleWrapper to achieve memory savings.
Works using MemCNN¶
- MemCNN: a Framework for Developing Memory Efficient Deep Invertible Networks by Sil C. van de Leemput et al.
- Reversible GANs for Memory-efficient Image-to-Image Translation by Tycho van der Ouderaa et al.
- Chest CT Super-resolution and Domain-adaptation using Memory-efficient 3D Reversible GANs by Tycho van der Ouderaa et al.
- iUNets: Fully invertible U-Nets with Learnable Up- and Downsampling by Christian Etmann et al.
Citation¶
Sil C. van de Leemput, Jonas Teuwen, Bram van Ginneken, and Rashindra Manniesing. MemCNN: A Python/PyTorch package for creating memory-efficient invertible neural networks. Journal of Open Source Software, 4, 1576, http://dx.doi.org/10.21105/joss.01576, 2019.
If you use our code, please cite:
@article{vandeLeemput2019MemCNN,
journal = {Journal of Open Source Software},
doi = {10.21105/joss.01576},
issn = {2475-9066},
number = {39},
publisher = {The Open Journal},
title = {MemCNN: A Python/PyTorch package for creating memory-efficient invertible neural networks},
url = {http://dx.doi.org/10.21105/joss.01576},
volume = {4},
author = {Sil C. {van de} Leemput and Jonas Teuwen and Bram {van} Ginneken and Rashindra Manniesing},
pages = {1576},
date = {2019-07-30},
year = {2019},
month = {7},
day = {30},
}
Installation¶
Stable release¶
These are the preferred methods to install MemCNN, as they will always install the most recent stable release.
PyPi¶
To install MemCNN using the Python package manager, run this command in your terminal:
$ pip install memcnn
If you don’t have pip installed, this Python installation guide can guide you through the process.
Anaconda¶
To install MemCNN using Anaconda, run this command in your terminal:
$ conda install -c silvandeleemput -c pytorch -c simpleitk -c conda-forge memcnn
If you don’t have conda installed, this Anaconda installation guide can guide you through the process.
From sources¶
The sources for MemCNN can be downloaded from the Github repo.
You can either clone the public repository:
$ git clone git://github.com/silvandeleemput/memcnn
Or download the tarball:
$ curl -OL https://github.com/silvandeleemput/memcnn/tarball/master
Once you have a copy of the source, you can install it with:
$ python setup.py install
Using docker¶
MemCNN has several pre-build docker images that are hosted on dockerhub. You can directly pull these and to have a working environment for running the experiments.
Run image from repository¶
Run the latest docker build of MemCNN from the repository (automatically pulls the image):
$ docker run --shm-size=4g --runtime=nvidia -it silvandeleemput/memcnn:latest
For --runtime=nvidia
to work nvidia-docker must be installed on your system.
It can be omitted but this will drop GPU training support.
This will open a preconfigured bash shell, which is correctly configured to run the experiments. The latest version has Ubuntu 18.04 and Python 3.7 installed.
By default, the datasets and experimental results will be put inside the created
docker container under: \home\user\data
and
\home\user\experiments
respectively.
Build image from source¶
Requirements:
- NVIDIA graphics card and the proper NVIDIA-drivers on your system
The following bash commands will clone this repository and do a one-time build of the docker image with the right environment installed:
$ git clone https://github.com/silvandeleemput/memcnn.git
$ docker build ./memcnn/docker --tag=silvandeleemput/memcnn:latest
After the one-time install on your machine, the docker image can be invoked using the same commands as listed above.
Experiment configuration file¶
To run the experiments, MemCNN requires setting up a configuration file containing locations to put the data files. This step is not necessary for the docker builds.
The configuration file config.json
goes in the /memcnn/config/
directory of the library and should be formatted as follows:
{
"data_dir": "/home/user/data",
"results_dir": "/home/user/experiments"
}
- data_dir : location for storing the input training datasets
- results_dir : location for storing the experiment files during training
Change the data paths to your liking.
If you are unsure where MemCNN and/or the configuration file is located on your machine run:
$ python -m memcnn.train
If the configuration file is not setup correctly, this command should give the user the correct path to the configuration file. Next, create/edit the file at the given location.
Usage¶
To use MemCNN in a project:
import memcnn
Examples¶
Creating an AdditiveCoupling with memory savings¶
import torch
import torch.nn as nn
import memcnn
# define a new torch Module with a sequence of operations: Relu o BatchNorm2d o Conv2d
class ExampleOperation(nn.Module):
def __init__(self, channels):
super(ExampleOperation, self).__init__()
self.seq = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=channels,
kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(num_features=channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.seq(x)
# generate some random input data (batch_size, num_channels, y_elements, x_elements)
X = torch.rand(2, 10, 8, 8)
# application of the operation(s) the normal way
model_normal = ExampleOperation(channels=10)
model_normal.eval()
Y = model_normal(X)
# turn the ExampleOperation invertible using an additive coupling
invertible_module = memcnn.AdditiveCoupling(
Fm=ExampleOperation(channels=10 // 2),
Gm=ExampleOperation(channels=10 // 2)
)
# test that it is actually a valid invertible module (has a valid inverse method)
assert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape)
# wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training
invertible_module_wrapper = memcnn.InvertibleModuleWrapper(fn=invertible_module, keep_input=True, keep_input_inverse=True)
# by default the module is set to training, the following sets this to evaluation
# note that this is required to pass input tensors to the model with requires_grad=False (inference only)
invertible_module_wrapper.eval()
# test that the wrapped module is also a valid invertible module
assert memcnn.is_invertible_module(invertible_module_wrapper, test_input_shape=X.shape)
# compute the forward pass using the wrapper
Y2 = invertible_module_wrapper.forward(X)
# the input (X) can be approximated (X2) by applying the inverse method of the wrapper on Y2
X2 = invertible_module_wrapper.inverse(Y2)
# test that the input and approximation are similar
assert torch.allclose(X, X2, atol=1e-06)
Run PyTorch Experiments¶
After installing MemCNN run:
python -m memcnn.train [MODEL] [DATASET] [--fresh] [--no-cuda]
- Available values for
DATASET
arecifar10
andcifar100
. - Available values for
MODEL
areresnet32
,resnet110
,resnet164
,revnet38
,revnet110
,revnet164
- Use the
--fresh
flag to remove earlier experiment results. - Use the
--no-cuda
flag to train on the CPU rather than the GPU through CUDA.
Datasets are automatically downloaded if they are not available.
When using Python 3.* replace the python
directive with the appropriate Python 3 directive. For example when using the MemCNN docker image use python3.7
.
Modules¶
Top-level package for MemCNN.
-
memcnn.
is_invertible_module
(module_in, test_input_shape, test_input_dtype=<sphinx.ext.autodoc.importer._MockObject object>, atol=1e-06, random_seed=42)[source]¶ Test if a
torch.nn.Module
is invertibleParameters: - module_in (
torch.nn.Module
) – A torch.nn.Module to test. - test_input_shape (
tuple
ofint
ortuple
oftuple
ofint
) – Dimensions of test tensor(s) object to perform the test with. - test_input_dtype (
torch.dtype
, optional) – Data type of test tensor object to perform the test with. - atol (
float
, optional) – Tolerance value used for comparing the outputs. - random_seed (
int
, optional) – Use this value to seed the pseudo-random test_input_shapes with different numbers.
Returns: True if the input module is invertible, False otherwise.
Return type: - module_in (
-
class
memcnn.
InvertibleModuleWrapper
(fn, keep_input=False, keep_input_inverse=False, num_bwd_passes=1, disable=False, preserve_rng_state=False)[source]¶ The InvertibleModuleWrapper which enables memory savings during training by exploiting the invertible properties of the wrapped module.
Parameters: - fn (
torch.nn.Module
) – A torch.nn.Module which has a forward and an inverse function implemented with \(x == m.inverse(m.forward(x))\) - keep_input (
bool
, optional) – Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass. - keep_input_inverse (
bool
, optional) – Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass.
- :param num_bwd_passes
int
, optional: Number of backward passes to retain a link with the output. After the last backward pass the output - is discarded and memory is freed. Warning: if this value is raised higher than the number of required passes memory will not be freed correctly anymore and the training process can quickly run out of memory. Hence, The typical use case is to keep this at 1, until it raises an error for raising this value.
Parameters: - disable (
bool
, optional) – This will disable using the InvertibleCheckpointFunction altogether. Essentially this renders the function as \(y = fn(x)\) without any of the memory savings. Setting this to true will also ignore the keep_input and keep_input_inverse properties. - preserve_rng_state (
bool
, optional) – Setting this will ensure that the same RNG state is used during reconstruction of the inputs. I.e. if keep_input = False on forward or keep_input_inverse = False on inverse. By default this is False since most invertible modules should have a valid inverse and hence are deterministic.
-
keep_input
¶ Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass.
Type: bool
, optional
-
keep_input_inverse
¶ Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass.
Type: bool
, optional
Note
The InvertibleModuleWrapper can be used with mixed-precision training using
torch.cuda.amp.autocast
as of torch v1.6 and above. However, inputs will always be cast totorch.float32
internally. This is done to minimize autocasting inputs to a different datatype which usually results in a disconnected computation graph and will raise an error on the backward pass.-
forward
(*xin)[source]¶ Forward operation \(R(x) = y\)
Parameters: *xin ( torch.Tensor
tuple) – Input torch tensor(s).Returns: Output torch tensor(s) *y. Return type: torch.Tensor
tuple
-
inverse
(*yin)[source]¶ Inverse operation \(R^{-1}(y) = x\)
Parameters: *yin ( torch.Tensor
tuple) – Input torch tensor(s).Returns: Output torch tensor(s) *x. Return type: torch.Tensor
tuple
- fn (
-
class
memcnn.
AdditiveCoupling
(Fm, Gm=None, implementation_fwd=-1, implementation_bwd=-1, split_dim=1)[source]¶ This computes the output \(y\) on forward given input \(x\) and arbitrary modules \(Fm\) and \(Gm\) according to:
\((x1, x2) = x\)
\(y1 = x1 + Fm(x2)\)
\(y2 = x2 + Gm(y1)\)
\(y = (y1, y2)\)
Parameters: - Fm (
torch.nn.Module
) – A torch.nn.Module encapsulating an arbitrary function - Gm (
torch.nn.Module
) – A torch.nn.Module encapsulating an arbitrary function (If not specified a deepcopy of Fm is used as a Module) - implementation_fwd (
int
) – Switch between different Additive Operation implementations for forward pass. Default = -1 - implementation_bwd (
int
) – Switch between different Additive Operation implementations for inverse pass. Default = -1 - split_dim (
int
) – Dimension to split the input tensors on. Default = 1, generally corresponding to channels.
- Fm (
-
class
memcnn.
AffineCoupling
(Fm, Gm=None, adapter=None, implementation_fwd=-1, implementation_bwd=-1, split_dim=1)[source]¶ This computes the output \(y\) on forward given input \(x\) and arbitrary modules \(Fm\) and \(Gm\) according to:
\((x1, x2) = x\)
\((log({s1}), t1) = Fm(x2)\)
\(s1 = exp(log({s1}))\)
\(y1 = s1 * x1 + t1\)
\((log({s2}), t2) = Gm(y1)\)
\(s2 = exp(log({s2}))\)
\(y2 = s2 * x2 + t2\)
\(y = (y1, y2)\)
Parameters: - Fm (
torch.nn.Module
) – A torch.nn.Module encapsulating an arbitrary function - Gm (
torch.nn.Module
) – A torch.nn.Module encapsulating an arbitrary function (If not specified a deepcopy of Gm is used as a Module) - adapter (
torch.nn.Module
class) – An optional wrapper class A for Fm and Gm which must output s, t = A(x) with shape(s) = shape(t) = shape(x) s, t are respectively the scale and shift tensors for the affine coupling. - implementation_fwd (
int
) – Switch between different Affine Operation implementations for forward pass. Default = -1 - implementation_bwd (
int
) – Switch between different Affine Operation implementations for inverse pass. Default = -1 - split_dim (
int
) – Dimension to split the input tensors on. Default = 1, generally corresponding to channels.
- Fm (
-
class
memcnn.
AffineAdapterNaive
(module)[source]¶ Naive Affine adapter
Outputs exp(f(x)), f(x) given f(.) and x
-
class
memcnn.
AffineAdapterSigmoid
(module)[source]¶ Sigmoid based affine adapter
Partitions the output h of f(x) = h into s and t by extracting every odd and even channel Outputs sigmoid(s), t
-
class
memcnn.
ReversibleBlock
(Fm, Gm=None, coupling='additive', keep_input=False, keep_input_inverse=False, implementation_fwd=-1, implementation_bwd=-1, adapter=None)[source]¶ The ReversibleBlock
Warning
This class has been deprecated. Use the more flexible InvertibleModuleWrapper class.
Note
The implementation_fwd and implementation_bwd parameters can be set to one of the following implementations:
- -1 Naive implementation without reconstruction on the backward pass.
- 0 Memory efficient implementation, compute gradients directly.
- 1 Memory efficient implementation, similar to approach in Gomez et al. 2017.
Parameters: - Fm (
torch.nn.Module
) – A torch.nn.Module encapsulating an arbitrary function - Gm (
torch.nn.Module
, optional) – A torch.nn.Module encapsulating an arbitrary function (If not specified a deepcopy of Fm is used as a Module) - coupling (
str
, optional) – Type of coupling [‘additive’, ‘affine’]. Default = ‘additive’ - keep_input (
bool
, optional) – Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass. - keep_input_inverse (
bool
, optional) – Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass. - implementation_fwd (
int
, optional) – Switch between different Operation implementations for forward training (Default = 1). If using the naive implementation (-1) then keep_input should be True. - implementation_bwd (
int
, optional) – Switch between different Operation implementations for backward training (Default = 1). If using the naive implementation (-1) then keep_input_inverse should be True. - adapter (
class
, optional) – Only relevant when using the ‘affine’ coupling. Should be a class of typetorch.nn.Module
that serves as an optional wrapper class A for Fm and Gm which must output s, t = A(x) with shape(s) = shape(t) = shape(x). s, t are respectively the scale and shift tensors for the affine coupling.
-
keep_input
¶ Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass.
Type: bool
, optional
-
keep_input_inverse
¶ Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass.
Type: bool
, optional
Raises: NotImplementedError
– If an unknown coupling or implementation is given.-
forward
(*xin)¶ Forward operation \(R(x) = y\)
Parameters: *xin ( torch.Tensor
tuple) – Input torch tensor(s).Returns: Output torch tensor(s) *y. Return type: torch.Tensor
tuple
-
inverse
(*yin)¶ Inverse operation \(R^{-1}(y) = x\)
Parameters: *yin ( torch.Tensor
tuple) – Input torch tensor(s).Returns: Output torch tensor(s) *x. Return type: torch.Tensor
tuple
Contributing¶
Contributions are welcome, and they are greatly appreciated! Every little bit helps, and credit will always be given.
You can contribute in many ways:
Types of Contributions¶
Report Bugs¶
Report bugs at https://github.com/silvandeleemput/memcnn/issues.
If you are reporting a bug, please include:
- Your operating system name and version.
- Any details about your local setup that might be helpful in troubleshooting.
- Detailed steps to reproduce the bug.
Fix Bugs¶
Look through the GitHub issues for bugs. Anything tagged with “bug” and “help wanted” is open to whoever wants to implement it.
Implement Features¶
Look through the GitHub issues for features. Anything tagged with “enhancement” and “help wanted” is open to whoever wants to implement it.
Write Documentation¶
MemCNN could always use more documentation, whether as part of the official MemCNN docs, in docstrings, or even on the web in blog posts, articles, and such.
Submit Feedback¶
The best way to send feedback is to file an issue at https://github.com/silvandeleemput/memcnn/issues.
If you are proposing a feature:
- Explain in detail how it would work.
- Keep the scope as narrow as possible, to make it easier to implement.
- Remember that this is a volunteer-driven project, and that contributions are welcome :)
Get Started!¶
Ready to contribute? Here’s how to set up memcnn for local development.
Fork the memcnn repo on GitHub.
Clone your fork locally:
$ git clone git@github.com:your_name_here/memcnn.git
Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:
$ mkvirtualenv memcnn $ cd memcnn/ $ python setup.py develop
Create a branch for local development:
$ git checkout -b name-of-your-bugfix-or-feature
Now you can make your changes locally.
When you’re done making changes, check that your changes pass flake8 and the tests, including testing other Python versions with tox:
$ flake8 memcnn tests $ python setup.py test or py.test $ tox
To get flake8 and tox, just pip install them into your virtualenv.
Commit your changes and push your branch to GitHub:
$ git add . $ git commit -m "Your detailed description of your changes." $ git push origin name-of-your-bugfix-or-feature
Submit a pull request through the GitHub website.
Pull Request Guidelines¶
Before you submit a pull request, check that it meets these guidelines:
- The pull request should include tests.
- If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a docstring, and add the feature to the list in README.rst.
- The pull request should work for Python 2.7, 3.5+, and for PyPy. Check through tox that all the tests pass for all supported Python versions.
Tips¶
To run a subset of tests:
$ pytest memcnn/memcnn/models/tests/test_revop.py
To run a specific test:
$ pytest memcnn/memcnn/models/tests/test_revop.py::test_reversible_block_fwd_bwd
Deploying¶
A reminder for the maintainers on how to deploy. Make sure all your changes are committed (including an entry in HISTORY.rst). Then run:
$ bumpversion patch # possible: major / minor / patch
$ git push
$ git push origin <tag_name>
CircleCI will then deploy to PyPI if tests pass.
Credits¶
Development Lead¶
- Sil van de Leemput <silvandeleemput@gmail.com>
Contributors¶
- Tycho van der Ouderaa
- Jonas Teuwen
- Bram van Ginneken
- Rashindra Manniesing
History¶
1.5.1 (2021-08-07)¶
- Added support for 2-dimensional inputs for AffineAdapterSigmoid
- Fixed CI issues
1.5.0 (2020-11-24)¶
- Added support for mixed-precision training using torch.cuda.amp (inputs fixed to float32 for now)
- Added support for PyTorch v1.7
- Dropped support for PyTorch < v1.0 and Python 2
- Removed the version limit for Pillow in the requirements
1.4.0 (2020-06-05)¶
- Added support for splitting on arbitrary dimensions to the Couplings. Big thanks to ClashLuke for the PR
- Added a preserve_rng_state option to the InvertibleModuleWrapper
1.3.2 (2020-03-05)¶
- Improved InvertibleModuleWrapper * Added support for multi input/output invertible operations! Big thanks to Christian Etmann for the PR
- Improved the is_invertible_module test * Added multi input/output checks * Fixed random seed per default * Additional warning checks have been added
1.3.1 (2020-03-02)¶
- HOTFIX InvertibleCheckpointFunction uses ref_count for inputs as well to avoid memory spikes
1.3.0 (2020-03-01)¶
- Updated underlying mechanics for the InvertibleModuleWrapper * Hooks have been replaced by a torch.autograd.Function called InvertibleCheckpointFunction * Identity functions are now supported
- Reported unstable memory behavior should be fixed now when using the InvertibleModuleWrapper!
- Minor changes to test suite
1.2.1 (2020-02-24)¶
- Added InvertibleModuleWrapper support to is_invertible_module test
1.2.0 (2020-01-19)¶
- Replaced TensorBoard logging with simple json file logging which removed the cumbersome TensorBoard and TensorFlow dependencies
- Updated the Dockerfile for Python37 and PyTorch 1.4.0
- Updated the CI tests Py36 versions to Py37, also added a new CI test for PyTorch 1.4.0
1.1.1 (2020-01-11)¶
- Fixed some versions in the requirements for TensorFlow and Pillow to avoid errors and segfaults
- The module auto documentation has been updated for the new API changes
1.1.0 (2019-12-15)¶
- A complete refactor of MemCNN with changes to the API
- Factored out the code responsible for the memory savings in a separate InvertibleModuleWrapper and reimplemented it using hooks
- The InvertibleModuleWrapper allows for arbitrary invertible functions now (not just the additive and affine couplings)
- The AdditiveBlock and AffineBlock have been refactored to AdditiveCoupling and AffineCoupling
- The ReveribleBlock is now deprecated
- The documentation and examples have been updated for the new API changes
1.0.1 (2019-12-08)¶
- Bug fixes related to SummaryIterator import in Tensorflow 2 (location of summary_iterator has changed in TensorFlow)
- Bug fixes related to NSamplesRandomSampler nsamples attribute (would crash if no-gpu and numpy.int were given)
1.0.0 (2019-07-28)¶
- Major release for completing the JOSS review:
- Anaconda cloud and codacy code quality CI
- Updated/improved documentation
0.3.5 (2019-07-28)¶
- Added CI for anaconda cloud
- Documented conda installation steps
- Minor test release for testing CI build
0.3.4 (2019-07-26)¶
- Performed changes recommended by JOSS reviewers:
- Added requirements.txt to manifest.in
- Added codacy code quality integration
- Improved documentation
- Setup proper github contribution templates
0.3.3 (2019-07-10)¶
- Added docker build triggers to CI
- Finalized JOSS paper.md
0.3.2 (2019-07-10)¶
- Added docker build shield
- Fixed a bug with device agnostic tensor generation for loss.py
- Code cleanup resnet.py
- Added examples to distribution with pytests
- Improved documentation
0.3.1 (2019-07-09)¶
- Added experiments.json and config.json.example data files to the distribution
- Fixed documentation issues with mock modules
0.3.0 (2019-07-09)¶
- Updated major bug in distribution setup.py
- Removed older releases due to bug
- Added the ReversibleBlock at the module level
- Splitted keep_input into keep_input and keep_input_inverse
0.2.1 (2019-06-06 - Removed)¶
- Patched the memory saving tests
0.2.0 (2019-05-28 - Removed)¶
- Minor update with better coverage and affine coupling support
0.1.0 (2019-05-24 - Removed)¶
- First release on PyPI