Skip to content

Commit 5003d0e

Browse files
authored
sample pixels from all data in variable-resolution batches (#2772)
* sample pixels from all data in variable-resolution batches * sort import block * sort imports * reorder imports * clean up warnings * simplify index extraction
1 parent 2091a0d commit 5003d0e

File tree

1 file changed

+12
-18
lines changed

1 file changed

+12
-18
lines changed

nerfstudio/data/pixel_samplers.py

+12-18
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import random
2020
import warnings
21+
from collections import defaultdict
2122
from dataclasses import dataclass, field
2223
from typing import Dict, Optional, Type, Union
2324

@@ -335,8 +336,7 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
335336

336337
# only sample within the mask, if the mask is in the batch
337338
all_indices = []
338-
all_images = []
339-
all_depth_images = []
339+
all_images = defaultdict(list)
340340

341341
assert num_rays_per_batch % 2 == 0, "num_rays_per_batch must be divisible by 2"
342342
num_rays_per_image = divide_rays_per_image(num_rays_per_batch, num_images)
@@ -350,10 +350,11 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
350350
)
351351
indices[:, 0] = i
352352
all_indices.append(indices)
353-
all_images.append(batch["image"][i][indices[:, 1], indices[:, 2]])
354-
if "depth_image" in batch:
355-
all_depth_images.append(batch["depth_image"][i][indices[:, 1], indices[:, 2]])
356353

354+
for key, value in batch.items():
355+
if key in ["image_idx", "mask"]:
356+
continue
357+
all_images[key].append(value[i][indices[:, 1], indices[:, 2]])
357358
else:
358359
for i, num_rays in enumerate(num_rays_per_image):
359360
image_height, image_width, _ = batch["image"][i].shape
@@ -363,26 +364,19 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
363364
indices = self.sample_method(num_rays, 1, image_height, image_width, device=device)
364365
indices[:, 0] = i
365366
all_indices.append(indices)
366-
all_images.append(batch["image"][i][indices[:, 1], indices[:, 2]])
367-
if "depth_image" in batch:
368-
all_depth_images.append(batch["depth_image"][i][indices[:, 1], indices[:, 2]])
367+
for key, value in batch.items():
368+
if key in ["image_idx", "mask"]:
369+
continue
370+
all_images[key].append(value[i][indices[:, 1], indices[:, 2]])
369371

370372
indices = torch.cat(all_indices, dim=0)
371373

372-
c, y, x = (i.flatten() for i in torch.split(indices, 1, dim=-1))
373-
collated_batch = {
374-
key: value[c, y, x]
375-
for key, value in batch.items()
376-
if key not in ("image_idx", "image", "mask", "depth_image") and value is not None
377-
}
378-
379-
collated_batch["image"] = torch.cat(all_images, dim=0)
380-
if "depth_image" in batch:
381-
collated_batch["depth_image"] = torch.cat(all_depth_images, dim=0)
374+
collated_batch = {key: torch.cat(all_images[key], dim=0) for key in all_images}
382375

383376
assert collated_batch["image"].shape[0] == num_rays_per_batch
384377

385378
# Needed to correct the random indices to their actual camera idx locations.
379+
c = indices[..., 0].flatten()
386380
indices[:, 0] = batch["image_idx"][c]
387381
collated_batch["indices"] = indices # with the abs camera indices
388382

0 commit comments

Comments
 (0)