Skip to content

Commit bce6d72

Browse files
ayushidalmiafacebook-github-bot
authored andcommitted
Make lightning reproducible
Summary: X-link: facebookresearch/d2go#661 X-link: fairinternal/detectron2#603 Pull Request resolved: #5273 In this diff we make changes to ensure we can control reproducibility in d2go: - update setup.py to enforce deterministic performance if set via config - set lightning parameters if deterministic is passed: ``` { "sync_batchnorm": True, "deterministic": True, "replace_sampler_ddp": False, } ``` - allow passing prefetch_factor, pin_memory, persistent_memory as args to batch dataloader. - minor fix in training sampler Differential Revision: D55767128 fbshipit-source-id: eeab50c95969a91c58f1773473b6fc666494cc16
1 parent 3eef7a5 commit bce6d72

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

detectron2/data/build.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,9 @@ def build_batch_data_loader(
301301
collate_fn=None,
302302
drop_last: bool = True,
303303
single_gpu_batch_size=None,
304+
prefetch_factor=2,
305+
persistent_workers=False,
306+
pin_memory=False,
304307
seed=None,
305308
**kwargs,
306309
):
@@ -375,8 +378,11 @@ def build_batch_data_loader(
375378
num_workers=num_workers,
376379
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
377380
worker_init_fn=worker_init_reset_seed,
381+
prefetch_factor=prefetch_factor if num_workers > 0 else None,
382+
persistent_workers=persistent_workers,
383+
pin_memory=pin_memory,
378384
generator=generator,
379-
**kwargs
385+
**kwargs,
380386
)
381387

382388

detectron2/data/samplers/distributed_sampler.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def __iter__(self):
6161

6262
def _infinite_indices(self):
6363
g = torch.Generator()
64-
g.manual_seed(self._seed)
64+
if self._seed is not None:
65+
g.manual_seed(self._seed)
6566
while True:
6667
if self._shuffle:
6768
yield from torch.randperm(self._size, generator=g).tolist()

detectron2/utils/env.py

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def seed_all_rng(seed=None):
4242
np.random.seed(seed)
4343
torch.manual_seed(seed)
4444
random.seed(seed)
45+
torch.cuda.manual_seed_all(str(seed))
4546
os.environ["PYTHONHASHSEED"] = str(seed)
4647

4748

0 commit comments

Comments
 (0)