Skip to content

Commit 4a3e3e6

Browse files
AntonioMacaroniokerrjbrentyi
authored
Dataloading Revamp (#3216)
* initial debugging and testing works * pwais changes with RayBatchStream to alleviate training * few bugs to iron out with multiprocessing, specifically pickled collate_fn * working version of RayBatchStream * additional docstrings * cleanup * much more documentation * successfully trained AEA-script2_seq2 closed_loop without OOM * porting over aria dataset-size feature * added logic to handle eviction of a worker's cached_collated_batch * antonio's implementation of stream batches * training on a dataset with 4000 images works! * some configuration speedups, loops aren't actually needed! * quick fix adjustment to aria * removed unnecessary looping * much faster training when adding i variable to collate every 5 ray bundles * cleanup unnecssary variables in Dataloader * further cleanup * adding caching of compressed images to RAM to reduce disk bottleneck * added caching to RAM for masks * found fast way to collate - many tricks applied * quick update to aria to test on different datasets * cleaned up the accelerated pil_to_numpy function * cleaning up PR * this commit was used to generate the time metrics and profiling metrics * REAL commit used to run tests * testing with nerfacto-big * generated RayBundle collate and converting images from uint8s to float32 on GPU tests * updating nerfacto to support uint8 easily, will need to figure out a way to contain this within the datamanager API * datamanager updates, both splat and nerf * must use writeable arrays because torch requires them * cleaned up base_dataset, added pickle to utils, more code in full_image, and cleaner desc for base_datamanager * lots of process on a parallel FullImageDatamanger * can train big splats with pre-assertion hack or ROI hack and 0 workers * fixed all undistortion issues with ParallelImageDatamanager * adding some downsampling and parallel tests with splatfacto! * deleted commented code in dataloaders.py and added bugfix to shuffling * testing splatfacto-big * cleaned up base_pipeline.py * cleaned up base_pipeline.py ACTUALLY THIS TIME, forgot to save last time * cleaned up a lot of code * process_project_aria back to main branch and some cleanup in full_image_datamanager * clarifying docstrings * further PR cleanup * updating models * further cleanup * removed caching of images into bytestrings * adding caching of compressed images to RAM, forgot that hardware matters * removing oom methods, adding the ability to add a flag to dataloading * removed CacheDataloader, moved RayBatchStream to dataloaders.py, new vanilla_datamanager rewritten * fixing base_piplines, deleting a weird datamanager_configs file that was accidently created * cleaning up next_train * replaced parallel datamanager with new datamanager * reverted the original base_datamanager.py, new datamanager replaced parallel_datamanager.py * modified VanillaConfig, but VanillaDataManager is the same as before * cleaning up, 2 datamanagers now - original and new parallel one * able to train with new nerfstudio dataloader now * side by side datamanagers, moved tons of logic into dataloaders.py and created new files for our parallel datamangers * added custom ray processing API to support implementations like LERF, cleaned up FullImageDatamanager to original because of new ParallelImageDatamanger * adding functionality for ns-eval by adding FixedIndicesEvalDataloader to the setup_eval * adding both ray API and image-view API to datamanagers for custom parallelization * updating splatfacto config for 4k tests * updating docstrings to be more descriptive * new datamanager API breaks when setup_eval() has multiple workers, not sure why but single worker will have to do * adding custom_view_processor to ImageBatchStream * reverting full_images_datamanager to main branch * removing nn.Module inheritance from Datamanager class * don't need to move datamanger to device anymore since Datamanager is not a subclass of nn.Module * finished integration test with nerfacto * simplified config variables, integrated the parallelism/disk-data-loading all into one datamanager * updated the splatfacto config to be simpler with the dataloading and now uses FullImageDatamanager (which has been changed) * style checks and some cleanup * new splatfacto test, cleaning up nerfacto integration test * removing redundant parallel_full_images_datamaanger, as the OG full_image_datamanager now has full parallelized support * ruff linting and pyright fixing * further pyright fixing * another pyright fixing * fixing pyright error, camera optimization no longer part of datamanager * fixing one pyright * fixing dataloading error when camera is not undistorted with dataloader * fixing comments and updating style * undoing a style change i made * undoing another style change i made by accident * fixing slow runtime * fixing a more general camera undistortion bug * move images to device properly * minor improvements * add print statement about >500 images, cleanup method configs * make method configs consistent across nerfacto models * adding description comments * updating description * resolving some pyright issues with export.py, explained in PR desc * fixing pyright issues in base_pipeline.py * ran pyright on exporter and base_pipeline.py without issues * adding a git ignore to a clearly checked pyright issue * typo * fixing most ns-dev-test cases * cleanup, passing final ns-dev-test * oops, accidentally pushed the deletion of a docstring, undoing that * another cleanup * some fixes to eval pipeline * lint * add asserts for spawn * lint * cleaning up import statements in parallel_datamanager.py * adding new developer documentation if users would like to migrate their custom datamanagers to support new features * removing unnecessary to_device no-op * further updates to documentation * lint * more docs * docs * remove comment * add docs, fix depth dataset with parallel datamanager, fix mask sampling bug * remove profiling * more profile removal * custom_view_processor->custom_image_processor * doc clarification * datamanager doc nit * whitespace * nits * remove stuff from __post_init__, tune num workers more, add random offset in raybatchstream * removing unnecessary assertion, updating docstring because DataManager is no longer an nn.Module * clarifying configuration with num_images_to_sample_from and num_times_to_repeat_images, cleaning up functions * adding logic so that nerfacto users can load_from_disk and customize image batch sizes and repeat parameters * ruff formatting! whoops forgot to format * fixing logic, now if users set load_from_disk to true, datamanager will use 50 and 10. If users set it and specify their own values, we support that as well * adding separate datamanager config so that target can be removed in method_configs --------- Co-authored-by: Justin Kerr <[email protected]> Co-authored-by: Brent Yi <[email protected]>
1 parent 189328e commit 4a3e3e6

16 files changed

+1083
-470
lines changed

docs/developer_guides/pipelines/datamanagers.md

+107-3
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,28 @@
1414

1515
## What is a DataManager?
1616

17-
The DataManager returns RayBundle and RayGT objects. Let's first take a look at the most important abstract methods required by the DataManager.
17+
The DataManager batches and returns two components from an input dataset:
18+
19+
1. A representation of viewpoint (either cameras or rays).
20+
- For splatting methods (`FullImageDataManager`): a `Cameras` object.
21+
- For ray sampling methods (`VanillaDataManager`): a `RayBundle` object.
22+
2. A dictionary of ground truth data.
23+
- For splatting methods (`FullImageDataManager`): dictionary contains complete images.
24+
- For ray sampling methods (`VanillaDataManager`): dictionary contains per-ray information.
25+
26+
Behaviors are defined by implementing the abstract methods required by the DataManager:
1827

1928
```python
2029
class DataManager(nn.Module):
2130
"""Generic data manager's abstract class
2231
"""
2332

2433
@abstractmethod
25-
def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
34+
def next_train(self, step: int) -> Tuple[Union[RayBundle, Cameras], Dict]:
2635
"""Returns the next batch of data for train."""
2736

2837
@abstractmethod
29-
def next_eval(self, step: int) -> Tuple[RayBundle, Dict]:
38+
def next_eval(self, step: int) -> Tuple[Union[RayBundle, Cameras], Dict]:
3039
"""Returns the next batch of data for eval."""
3140

3241
@abstractmethod
@@ -94,3 +103,98 @@ See the code!
94103
## Creating Your Own
95104

96105
We currently don't have other implementations because most papers follow the VanillaDataManager implementation. However, it should be straightforward to add a VanillaDataManager with logic that progressively adds cameras, for instance, by relying on the step and modifying RayBundle and RayGT generation logic.
106+
107+
## Disk Caching for Large Datasets
108+
As of January 2025, the FullImageDatamanager and ParallelImageDatamanager implementations now support parallelized dataloading and dataloading from disk to avoid Out-Of-Memory errors and support very large datasets. To train a NeRF-based method with a large dataset that's unable to fit in memory, please add the `load_from_disk` flag to your `ns-train` command. For example with nerfacto:
109+
```bash
110+
ns-train nerfacto --data {PROCESSED_DATA_DIR} --pipeline.datamanager.load-from-disk
111+
```
112+
113+
To train splatfacto with a large dataset that's unable to fit in memory, please set the device of `cache_images` to `"disk"`. For example with splatfacto:
114+
```bash
115+
ns-train splatfacto --data {PROCESSED_DATA_DIR} --pipeline.datamanager.cache-images disk
116+
```
117+
118+
## Migrating Your DataManager to the new DataManager
119+
Many methods subclass a DataManager and add extra data to it. If you would like your custom datamanager to also support new parallel features, you can migrate any custom dataloading logic to the new `custom_ray_processor()` API. This function takes in a full training batch (either image or ray bundle) and allows the user to modify or add to it. Let's take a look at an example for the LERF method, which was built on Nerfstudio's VanillaDataManager. This API provides an interface to attach new information to the RayBundle (for ray based methods), Cameras object (for splatting based methods), or ground truth dictionary. It runs in a background process if disk caching is enabled, otherwise it runs in the main process.
120+
121+
Naively transfering code to `custom_ray_processor` may still OOM on very large datasets if initialization code requires computing something over the whole dataset. To fully take advantage of parallelization make sure your subclassed datamanager computes new information inside the `custom_ray_processor`, or caches a subset of the whole dataset. This can also still be slow if pre-computation requires GPU-heavy steps on the same GPU used for training.
122+
123+
**Note**: Because the parallel DataManager uses background processes, any member of the DataManager needs to be *picklable* to be used inside `custom_ray_processor`.
124+
125+
```python
126+
class LERFDataManager(VanillaDataManager):
127+
"""Subclass VanillaDataManager to add extra data processing
128+
129+
Args:
130+
config: the DataManagerConfig used to instantiate class
131+
"""
132+
133+
config: LERFDataManagerConfig
134+
135+
def __init__(
136+
self,
137+
config: LERFDataManagerConfig,
138+
device: Union[torch.device, str] = "cpu",
139+
test_mode: Literal["test", "val", "inference"] = "val",
140+
world_size: int = 1,
141+
local_rank: int = 0,
142+
**kwargs,
143+
):
144+
super().__init__(
145+
config=config, device=device, test_mode=test_mode, world_size=world_size, local_rank=local_rank, **kwargs
146+
)
147+
# Some code to initialize all the CLIP and DINO feature encoders.
148+
self.image_encoder: BaseImageEncoder = kwargs["image_encoder"]
149+
self.dino_dataloader = ...
150+
self.clip_interpolator = ...
151+
152+
def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
153+
"""Returns the next batch of data from the train dataloader.
154+
155+
In this custom DataManager we need to add on the data that LERF needs, namely CLIP and DINO features.
156+
"""
157+
self.train_count += 1
158+
image_batch = next(self.iter_train_image_dataloader)
159+
assert self.train_pixel_sampler is not None
160+
batch = self.train_pixel_sampler.sample(image_batch)
161+
ray_indices = batch["indices"]
162+
ray_bundle = self.train_ray_generator(ray_indices)
163+
batch["clip"], clip_scale = self.clip_interpolator(ray_indices)
164+
batch["dino"] = self.dino_dataloader(ray_indices)
165+
ray_bundle.metadata["clip_scales"] = clip_scale
166+
# assume all cameras have the same focal length and image width
167+
ray_bundle.metadata["fx"] = self.train_dataset.cameras[0].fx.item()
168+
ray_bundle.metadata["width"] = self.train_dataset.cameras[0].width.item()
169+
ray_bundle.metadata["fy"] = self.train_dataset.cameras[0].fy.item()
170+
ray_bundle.metadata["height"] = self.train_dataset.cameras[0].height.item()
171+
return ray_bundle, batch
172+
```
173+
174+
To migrate this custom datamanager to the new datamanager, we'll subclass the new ParallelDataManager and shift the data customization process from `next_train()` to `custom_ray_processor()`.
175+
The function `custom_ray_processor()` is called with a fully populated ray bundle and ground truth batch, just like the subclassed `next_train` in the above code. This code, however, is run in a background process.
176+
177+
```python
178+
class LERFDataManager(ParallelDataManager, Generic[TDataset]):
179+
"""
180+
__init__ stays the same
181+
"""
182+
183+
...
184+
185+
def custom_ray_processor(
186+
self, ray_bundle: RayBundle, batch: Dict
187+
) -> Tuple[RayBundle, Dict]:
188+
"""An API to add latents, metadata, or other further customization to the RayBundle dataloading process that is parallelized."""
189+
ray_indices = batch["indices"]
190+
batch["clip"], clip_scale = self.clip_interpolator(ray_indices)
191+
batch["dino"] = self.dino_dataloader(ray_indices)
192+
ray_bundle.metadata["clip_scales"] = clip_scale
193+
194+
# Assume all cameras have the same focal length and image dimensions.
195+
ray_bundle.metadata["fx"] = self.train_dataset.cameras[0].fx.item()
196+
ray_bundle.metadata["width"] = self.train_dataset.cameras[0].width.item()
197+
ray_bundle.metadata["fy"] = self.train_dataset.cameras[0].fy.item()
198+
ray_bundle.metadata["height"] = self.train_dataset.cameras[0].height.item()
199+
return ray_bundle, batch
200+
```

nerfstudio/configs/method_configs.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from nerfstudio.configs.external_methods import ExternalMethodDummyTrainerConfig, get_external_methods
2929
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig
3030
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig
31-
from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManagerConfig
31+
from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager, ParallelDataManagerConfig
3232
from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManagerConfig
3333
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
3434
from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig
@@ -220,7 +220,7 @@
220220
mixed_precision=True,
221221
pipeline=VanillaPipelineConfig(
222222
datamanager=VanillaDataManagerConfig(
223-
_target=VanillaDataManager[DepthDataset],
223+
_target=ParallelDataManager[DepthDataset],
224224
dataparser=NerfstudioDataParserConfig(),
225225
train_num_rays_per_batch=4096,
226226
eval_num_rays_per_batch=4096,
@@ -302,7 +302,7 @@
302302
method_configs["mipnerf"] = TrainerConfig(
303303
method_name="mipnerf",
304304
pipeline=VanillaPipelineConfig(
305-
datamanager=ParallelDataManagerConfig(dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=1024),
305+
datamanager=VanillaDataManagerConfig(dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=1024),
306306
model=VanillaModelConfig(
307307
_target=MipNerfModel,
308308
loss_coefficients={"rgb_loss_coarse": 0.1, "rgb_loss_fine": 1.0},
@@ -375,7 +375,7 @@
375375
max_num_iterations=30000,
376376
mixed_precision=False,
377377
pipeline=VanillaPipelineConfig(
378-
datamanager=ParallelDataManagerConfig(
378+
datamanager=VanillaDataManagerConfig(
379379
dataparser=BlenderDataParserConfig(),
380380
train_num_rays_per_batch=4096,
381381
eval_num_rays_per_batch=4096,

nerfstudio/data/datamanagers/base_datamanager.py

+23-47
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from __future__ import annotations
2020

2121
from abc import abstractmethod
22-
from collections import defaultdict
2322
from dataclasses import dataclass, field
2423
from functools import cached_property
2524
from pathlib import Path
@@ -42,7 +41,6 @@
4241

4342
import torch
4443
import tyro
45-
from torch import nn
4644
from torch.nn import Parameter
4745
from torch.utils.data.distributed import DistributedSampler
4846
from typing_extensions import TypeVar
@@ -56,44 +54,19 @@
5654
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
5755
from nerfstudio.data.datasets.base_dataset import InputDataset
5856
from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig
59-
from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader
57+
from nerfstudio.data.utils.dataloaders import (
58+
CacheDataloader,
59+
FixedIndicesEvalDataloader,
60+
RandIndicesEvalDataloader,
61+
variable_res_collate,
62+
)
6063
from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate
6164
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes
6265
from nerfstudio.model_components.ray_generators import RayGenerator
6366
from nerfstudio.utils.misc import IterableWrapper, get_orig_class
6467
from nerfstudio.utils.rich_utils import CONSOLE
6568

6669

67-
def variable_res_collate(batch: List[Dict]) -> Dict:
68-
"""Default collate function for the cached dataloader.
69-
Args:
70-
batch: Batch of samples from the dataset.
71-
Returns:
72-
Collated batch.
73-
"""
74-
images = []
75-
imgdata_lists = defaultdict(list)
76-
for data in batch:
77-
image = data.pop("image")
78-
images.append(image)
79-
topop = []
80-
for key, val in data.items():
81-
if isinstance(val, torch.Tensor):
82-
# if the value has same height and width as the image, assume that it should be collated accordingly.
83-
if len(val.shape) >= 2 and val.shape[:2] == image.shape[:2]:
84-
imgdata_lists[key].append(val)
85-
topop.append(key)
86-
# now that iteration is complete, the image data items can be removed from the batch
87-
for key in topop:
88-
del data[key]
89-
90-
new_batch = nerfstudio_collate(batch)
91-
new_batch["image"] = images
92-
new_batch.update(imgdata_lists)
93-
94-
return new_batch
95-
96-
9770
@dataclass
9871
class DataManagerConfig(InstantiateConfig):
9972
"""Configuration for data manager instantiation; DataManager is in charge of keeping the train/eval dataparsers;
@@ -111,7 +84,7 @@ class DataManagerConfig(InstantiateConfig):
11184
"""Process images on GPU for speed at the expense of memory, if True."""
11285

11386

114-
class DataManager(nn.Module):
87+
class DataManager:
11588
"""Generic data manager's abstract class
11689
11790
This version of the data manager is designed be a monolithic way to load data and latents,
@@ -164,16 +137,16 @@ class DataManager(nn.Module):
164137
train_sampler: Optional[DistributedSampler] = None
165138
eval_sampler: Optional[DistributedSampler] = None
166139
includes_time: bool = False
140+
test_mode: Literal["test", "val", "inference"] = "val"
167141

168142
def __init__(self):
169143
"""Constructor for the DataManager class.
170144
171145
Subclassed DataManagers will likely need to override this constructor.
172146
173-
If you aren't manually calling the setup_train and setup_eval functions from an overriden
174-
constructor, that you call super().__init__() BEFORE you initialize any
175-
nn.Modules or nn.Parameters, but AFTER you've already set all the attributes you need
176-
for the setup functions."""
147+
If you aren't manually calling the setup_train() and setup_eval() functions from an overridden
148+
constructor, please call super().__init__() in your subclass' __init__() method after
149+
you've already set all the attributes you need for the setup functions."""
177150
super().__init__()
178151
self.train_count = 0
179152
self.eval_count = 0
@@ -311,18 +284,22 @@ class VanillaDataManagerConfig(DataManagerConfig):
311284
"""Target class to instantiate."""
312285
dataparser: AnnotatedDataParserUnion = field(default_factory=BlenderDataParserConfig)
313286
"""Specifies the dataparser used to unpack the data."""
287+
cache_images_type: Literal["uint8", "float32"] = "float32"
288+
"""The image type returned from manager, caching images in uint8 saves memory"""
314289
train_num_rays_per_batch: int = 1024
315290
"""Number of rays per batch to use per training iteration."""
316-
train_num_images_to_sample_from: int = -1
317-
"""Number of images to sample during training iteration."""
318-
train_num_times_to_repeat_images: int = -1
319-
"""When not training on all images, number of iterations before picking new
320-
images. If -1, never pick new images."""
291+
train_num_images_to_sample_from: Union[int, float] = float("inf")
292+
"""Number of images to load into CPU RAM to generate RayBundles from during training. If infinity, load
293+
all images into CPU RAM to generate RayBundles."""
294+
train_num_times_to_repeat_images: Union[int, float] = float("inf")
295+
"""Number of RayBundles to generate for a batch of images loaded into CPU RAM before sampling new images.
296+
If infinity, never sample new images. Note: decreasing train_num_images_to_sample_from and increasing
297+
train_num_times_to_repeat_images alleviates CPU bottleneck."""
321298
eval_num_rays_per_batch: int = 1024
322299
"""Number of rays per batch to use per eval iteration."""
323-
eval_num_images_to_sample_from: int = -1
300+
eval_num_images_to_sample_from: Union[int, float] = float("inf")
324301
"""Number of images to sample during eval iteration."""
325-
eval_num_times_to_repeat_images: int = -1
302+
eval_num_times_to_repeat_images: Union[int, float] = float("inf")
326303
"""When not evaluating on all images, number of iterations before picking
327304
new images. If -1, never pick new images."""
328305
eval_image_indices: Optional[Tuple[int, ...]] = (0,)
@@ -331,8 +308,7 @@ class VanillaDataManagerConfig(DataManagerConfig):
331308
"""Specifies the collate function to use for the train and eval dataloaders."""
332309
camera_res_scale_factor: float = 1.0
333310
"""The scale factor for scaling spatial data such as images, mask, semantics
334-
along with relevant information about camera intrinsics
335-
"""
311+
along with relevant information about camera intrinsics"""
336312
patch_size: int = 1
337313
"""Size of patch to sample from. If > 1, patch-based sampling will be used."""
338314

0 commit comments

Comments
 (0)