Skip to content

Commit 8eb5b5d

Browse files
committed
[Fix] VFI: update param_scheduler, hooks, and epoch_base_runner
1 parent 5dc5589 commit 8eb5b5d

File tree

6 files changed

+25
-42
lines changed

6 files changed

+25
-42
lines changed

configs/video_interpolators/cain/cain_b5_g1b32_vimeo90k_triplet.py

+11-14
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,7 @@
120120
]
121121
test_evaluator = val_evaluator
122122

123-
# 1604 iters == 1 epoch
124-
epoch_length = 1604
125-
126-
train_cfg = dict(
127-
type='IterBasedTrainLoop', max_iters=300_000, val_interval=epoch_length)
123+
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500)
128124
val_cfg = dict(type='ValLoop')
129125
test_cfg = dict(type='TestLoop')
130126

@@ -138,23 +134,24 @@
138134
# learning policy
139135
param_scheduler = dict(
140136
type='ReduceLR',
141-
by_epoch=False,
137+
by_epoch=True,
142138
mode='min',
143139
factor=0.5,
144140
patience=5,
145-
cooldown=0,
146-
verbose=True)
141+
cooldown=0)
147142

148143
default_hooks = dict(
149144
checkpoint=dict(
150-
type='CheckpointHook',
151-
interval=epoch_length * 4,
152-
save_optimizer=True,
153-
by_epoch=False),
145+
type='CheckpointHook', interval=1, save_optimizer=True, by_epoch=True),
154146
timer=dict(type='IterTimerHook'),
155-
logger=dict(type='LoggerHook', interval=1),
147+
logger=dict(type='LoggerHook', interval=100),
156148
sampler_seed=dict(type='DistSamplerSeedHook'),
157149
# visualization=dict(type='EditVisualizationHook'),
158150
param_scheduler=dict(
159-
type='ReduceLRSchedulerHook', by_epoch=False, val_metric='MAE'),
151+
type='ReduceLRSchedulerHook',
152+
by_epoch=True,
153+
interval=1,
154+
val_metric='MAE'),
160155
)
156+
157+
log_processor = dict(type='LogProcessor', by_epoch=True)

configs/video_interpolators/flavr/flavr_in4out1_g8b4_vimeo90k_septuplet.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,7 @@
131131
]
132132
test_evaluator = val_evaluator
133133

134-
epoch_length = 2020
135-
136-
train_cfg = dict(
137-
type='IterBasedTrainLoop', max_iters=1_000_000, val_interval=epoch_length)
134+
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500)
138135
val_cfg = dict(type='ValLoop')
139136
test_cfg = dict(type='TestLoop')
140137

@@ -146,30 +143,26 @@
146143
))
147144

148145
# learning policy
149-
# 1604 iters == 1 epoch
150-
total_iters = 1000000
151-
lr_config = dict(
146+
param_scheduler = dict(
152147
type='ReduceLR',
153-
by_epoch=False,
148+
by_epoch=True,
154149
mode='min',
155150
factor=0.5,
156151
patience=10,
157-
cooldown=20,
158-
verbose=True)
152+
cooldown=20)
159153

160154
default_hooks = dict(
161155
checkpoint=dict(
162-
type='CheckpointHook',
163-
interval=epoch_length * 2,
164-
save_optimizer=True,
165-
by_epoch=False),
156+
type='CheckpointHook', interval=1, save_optimizer=True, by_epoch=True),
166157
timer=dict(type='IterTimerHook'),
167158
logger=dict(type='LoggerHook', interval=100),
168159
sampler_seed=dict(type='DistSamplerSeedHook'),
169160
# visualization=dict(type='EditVisualizationHook'),
170161
param_scheduler=dict(
171162
type='ReduceLRSchedulerHook',
172-
by_epoch=False,
173-
interval=epoch_length,
163+
by_epoch=True,
164+
interval=1,
174165
val_metric='MAE'),
175166
)
167+
168+
log_processor = dict(type='LogProcessor', by_epoch=True)

configs/video_interpolators/tof/base_tof_vfi_nobn_1xb1_vimeo90k_triplet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
))
9090

9191
# learning policy
92-
lr_config = dict(
92+
param_scheduler = dict(
9393
type='StepLR',
9494
by_epoch=False,
9595
gamma=0.5,

mmedit/optimizer/scheduler/reduce_lr_scheduler.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ class ReduceLR(_ParamScheduler):
4747
eps (float, optional): Minimal decay applied to lr. If the difference
4848
between new and old lr is smaller than eps, the update is
4949
ignored. Default: 1e-8.
50-
verbose (bool): If ``True``, prints a message to stdout for
51-
each update. Default: ``False``.
5250
begin (int): Step at which to start updating the learning rate.
5351
Defaults to 0.
5452
end (int): Step at which to stop updating the learning rate.
@@ -68,7 +66,6 @@ def __init__(self,
6866
cooldown: int = 0,
6967
min_lr: float = 0.,
7068
eps: float = 1e-8,
71-
verbose: bool = False,
7269
**kwargs):
7370

7471
super().__init__(optimizer=optimizer, param_name='lr', **kwargs)
@@ -99,7 +96,6 @@ def __init__(self,
9996
self.mode_worse = None # the worse value for the chosen mode
10097
self.min_lr = min_lr
10198
self.eps = eps
102-
self.verbose = verbose
10399
self.last_epoch = 0
104100
self._init_is_better(self.mode)
105101
self._reset()
@@ -130,11 +126,7 @@ def _get_value(self):
130126
for group in self.optimizer.param_groups:
131127
regular_lr = group[self.param_name]
132128
if regular_lr - regular_lr * self.factor > self.eps:
133-
new_lr = max(regular_lr * self.factor, self.min_lr)
134-
if self.verbose:
135-
print(f'Reducing learning rate of {group} from '
136-
f'{regular_lr:.4e} to {new_lr:.4e}.')
137-
regular_lr = new_lr
129+
regular_lr = max(regular_lr * self.factor, self.min_lr)
138130
results.append(regular_lr)
139131
return results
140132

mmedit/registry.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,15 @@ def register_all_modules(init_default_scope: bool = True) -> None:
7878
Defaults to True.
7979
""" # noqa
8080
import mmedit.datasets # noqa: F401,F403
81+
import mmedit.hooks # noqa: F401,F403
8182
import mmedit.metrics # noqa: F401,F403
8283
import mmedit.models # noqa: F401,F403
84+
import mmedit.optimizer # noqa: F401,F403
8385
import mmedit.transforms # noqa: F401,F403
8486

8587
if init_default_scope:
8688
never_created = DefaultScope.get_current_instance() is None \
87-
or not DefaultScope.check_instance_created('mmedit')
89+
or not DefaultScope.check_instance_created('mmedit')
8890
if never_created:
8991
DefaultScope.get_instance('mmedit', scope_name='mmedit')
9092
return

tools/dist_train.sh

-1
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,4 @@ python -m torch.distributed.launch \
1616
--master_port=$PORT \
1717
$(dirname "$0")/train.py \
1818
$CONFIG \
19-
--seed 0 \
2019
--launcher pytorch ${@:3}

0 commit comments

Comments
 (0)