Skip to content

Commit 7ca8795

Browse files
committed
workaround for iteration valid sample check
1 parent be72d3e commit 7ca8795

File tree

1 file changed

+38
-8
lines changed

1 file changed

+38
-8
lines changed

preprocessing/ssd_preprocessing.py

+38-8
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,36 @@ def ssd_random_expand(image, bboxes, ratio=2., name=None):
315315

316316
return big_canvas, absolute_bboxes / tf.cast(tf.stack([canvas_height, canvas_width, canvas_height, canvas_width]), bboxes.dtype)
317317

318+
# def ssd_random_sample_patch_wrapper(image, labels, bboxes):
319+
# with tf.name_scope('ssd_random_sample_patch_wrapper'):
320+
# orgi_image, orgi_labels, orgi_bboxes = image, labels, bboxes
321+
# def check_bboxes(bboxes):
322+
# areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
323+
# return tf.logical_and(tf.logical_and(areas < 0.9, areas > 0.001),
324+
# tf.logical_and((bboxes[:, 3] - bboxes[:, 1]) > 0.025, (bboxes[:, 2] - bboxes[:, 0]) > 0.025))
325+
326+
# index = 0
327+
# max_attempt = 3
328+
# def condition(index, image, labels, bboxes):
329+
# return tf.logical_or(tf.logical_and(tf.reduce_sum(tf.cast(check_bboxes(bboxes), tf.int64)) < 1, tf.less(index, max_attempt)), tf.less(index, 1))
330+
331+
# def body(index, image, labels, bboxes):
332+
# image, bboxes = tf.cond(tf.random_uniform([], minval=0., maxval=1., dtype=tf.float32) < 0.5,
333+
# lambda: (image, bboxes),
334+
# lambda: ssd_random_expand(image, bboxes, tf.random_uniform([1], minval=1.1, maxval=4., dtype=tf.float32)[0]))
335+
# # Distort image and bounding boxes.
336+
# random_sample_image, labels, bboxes = ssd_random_sample_patch(image, labels, bboxes, ratio_list=[-0.1, 0.1, 0.3, 0.5, 0.7, 0.9, 1.])
337+
# random_sample_image.set_shape([None, None, 3])
338+
# return index+1, random_sample_image, labels, bboxes
339+
340+
# [index, image, labels, bboxes] = tf.while_loop(condition, body, [index, orgi_image, orgi_labels, orgi_bboxes], parallel_iterations=4, back_prop=False, swap_memory=True)
341+
342+
# valid_mask = check_bboxes(bboxes)
343+
# labels, bboxes = tf.boolean_mask(labels, valid_mask), tf.boolean_mask(bboxes, valid_mask)
344+
# return tf.cond(tf.less(index, max_attempt),
345+
# lambda : (image, labels, bboxes),
346+
# lambda : (orgi_image, orgi_labels, orgi_bboxes))
347+
318348
def ssd_random_sample_patch_wrapper(image, labels, bboxes):
319349
with tf.name_scope('ssd_random_sample_patch_wrapper'):
320350
orgi_image, orgi_labels, orgi_bboxes = image, labels, bboxes
@@ -324,20 +354,20 @@ def check_bboxes(bboxes):
324354
tf.logical_and((bboxes[:, 3] - bboxes[:, 1]) > 0.025, (bboxes[:, 2] - bboxes[:, 0]) > 0.025))
325355

326356
index = 0
327-
max_attempt = 1
328-
def condition(index, image, labels, bboxes):
357+
max_attempt = 3
358+
def condition(index, image, labels, bboxes, orgi_image, orgi_labels, orgi_bboxes):
329359
return tf.logical_or(tf.logical_and(tf.reduce_sum(tf.cast(check_bboxes(bboxes), tf.int64)) < 1, tf.less(index, max_attempt)), tf.less(index, 1))
330360

331-
def body(index, image, labels, bboxes):
361+
def body(index, image, labels, bboxes, orgi_image, orgi_labels, orgi_bboxes):
332362
image, bboxes = tf.cond(tf.random_uniform([], minval=0., maxval=1., dtype=tf.float32) < 0.5,
333-
lambda: (image, bboxes),
334-
lambda: ssd_random_expand(image, bboxes, tf.random_uniform([1], minval=1.1, maxval=4., dtype=tf.float32)[0]))
363+
lambda: (orgi_image, orgi_bboxes),
364+
lambda: ssd_random_expand(orgi_image, orgi_bboxes, tf.random_uniform([1], minval=1.1, maxval=4., dtype=tf.float32)[0]))
335365
# Distort image and bounding boxes.
336-
random_sample_image, labels, bboxes = ssd_random_sample_patch(image, labels, bboxes, ratio_list=[-0.1, 0.1, 0.3, 0.5, 0.7, 0.9, 1.])
366+
random_sample_image, labels, bboxes = ssd_random_sample_patch(image, orgi_labels, bboxes, ratio_list=[-0.1, 0.1, 0.3, 0.5, 0.7, 0.9, 1.])
337367
random_sample_image.set_shape([None, None, 3])
338-
return index+1, random_sample_image, labels, bboxes
368+
return index+1, random_sample_image, labels, bboxes, orgi_image, orgi_labels, orgi_bboxes
339369

340-
[index, image, labels, bboxes] = tf.while_loop(condition, body, [index, orgi_image, orgi_labels, orgi_bboxes], parallel_iterations=4, back_prop=False, swap_memory=True)
370+
[index, image, labels, bboxes, orgi_image, orgi_labels, orgi_bboxes] = tf.while_loop(condition, body, [index, image, labels, bboxes, orgi_image, orgi_labels, orgi_bboxes], parallel_iterations=4, back_prop=False, swap_memory=True)
341371

342372
valid_mask = check_bboxes(bboxes)
343373
labels, bboxes = tf.boolean_mask(labels, valid_mask), tf.boolean_mask(bboxes, valid_mask)

0 commit comments

Comments
 (0)