Skip to content

Commit 2329dbb

Browse files
committed
Merge branch 'm-a-hall-resumable'
2 parents 34db7da + 19b46f6 commit 2329dbb

File tree

10 files changed

+276
-137
lines changed

10 files changed

+276
-137
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
distributionBase=GRADLE_USER_HOME
22
distributionPath=wrapper/dists
3+
distributionUrl=https\://services.gradle.org/distributions/gradle-4.10.2-bin.zip
34
zipStoreBase=GRADLE_USER_HOME
45
zipStorePath=wrapper/dists
5-
distributionUrl=https\://services.gradle.org/distributions/gradle-4.8-all.zip

src/main/java/weka/classifiers/functions/Dl4jMlpClassifier.java

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,15 @@ public class Dl4jMlpClassifier extends RandomizableClassifier
155155
* The number of epochs to perform.
156156
*/
157157
protected int numEpochs = 10;
158+
/** The current upper bound for the number of epochs */
159+
protected int maxEpochs = 0;
158160
/**
159-
* The number of epochs that have been performed.
161+
* The total number of epochs that have been performed.
160162
*/
161163
protected int numEpochsPerformed;
164+
165+
/** The number of epochs performed in this session of iterating */
166+
protected int numEpochsPerformedThisSession;
162167
/**
163168
* The dataset trainIterator.
164169
*/
@@ -211,6 +216,12 @@ public class Dl4jMlpClassifier extends RandomizableClassifier
211216
*/
212217
protected LogConfiguration logConfig = new LogConfiguration();
213218

219+
/**
220+
* Whether to allow training to continue at a later point after the initial
221+
* model is built.
222+
*/
223+
protected boolean resume;
224+
214225

