Skip to content

Commit 886e93a

Browse files
committed
Removes the positivity constraint on the weights of the skip connections as it was unneccessary for maintaining convexity.
1 parent 8b931f6 commit 886e93a

File tree

85 files changed

+254
-237
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+254
-237
lines changed

README.md

+11-16
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
# AI Verification: Constrained Deep Learning [![Open in MATLAB Online](https://www.mathworks.com/images/responsive/global/open-in-matlab-online.svg)](https://matlab.mathworks.com/open/github/v1?repo=matlab-deep-learning/constrained-deep-learning)
1+
# AI Verification: Constrained Deep Learning
22

33
Constrained deep learning is an advanced approach to training deep neural networks by incorporating domain-specific constraints into the learning process. By integrating these constraints into the construction and training of neural networks, you can guarantee desirable behaviour in safety-critical scenarios where such guarantees are paramount.
44

55
This project aims to develop and evaluate deep learning models that adhere to predefined constraints, which could be in the form of physical laws, logical rules, or any other domain-specific knowledge. In the context of AI verification, constrained deep learning provides guarantees that certain desirable properties are present in the trained neural network by design. These desirable properties could include monotonicity, boundedness, and robustness amongst others.
66

77
<figure>
88
<p align="center">
9-
<img src="./documentation/figures/constrained_learning.svg">
9+
<img src="./documentation/figures/constrained_learning.svg"
10+
style="width:4in;height:1.1in">
1011
</p>
1112
</figure>
1213

@@ -32,12 +33,12 @@ The repository contains several introductory, interactive examples as well as lo
3233

3334
### Introductory Examples (Short)
3435
Below are links for markdown versions of MATLAB Live Scripts that you can view in GitHub&reg;.
35-
- [Fully Input Convex Neural Networks in 1-Dimension](examples/convex/introductory/PoC_Ex1_1DFICNN.md)
36-
- [Fully Input Convex Neural Networks in n-Dimensions](examples/convex/introductory/PoC_Ex2_nDFICNN.md)
37-
- [Partially Input Convex Neural Networks in n-Dimensions](examples/convex/introductory/PoC_Ex3_nDPICNN.md)
38-
- [Fully Input Monotonic Neural Networks in 1-Dimension](examples/monotonic/introductory/PoC_Ex1_1DFMNN.md)
39-
- [Fully Input Monotonic Neural Networks in n-Dimensions](examples/monotonic/introductory/PoC_Ex2_nDFMNN.md)
40-
- [Lipschitz Continuous Neural Networks in 1-Dimension](examples/lipschitz/introductory/PoC_Ex1_1DLNN.md)
36+
- [Fully input convex neural networks in 1-dimension](examples/convex/introductory/PoC_Ex1_1DFICNN.md)
37+
- [Fully input convex neural networks in n-dimensions](examples/convex/introductory/PoC_Ex2_nDFICNN.md)
38+
- [Partially input convex neural networks in n-dimensions](examples/convex/introductory/PoC_Ex3_nDPICNN.md)
39+
- [Fully input monotonic neural networks in 1-dimension](examples/monotonic/introductory/PoC_Ex1_1DFMNN.md)
40+
- [Fully input monotonic neural networks in n-dimensions](examples/monotonic/introductory/PoC_Ex2_nDFMNN.md)
41+
- [Lipschitz continuous neural networks in 1-dimensions](examples/lipschitz/introductory/PoC_Ex1_1DLNN.md)
4142

