18
18
19
19
import random
20
20
import warnings
21
+ from collections import defaultdict
21
22
from dataclasses import dataclass , field
22
23
from typing import Dict , Optional , Type , Union
23
24
@@ -335,8 +336,7 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
335
336
336
337
# only sample within the mask, if the mask is in the batch
337
338
all_indices = []
338
- all_images = []
339
- all_depth_images = []
339
+ all_images = defaultdict (list )
340
340
341
341
assert num_rays_per_batch % 2 == 0 , "num_rays_per_batch must be divisible by 2"
342
342
num_rays_per_image = divide_rays_per_image (num_rays_per_batch , num_images )
@@ -350,10 +350,11 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
350
350
)
351
351
indices [:, 0 ] = i
352
352
all_indices .append (indices )
353
- all_images .append (batch ["image" ][i ][indices [:, 1 ], indices [:, 2 ]])
354
- if "depth_image" in batch :
355
- all_depth_images .append (batch ["depth_image" ][i ][indices [:, 1 ], indices [:, 2 ]])
356
353
354
+ for key , value in batch .items ():
355
+ if key in ["image_idx" , "mask" ]:
356
+ continue
357
+ all_images [key ].append (value [i ][indices [:, 1 ], indices [:, 2 ]])
357
358
else :
358
359
for i , num_rays in enumerate (num_rays_per_image ):
359
360
image_height , image_width , _ = batch ["image" ][i ].shape
@@ -363,26 +364,19 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
363
364
indices = self .sample_method (num_rays , 1 , image_height , image_width , device = device )
364
365
indices [:, 0 ] = i
365
366
all_indices .append (indices )
366
- all_images .append (batch ["image" ][i ][indices [:, 1 ], indices [:, 2 ]])
367
- if "depth_image" in batch :
368
- all_depth_images .append (batch ["depth_image" ][i ][indices [:, 1 ], indices [:, 2 ]])
367
+ for key , value in batch .items ():
368
+ if key in ["image_idx" , "mask" ]:
369
+ continue
370
+ all_images [key ].append (value [i ][indices [:, 1 ], indices [:, 2 ]])
369
371
370
372
indices = torch .cat (all_indices , dim = 0 )
371
373
372
- c , y , x = (i .flatten () for i in torch .split (indices , 1 , dim = - 1 ))
373
- collated_batch = {
374
- key : value [c , y , x ]
375
- for key , value in batch .items ()
376
- if key not in ("image_idx" , "image" , "mask" , "depth_image" ) and value is not None
377
- }
378
-
379
- collated_batch ["image" ] = torch .cat (all_images , dim = 0 )
380
- if "depth_image" in batch :
381
- collated_batch ["depth_image" ] = torch .cat (all_depth_images , dim = 0 )
374
+ collated_batch = {key : torch .cat (all_images [key ], dim = 0 ) for key in all_images }
382
375
383
376
assert collated_batch ["image" ].shape [0 ] == num_rays_per_batch
384
377
385
378
# Needed to correct the random indices to their actual camera idx locations.
379
+ c = indices [..., 0 ].flatten ()
386
380
indices [:, 0 ] = batch ["image_idx" ][c ]
387
381
collated_batch ["indices" ] = indices # with the abs camera indices
388
382
0 commit comments