215226
/**
216227
* Get the log configuration.
@@ -231,8 +242,7 @@ public LogConfiguration getLogConfig() {
231242
description = "The log configuration.",
232243
commandLineParamName = "logConfig",
233244
commandLineParamSynopsis = "-logConfig <LogConfiguration>",
234-
displayOrder = 1
235-
)
245+
displayOrder = 1)
236246
public void setLogConfig(LogConfiguration logConfig) {
237247
this.logConfig = logConfig;
238248
}
@@ -617,6 +627,34 @@ public void setQueueSize(int QueueSize) {
617627
queueSize = QueueSize;
618628
}
619629

630+
/**
631+
* If called with argument true, then the next time done() is called the model is effectively
632+
* "frozen" and no further iterations can be performed
633+
*
634+
* @param resume true if the model is to be finalized after performing iterations
635+
*/
636+
637+
@OptionMetadata(description = "Set whether training can be resumed at a later date",
638+
displayName = "Allow training to be resumed after the set number of epochs",
639+
commandLineParamName = "resume",
640+
commandLineParamSynopsis = "-resume",
641+
commandLineParamIsFlag = true,
642+
displayOrder = 31)
643+
public void setResume(boolean resume) {
644+
this.resume = resume;
645+
}
646+
647+
/**
648+
* Returns true if the model is to be finalized (or has been finalized) after
649+
* training.
650+
*
651+
* @return the current value of finalize
652+
*/
653+
public boolean getResume() {
654+
return resume;
655+
}
656+
657+
620658
/**
621659
* The method used to train the classifier.
622660
*
@@ -654,9 +692,20 @@ public void buildClassifier(Instances data) throws Exception {
654692
public void initializeClassifier(Instances data) throws Exception {
655693
logConfig.apply();
656694

695+
if (trainData != null && trainData.numInstances() > 0) {
696+
// Resume run: only initialize iterator
697+
trainIterator = getDataSetIterator(trainData);
698+
return;
699+
}
700+
701+
if (trainData != null) {
702+
if (!trainData.equalHeaders(data)) {
703+
throw new WekaException(trainData.equalHeadersMsg(data));
704+
}
705+
}
657706

658707
// If only class is present, build zeroR
659-
if (data.numAttributes() == 1 && data.classIndex() == 0) {
708+
if (zeroR == null && data.numAttributes() == 1 && data.classIndex() == 0) {
660709
zeroR = new ZeroR();
661710
zeroR.buildClassifier(data);
662711
return;
@@ -698,11 +747,14 @@ public void initializeClassifier(Instances data) throws Exception {
698747
try {
699748
Thread.currentThread().setContextClassLoader(this.getClass().getClassLoader());
700749

701-
// If zoo model was set, use this model as internal MultiLayerNetwork
702-
if (useZooModel()) {
703-
createZooModel();
704-
} else {
705-
createModel();
750+
// Could be null due to resuming from a previous run
751+
if (model == null) {
752+
// If zoo model was set, use this model as internal MultiLayerNetwork
753+
if (useZooModel()) {
754+
createZooModel();
755+
} else {
756+
createModel();
757+
}
706758
}
707759
// Initialize iterator
708760
instanceIterator.initialize();
@@ -715,11 +767,12 @@ public void initializeClassifier(Instances data) throws Exception {
715767
log.info(model.conf().toYaml());
716768
}
717769

770+
numEpochsPerformedThisSession = 0;
771+
maxEpochs += numEpochs; // set the current upper bound
772+
718773
// Set the iteration listener
719774
model.setListeners(getListener());
720775

721-
numEpochsPerformed = 0;
722-
723776
isInitializationFinished = true;
724777
} finally {
725778
Thread.currentThread().setContextClassLoader(origLoader);
@@ -1201,10 +1254,12 @@ protected TrainingListener getListener() throws Exception {
12011254

12021255
// Initialize weka listener
12031256
if (iterationListener instanceof weka.dl4j.listener.EpochListener) {
1204-
int numEpochs = getNumEpochs();
1257+
// int numEpochs = getNumEpochs();
1258+
int numEpochs = maxEpochs;
12051259
iterationListener
12061260
.init(
12071261
trainData.numClasses(),
1262+
numEpochsPerformed,
12081263
numEpochs,
12091264
numSamples,
12101265
trainIterator,
@@ -1219,7 +1274,7 @@ protected TrainingListener getListener() throws Exception {
12191274
*/
12201275
public boolean next() throws Exception {
12211276

1222-
if (numEpochsPerformed >= getNumEpochs() || zeroR != null || trainData == null) {
1277+
if (numEpochsPerformedThisSession >= getNumEpochs() || zeroR != null || trainData == null) {
12231278
return false;
12241279
}
12251280

@@ -1238,7 +1293,8 @@ public boolean next() throws Exception {
12381293
trainIterator.reset();
12391294
sw.stop();
12401295
numEpochsPerformed++;
1241-
log.info("Epoch [{}/{}] took {}", numEpochsPerformed, numEpochs, sw.toString());
1296+
numEpochsPerformedThisSession++;
1297+
log.info("Epoch [{}/{}] took {}", numEpochsPerformed, maxEpochs, sw.toString());
12421298
} finally {
12431299
Thread.currentThread().setContextClassLoader(origLoader);
12441300
}
@@ -1274,7 +1330,7 @@ public boolean useEarlyStopping() {
12741330
*/
12751331
public void done() {
12761332

1277-
trainData = null;
1333+
trainData = new Instances(trainData,0);
12781334
}
12791335

12801336
/**

src/main/java/weka/classifiers/functions/RnnSequenceClassifier.java

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,8 @@
3737
import org.nd4j.linalg.indexing.NDArrayIndex;
3838
import weka.classifiers.Classifier;
3939
import weka.classifiers.rules.ZeroR;
40-
import weka.core.Capabilities;
40+
import weka.core.*;
4141
import weka.core.Capabilities.Capability;
42-
import weka.core.CapabilitiesHandler;
43-
import weka.core.Instance;
44-
import weka.core.Instances;
45-
import weka.core.MissingOutputLayerException;
46-
import weka.core.OptionHandler;
47-
import weka.core.OptionMetadata;
4842
import weka.dl4j.CacheMode;
4943
import weka.dl4j.layers.Layer;
5044
import weka.dl4j.zoo.CustomNet;
@@ -84,6 +78,18 @@ public void initializeClassifier(Instances data) throws Exception {
8478
// Can classifier handle the data?
8579
getCapabilities().testWithFail(data);
8680

81+
if (trainData != null && trainData.numInstances() > 0) {
82+
// Resume run: only initialize iterator
83+
trainIterator = getDataSetIterator(trainData);
84+
return;
85+
}
86+
87+
if (trainData != null) {
88+
if (!trainData.equalHeaders(data)) {
89+
throw new WekaException(trainData.equalHeadersMsg(data));
90+
}
91+
}
92+
8793
// Check basic network structure
8894
if (layers.length == 0) {
8995
throw new MissingOutputLayerException("No layers have been added!");
@@ -95,7 +101,7 @@ public void initializeClassifier(Instances data) throws Exception {
95101
}
96102

97103
// If only class is present, build zeroR
98-
if(data.numAttributes() == 1 && data.classIndex() == 0){
104+
if(zeroR == null && data.numAttributes() == 1 && data.classIndex() == 0){
99105
zeroR = new ZeroR();
100106
zeroR.buildClassifier(data);
101107
return;
@@ -110,15 +116,20 @@ public void initializeClassifier(Instances data) throws Exception {
110116

111117
instanceIterator.initialize();
112118

113-
createModel();
119+
// Could be null due to resuming from a previous run
120+
if (model == null){
121+
createModel();
122+
}
114123

115124
// Setup the datasetiterators (needs to be done after the model initialization)
116125
trainIterator = getDataSetIterator(this.trainData);
117126

118127
// Set the iteration listener
119128
model.setListeners(getListener());
120129

121-
numEpochsPerformed = 0;
130+
numEpochsPerformedThisSession = 0;
131+
maxEpochs += numEpochs; // set the current upper bound
132+
122133
} finally {
123134
Thread.currentThread().setContextClassLoader(origLoader);
124135
}

src/main/java/weka/core/LogConfiguration.java

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,15 @@
3737
import java.net.URI;
3838
import java.nio.file.Paths;
3939
import java.util.Collection;
40+
import java.util.Enumeration;
4041

4142
/**
4243
* General logger configuration.
4344
*
4445
* @author Steven Lang
4546
*/
4647
@Log4j2
47-
public class LogConfiguration implements Serializable {
48+
public class LogConfiguration implements Serializable, OptionHandler {
4849

4950
private static final long serialVersionUID = 7910114399022582661L;
5051
/**
@@ -299,6 +300,34 @@ protected void updateLogLevel(String loggerName, LogLevel level) {
299300
Configurator.setLevel(loggerName, level.getLevel());
300301
}
301302

303+
/**
304+
* Returns an enumeration describing the available options.
305+
*
306+
* @return an enumeration of all the available options
307+
*/
308+
@Override public Enumeration<Option> listOptions() {
309+
return Option.listOptionsForClass(this.getClass()).elements();
310+
}
311+
312+
/**
313+
* Parses a given list of options.
314+
*
315+
* @param options the list of options as an array of strings
316+
* @throws Exception if an option is not supported
317+
*/
318+
@Override public void setOptions(String[] options) throws Exception {
319+
Option.setOptions(options, this, this.getClass());
320+
}
321+
322+
/**
323+
* Gets the current settings of the log configuration
324+
*
325+
* @return return an array of strings suitable for passing to setOptions
326+
*/
327+
@Override public String[] getOptions() {
328+
return Option.getOptions(this, this.getClass());
329+
}
330+
302331
/**
303332
* Available log levels.
304333
*

src/main/java/weka/dl4j/listener/EpochListener.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public class EpochListener extends TrainingListener {
4444
private static final long serialVersionUID = -8852994767947925554L;
4545

4646
/** Epoch counter */
47-
private int currentEpoch = 0;
47+
// private int currentEpoch = 0;
4848

4949
/** Evaluate every N epochs */
5050
private int n = 5;

src/main/java/weka/dl4j/listener/TrainingListener.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ public abstract class TrainingListener
4343
protected int numClasses;
4444
/** Number of classes */
4545
protected int numEpochs;
46+
/** The current epoch */
47+
protected int currentEpoch;
4648
/** Training dataset iterator */
4749
protected transient DataSetIterator validationIterator;
4850
/** Validation dataset iterator */
@@ -58,11 +60,13 @@ public abstract class TrainingListener
5860
*/
5961
public void init(
6062
int numClasses,
63+
int currentEpoch,
6164
int numEpochs,
6265
int numSamples,
6366
DataSetIterator trainIterator,
6467
DataSetIterator validationIterator) {
6568
this.numClasses = numClasses;
69+
this.currentEpoch = currentEpoch;
6670
this.numEpochs = numEpochs;
6771
this.numSamples = numSamples;
6872
this.trainIterator = trainIterator;

0 commit comments

Comments
 (0)