Implementation of Lipschitz Monotonic Networks, from the ICLR 2023 Submission: https://openreview.net/pdf?id=w2P7fMy_RH
The code here allows one to apply various weight constraints on torch.nn.Linear
layers through the kind
keyword. Here are the available weight norms:
"one", # |W|_1 constraint
"inf", # |W|_inf constraint
"one-inf", # |W|_1,inf constraint
"two-inf", # |W|_2,inf constraint
Check that you have the right kind
of lipschitz constraint.
If you are not sure: kind="one-inf"
in the first layer, kind="inf"
in all following layers.
The default (kind="one"
) works well ONLY when the output is scalar!
Note that the package used to be called monotonenorm
and was renamed to monotonicnetworks
on 2023-07-15. The old package is still available on PyPI and conda-forge, but will not be updated.
(deprecated) pip install monotonenorm
(deprecated) conda install -c okitouni monotonenorm
Make sure you have the following packages installed:
- torch (required)
- matplotlib (optional, for plotting examples)
- tqdm (optional, to run the examples with a progress bar)
Here's an example showing two ways to create a Lipschitz-constrained linear layer.
from torch import nn
import monotonicnetworks as lmn
linear_by_norming = lmn.direct_norm(nn.Linear(10, 10), kind="one-inf") # |W|_1,inf constraint
linear_native = lmn.LipschitzLinear(10, 10, kind="one-inf") # |W|_1,inf constraint
The function lmn.direct_norm
can apply various weight constraints on torch.nn.Linear layers through the kind
keyword and return a Lipschitz-constrained linear layer. Alternatively, the code in montonenorm/LipschitzMonotonicNetwork.py
contains several classes that can be used to create Lipschitz and Monotonic Layers directly.
-
The
LipschitzLinear
class is a linear layer with a Lipschitz constraint on its weights. -
The
MonotonicLayer
class is a linear layer with a Lipschitz constraint on its weights and monotonicity constraints that can be specified for each input dimension, or for each input-output pair. For instance, suppose we want to model a 2 input x 3 output linear layer. We specify the monotonic constraints wrt. the first input:[1,0,-1]
. Thus, the first output is monotonically increasing (1), the second has no constraint (0) and the third is monotonically decreasing (-1) wrt. the first input. For the second input:[0,1,0]
only the second output has a monotonically increasing constraint. The code for this looks as follows:
import monotonicnetworks as lmn
linear = lmn.MonotonicLayer(2, 3, monotonic_constraints=[[1, 0, -1], [0, 1, 0]])
The accepted 2D tensor shape for monotonic constraints is [input_dim, output_dim].
Using a 1D tensor for the constraint assumes that they are the same for each output dimension. By default, the code assumes all outputs are monotonically increasing with all inputs.
- The
MonotonicWrapper
class is a wrapper around a module with a Lipschitz constant. It adds a term to the output of the module which enforces monotonicity constraints given by monotonic_constraints. The class returns a module that is monotonic and Lipschitz with constant lipschitz_const. This is the preferred way to create a monotonic network. Example:
from torch import nn
import monotonicnetworks as lmn
lip_nn = nn.Sequential(
lmn.LipschitzLinear(2, 32, kind="one-inf"),
lmn.GroupSort(2),
lmn.LipschitzLinear(32, 2, kind="inf"),
)
monotonic_nn = lmn.MonotonicWrapper(lip_nn, monotonic_constraints=[1,0]) # first input increasing, no monotonicity constraints on second input
Note that one can stack monotonic modules.
-
The
SigmaNet
class is a deprecated class that is equivalent to the MonotonicWrapper class. -
The
RMSNorm
class is a class that implements the RMSNorm normalization layer. It can help when training a model with many Lipschitz-constrained layers.
Check out the Examples
directory for more details. Specifically, Examples/flower.py
shows how to train a Lipschitz Monotonic Network to regress on a complex decision boundary in 2D (under Lipschitz NNs can describe arbitrarily complex boundaries), and Examples/Examples_paper.ipynb
for the code used to make the plots under Monotonicity and Robustness.
We will make a simple toy regression model to fit the following 1D function
Training a monotonic NN and an unconstrained NN on the purple points and evaluating the networks on a uniform grid gives the following result:
Now we will make a different toy model with one noisy data point. This will show that the Lipschitz continuous network is more robust against outliers than an unconstrained network because its gradient with respect to the input is bounded between -1 and 1. Additionally, it is more robust against adversarial attacks/data corruption for the same reason.
GroupSort weight-constrained Neural Networks are universal approximators of Lipschitz continuous functions. Furthermore, they can describe arbitrarily complex decision boundaries in classification problems provided the proper objective function is used in training. In Examples\flower.py
we provide code to regress on an example "complex" decision boundary in 2D.
Here are the contour lines of the resulting network (along with the training points in black).