Skip to content

Commit e829684

Browse files
committed
fix cpu only
1 parent 81260bc commit e829684

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

train_ssd.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@
118118
tf.app.flags.DEFINE_boolean(
119119
'ignore_missing_vars', True,
120120
'When restoring a checkpoint would ignore missing variables.')
121+
tf.app.flags.DEFINE_boolean(
122+
'multi_gpu', True,
123+
'Whether there is GPU to use for training.')
121124

122125
FLAGS = tf.app.flags.FLAGS
123126
#CUDA_VISIBLE_DEVICES
@@ -129,22 +132,24 @@ def validate_batch_size_for_multi_gpu(batch_size):
129132
directly. Multi-GPU support is currently experimental, however,
130133
so doing the work here until that feature is in place.
131134
"""
132-
from tensorflow.python.client import device_lib
133-
134-
local_device_protos = device_lib.list_local_devices()
135-
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
136-
if not num_gpus:
137-
raise ValueError('Multi-GPU mode was specified, but no GPUs '
138-
'were found. To use CPU, run without --multi_gpu.')
139-
140-
remainder = batch_size % num_gpus
141-
if remainder:
142-
err = ('When running with multiple GPUs, batch size '
143-
'must be a multiple of the number of available GPUs. '
144-
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
145-
).format(num_gpus, batch_size, batch_size - remainder)
146-
raise ValueError(err)
147-
return num_gpus
135+
if FLAGS.multi_gpu:
136+
from tensorflow.python.client import device_lib
137+
138+
local_device_protos = device_lib.list_local_devices()
139+
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
140+
if not num_gpus:
141+
raise ValueError('Multi-GPU mode was specified, but no GPUs '
142+
'were found. To use CPU, run --multi_gpu=False.')
143+
144+
remainder = batch_size % num_gpus
145+
if remainder:
146+
err = ('When running with multiple GPUs, batch size '
147+
'must be a multiple of the number of available GPUs. '
148+
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
149+
).format(num_gpus, batch_size, batch_size - remainder)
150+
raise ValueError(err)
151+
return num_gpus
152+
return 0
148153

149154
def get_init_fn():
150155
return scaffolds.get_init_fn_for_scaffold(FLAGS.model_dir, FLAGS.checkpoint_path,

utility/checkpint_inspect.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,4 @@ def print_all_tensors_name(file_name):
5252
"with SNAPPY.")
5353

5454
if __name__ == "__main__":
55-
print_all_tensors_name('./model/vgg16.ckpt')
55+
print_all_tensors_name('./model/vgg16_reducedfc.ckpt')

0 commit comments

Comments
 (0)