118
118
tf .app .flags .DEFINE_boolean (
119
119
'ignore_missing_vars' , True ,
120
120
'When restoring a checkpoint would ignore missing variables.' )
121
+ tf .app .flags .DEFINE_boolean (
122
+ 'multi_gpu' , True ,
123
+ 'Whether there is GPU to use for training.' )
121
124
122
125
FLAGS = tf .app .flags .FLAGS
123
126
#CUDA_VISIBLE_DEVICES
@@ -129,22 +132,24 @@ def validate_batch_size_for_multi_gpu(batch_size):
129
132
directly. Multi-GPU support is currently experimental, however,
130
133
so doing the work here until that feature is in place.
131
134
"""
132
- from tensorflow .python .client import device_lib
133
-
134
- local_device_protos = device_lib .list_local_devices ()
135
- num_gpus = sum ([1 for d in local_device_protos if d .device_type == 'GPU' ])
136
- if not num_gpus :
137
- raise ValueError ('Multi-GPU mode was specified, but no GPUs '
138
- 'were found. To use CPU, run without --multi_gpu.' )
139
-
140
- remainder = batch_size % num_gpus
141
- if remainder :
142
- err = ('When running with multiple GPUs, batch size '
143
- 'must be a multiple of the number of available GPUs. '
144
- 'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
145
- ).format (num_gpus , batch_size , batch_size - remainder )
146
- raise ValueError (err )
147
- return num_gpus
135
+ if FLAGS .multi_gpu :
136
+ from tensorflow .python .client import device_lib
137
+
138
+ local_device_protos = device_lib .list_local_devices ()
139
+ num_gpus = sum ([1 for d in local_device_protos if d .device_type == 'GPU' ])
140
+ if not num_gpus :
141
+ raise ValueError ('Multi-GPU mode was specified, but no GPUs '
142
+ 'were found. To use CPU, run --multi_gpu=False.' )
143
+
144
+ remainder = batch_size % num_gpus
145
+ if remainder :
146
+ err = ('When running with multiple GPUs, batch size '
147
+ 'must be a multiple of the number of available GPUs. '
148
+ 'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
149
+ ).format (num_gpus , batch_size , batch_size - remainder )
150
+ raise ValueError (err )
151
+ return num_gpus
152
+ return 0
148
153
149
154
def get_init_fn ():
150
155
return scaffolds .get_init_fn_for_scaffold (FLAGS .model_dir , FLAGS .checkpoint_path ,
0 commit comments