4243
These examples make use of [custom training loops](https://uk.mathworks.com/help/deeplearning/deep-learning-custom-training-loops.html) and the [`arrayDatastore`](https://uk.mathworks.com/help/matlab/ref/matlab.io.datastore.arraydatastore.html) object. To learn more, click the links.
4344

@@ -70,13 +71,7 @@ As discussed in [1] (see 3.4.1.5), in certain situations, small violations in th
7071

7172
## Technical Articles
7273

73-
This repository focuses on the development and evaluation of deep learning models that adhere to constraints crucial for safety-critical applications, such as predictive maintenance for industrial machinery and equipment. Specifically, it focuses on enforcing monotonicity, convexity, and Lipschitz continuity within neural networks to ensure predictable and controlled behavior.
74-
75-
By emphasizing constraints like monotonicity, constrained neural networks ensure that predictions of the Remaining Useful Life (RUL) of components behave intuitively: as a machine's condition deteriorates, the estimated RUL should monotonically decrease. This is crucial in applications like aerospace or manufacturing, where an accurate and reliable estimation of RUL can prevent failures and save costs.
76-
77-
Alongside monotonicity, Lipschitz continuity is also enforced to guarantee model robustness and controlled behavior. This is essential in environments where safety and precision are paramount such as control systems in autonomous vehicles or precision equipment in healthcare.
78-
79-
Convexity is especially beneficial for control systems as it inherently provides boundedness properties. For instance, by ensuring that the output of a neural network lies within a convex hull, it is possible to guarantee that the control commands remain within a safe and predefined operational space, preventing erratic or unsafe system behaviors. This boundedness property, derived from the convex nature of the model's output space, is critical for maintaining the integrity and safety of control systems under various conditions.
74+
This repository focuses on the development and evaluation of deep learning models that adhere to constraints crucial for safety-critical applications, such as predictive maintenance for industrial machinery and equipment. Specifically, it focuses on enforcing monotonicity, convexity, and Lipschitz continuity within neural networks to ensure predictable and controlled behavior. By emphasizing constraints like monotonicity, constrained neural networks ensure that predictions of the Remaining Useful Life (RUL) of components behave intuitively: as a machine's condition deteriorates, the estimated RUL should monotonically decrease. This is crucial in applications like aerospace or manufacturing, where an accurate and reliable estimation of RUL can prevent failures and save costs. Alongside monotonicity, Lipschitz continuity is also enforced to guarantee model robustness and controlled behavior. This is essential in environments where safety and precision are paramount such as control systems in autonomous vehicles or precision equipment in healthcare. Convexity is especially beneficial for control systems as it inherently provides boundedness properties. For instance, by ensuring that the output of a neural network lies within a convex hull, it is possible to guarantee that the control commands remain within a safe and predefined operational space, preventing erratic or unsafe system behaviors. This boundedness property, derived from the convex nature of the model's output space, is critical for maintaining the integrity and safety of control systems under various conditions.
8075

8176
These technical articles explain key concepts of AI verification in the context of constrained deep learning. They include discussions on how to achieve the specified constraints in neural networks at construction and training time, as well as deriving and proving useful properties of constrained networks in AI verification applications. It is not necessary to go through these articles in order to explore this repository, however, you can find references and more in depth discussion here.
8277

@@ -90,4 +85,4 @@ These technical articles explain key concepts of AI verification in the context
9085
- [3] Gouk, Henry, et al. “Regularisation of Neural Networks by Enforcing Lipschitz Continuity.” Machine Learning, vol. 110, no. 2, Feb. 2021, pp. 393–416. DOI.org (Crossref), https://doi.org/10.1007/s10994-020-05929-w
9186
- [4] Kitouni, Ouail, et al. Expressive Monotonic Neural Networks. arXiv:2307.07512, arXiv, 14 July 2023. arXiv.org, http://arxiv.org/abs/2307.07512.
9287

93-
Copyright 2024, The MathWorks, Inc.
88+
Copyright (c) 2024, The MathWorks, Inc.

conslearn/+conslearn/+convex/buildFICNN.m

+7-7
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
%
1414
% BUILDFICNN name-value arguments:
1515
%
16-
% 'PositiveNonDecreasingActivation' - Specify the positive, convex,
16+
% 'ConvexNonDecreasingActivation' - Specify the convex,
1717
% non-decreasing activation functions.
1818
% The options are 'softplus' or 'relu'.
1919
% The default is 'softplus'.
2020
%
2121
% The construction of this network corresponds to Eq 2 in [1] with the
22-
% exception that the application of the positive, non-decreasing activation
22+
% exception that the application of the convex, non-decreasing activation
2323
% function on the network output is not applied. This maintains convexity
2424
% but permits positive and negative network outputs.
2525
%
@@ -31,7 +31,7 @@
3131
arguments
3232
inputSize (1,:)
3333
numHiddenUnits (1,:)
34-
options.PositiveNonDecreasingActivation = 'softplus'
34+
options.ConvexNonDecreasingActivation = 'softplus'
3535
end
3636

3737
% Construct the correct input layer
@@ -43,7 +43,7 @@
4343
end
4444

4545
% Loop over construction of hidden units
46-
switch options.PositiveNonDecreasingActivation
46+
switch options.ConvexNonDecreasingActivation
4747
case 'relu'
4848
pndFcn = @(k)reluLayer(Name="pnd_" + k);
4949
case 'softplus'
@@ -68,10 +68,10 @@
6868

6969
% Add a cascading residual connection
7070
for ii = 2:depth
71-
tempLayers = fullyConnectedLayer(numHiddenUnits(ii),Name="fc_y_+_" + ii);
71+
tempLayers = fullyConnectedLayer(numHiddenUnits(ii),Name="fc_y_" + ii);
7272
lgraph = addLayers(lgraph,tempLayers);
73-
lgraph = connectLayers(lgraph,"input","fc_y_+_" + ii);
74-
lgraph = connectLayers(lgraph,"fc_y_+_" + ii,"add_" + ii + "/in2");
73+
lgraph = connectLayers(lgraph,"input","fc_y_" + ii);
74+
lgraph = connectLayers(lgraph,"fc_y_" + ii,"add_" + ii + "/in2");
7575
end
7676

7777
% Initialize dlnetwork

conslearn/+conslearn/+convex/buildPICNN.m

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
%
1414
% BUILDPICNN name-value arguments:
1515
%
16-
% 'PositiveNonDecreasingActivation' - Specify the positive, convex,
16+
% 'ConvexNonDecreasingActivation' - Specify the convex,
1717
% non-decreasing activation functions.
1818
% The options are 'softplus' or 'relu'.
1919
% The default is 'softplus'.
@@ -32,7 +32,7 @@
3232
% default value is 1.
3333
%
3434
% The construction of this network corresponds to Eq 3 in [1] with the
35-
% exception that the application of the positive, non-decreasing activation
35+
% exception that the application of the convex, non-decreasing activation
3636
% function on the network output is not applied. This maintains convexity
3737
% but permits positive and negative network outputs. Additionally, and in
3838
% keeping with the notation used in the reference, in this implementation
@@ -50,7 +50,7 @@
5050
arguments
5151
inputSize (1,:) {iValidateInputSize(inputSize)}
5252
numHiddenUnits (1,:)
53-
options.PositiveNonDecreasingActivation = 'softplus'
53+
options.ConvexNonDecreasingActivation = 'softplus'
5454
options.Activation = 'tanh'
5555
options.ConvexChannelIdx = 1
5656
end
@@ -63,7 +63,7 @@
6363
convexInputSize = numel(convexChannels);
6464

6565
% Prepare the two types of valid activation functions
66-
switch options.PositiveNonDecreasingActivation
66+
switch options.ConvexNonDecreasingActivation
6767
case 'relu'
6868
pndFcn = @(k)reluLayer(Name="pnd_" + k);
6969
case 'softplus'

conslearn/buildConstrainedNetwork.m

+12-12
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
%
2020
% These options and default values apply to convex constrained networks:
2121
%
22-
% PositiveNonDecreasingActivation - Positive, convex, non-decreasing
22+
% ConvexNonDecreasingActivation - Convex, non-decreasing
2323
% ("fully-convex") activation functions.
2424
% ("partially-convex") The options are "softplus" or "relu".
2525
% The default is "softplus".
@@ -96,10 +96,10 @@
9696
iValidateInputSize(inputSize)}
9797
numHiddenUnits (1,:) {mustBeInteger,mustBeReal,mustBePositive}
9898
% Convex
99-
options.PositiveNonDecreasingActivation {...
99+
options.ConvexNonDecreasingActivation {...
100100
mustBeTextScalar, ...
101-
mustBeMember(options.PositiveNonDecreasingActivation,["relu","softplus"]),...
102-
iValidateConstraintWithPositiveNonDecreasingActivation(options.PositiveNonDecreasingActivation, constraint)}
101+
mustBeMember(options.ConvexNonDecreasingActivation,["relu","softplus"]),...
102+
iValidateConstraintWithConvexNonDecreasingActivation(options.ConvexNonDecreasingActivation, constraint)}
103103
options.ConvexChannelIdx (1,:) {...
104104
iValidateConstraintWithConvexChannelIdx(options.ConvexChannelIdx, inputSize, constraint), ...
105105
mustBeNumeric,mustBePositive,mustBeInteger}
@@ -131,15 +131,15 @@
131131
switch constraint
132132
case "fully-convex"
133133
% Set defaults
134-
if ~any(fields(options) == "PositiveNonDecreasingActivation")
135-
options.PositiveNonDecreasingActivation = "softplus";
134+
if ~any(fields(options) == "ConvexNonDecreasingActivation")
135+
options.ConvexNonDecreasingActivation = "softplus";
136136
end
137137
net = conslearn.convex.buildFICNN(inputSize, numHiddenUnits, ...
138-
PositiveNonDecreasingActivation=options.PositiveNonDecreasingActivation);
138+
ConvexNonDecreasingActivation=options.ConvexNonDecreasingActivation);
139139
case "partially-convex"
140140
% Set defaults
141-
if ~any(fields(options) == "PositiveNonDecreasingActivation")
142-
options.PositiveNonDecreasingActivation = "softplus";
141+
if ~any(fields(options) == "ConvexNonDecreasingActivation")
142+
options.ConvexNonDecreasingActivation = "softplus";
143143
end
144144
if ~any(fields(options) == "Activation")
145145
options.Activation = "tanh";
@@ -148,7 +148,7 @@
148148
options.ConvexChannelIdx = 1;
149149
end
150150
net = conslearn.convex.buildPICNN(inputSize, numHiddenUnits,...
151-
PositiveNonDecreasingActivation=options.PositiveNonDecreasingActivation,...
151+
ConvexNonDecreasingActivation=options.ConvexNonDecreasingActivation,...
152152
Activation=options.Activation,...
153153
ConvexChannelIdx=options.ConvexChannelIdx);
154154
case "fully-monotonic"
@@ -259,9 +259,9 @@ function iValidateConstraintWithMonotonicTrend(param, constraint)
259259
end
260260
end
261261

262-
function iValidateConstraintWithPositiveNonDecreasingActivation(param, constraint)
262+
function iValidateConstraintWithConvexNonDecreasingActivation(param, constraint)
263263
if ( ~isequal(constraint, "fully-convex") && ~isequal(constraint,"partially-convex") ) && ~isempty(param)
264-
error("'PositiveNonDecreasingActivation' is not an option for constraint " + constraint);
264+
error("'ConvexNonDecreasingActivation' is not an option for constraint " + constraint);
265265
end
266266
end
267267

conslearn/trainConstrainedNetwork.m

+10
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,16 @@
167167
end
168168
end
169169
end
170+
171+
% Update the training monitor status
172+
if trainingOptions.TrainingMonitor
173+
if monitor.Stop == 1
174+
monitor.Status = "Training stopped";
175+
else
176+
monitor.Status = "Training complete";
177+
end
178+
end
179+
170180
end
171181

172182
%% Helpers

0 commit comments

Comments
 (0)