Skip to content

Commit 69209e6

Browse files
[Enhancement] Revise file structure and add ut (#1145)
* refactor mmedit/models/utils/ * rename layers to base_archs * move fid-inception and inception_utils to evaluation/funcitonal * merge batch process to edit data processor * move misc to utils * rename loops.py to gen_loops.py * move log_processor to runner * move trans utils to utils * move some function from metric_utils to mmedit/utils * add unit test for inception utils.py * add unit test for average model * add unit test for base translation model * add more ut for sagan generator * add more unit test for biggan generator * add unit test for io_utils and revise the name of variable * add more unit test for biggan deep generator * remove useless ut files and revise some uts * add unit test for stylegan3 * add more unit test for cyclegan * add more unit test for stylegan2 module * skip unit test of CLIP loss on win-cuda env * omit ops under stylegan2 folder * fix ut of alpha.py to avoid the randomness * update base_archs * update base_archs * update ut * fix import error * add unit test for singan-modules * add unit test for sn module * add unit test for stylegan1 and stylegan3 * revise unit test of arcFace * add unit test for wgan-module * complete unit test of gen metric * remove useless unit test files * remove useless uts and revise ut checking script * revise omit file list in CI configs * update dev_scripts Co-authored-by: zenggyh1900 <[email protected]>
1 parent ae31d9b commit 69209e6

File tree

177 files changed

+2250
-1518
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

177 files changed

+2250
-1518
lines changed

.dev_scripts/README.md

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
## Check UT
2+
3+
Please check your UT by the following scripts:
4+
5+
```python
6+
cd mmediting/
7+
python .dev_script/update_ut.py
8+
```
9+
10+
Then, you will find some redundant UT, missing UT and blank UT.
11+
Please create UTs according to your package code implementation.

.dev_scripts/update_ut.py

+43-27
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,67 @@
11
import os
22
import os.path as osp
33
from argparse import ArgumentParser
4+
from fnmatch import fnmatch
45
from glob import glob
56

67
from tqdm import tqdm
78

89
parser = ArgumentParser()
910
parser.add_argument('--src', type=str, default='mmedit')
1011
parser.add_argument('--dst', type=str, default='tests')
12+
parser.add_argument(
13+
'--exclude',
14+
nargs='+',
15+
default=[
16+
'mmedit/.mim', 'mmedit/registry.py', 'mmedit/version.py',
17+
'__pycache__', '__init__', '**/__init__.py',
18+
'**/stylegan3_ops/*', '**/conv2d_gradfix.py', '**/grid_sample_gradfix.py',
19+
'**/misc.py', '**/upfirdn2d.py',
20+
'**/all_gather_layer.py', '**/typing.py'
21+
])
1122
args = parser.parse_args()
1223

1324

25+
def check_exclude(fn):
26+
for pattern in args.exclude:
27+
if fnmatch(fn, pattern):
28+
return True
29+
return False
30+
31+
1432
def update_ut():
1533

16-
folders = [f for f in os.listdir(args.src) if osp.isdir(f'mmedit/{f}')]
1734
target_ut = []
1835
missing_ut = []
1936
blank_ut = []
2037

21-
for subf in folders:
22-
if subf == '.mim' or subf == '__pycache__':
38+
file_list = glob('mmedit/**/*.py', recursive=True)
39+
40+
for f in tqdm(file_list):
41+
if check_exclude(f):
2342
continue
2443

25-
file_list = glob(f'mmedit/{subf}/**/*.py', recursive=True)
26-
27-
for f in tqdm(file_list, desc=f'mmedit/{subf}'):
28-
if osp.splitext(osp.basename(f))[0] != '__init__':
29-
30-
dirname = osp.dirname(f)
31-
dirname = dirname.replace('__', '')
32-
dirname = dirname.replace('mmedit', 'tests')
33-
dirname = dirname.replace('/', '/test_')
34-
os.makedirs(dirname, exist_ok=True)
35-
36-
basename = osp.basename(f)
37-
basename = 'test_' + basename
38-
39-
dst_path = osp.join(dirname, basename)
40-
target_ut.append(dst_path)
41-
if not osp.exists(dst_path):
42-
missing_ut.append(dst_path)
43-
fp = open(dst_path, 'a')
44-
fp.close()
45-
else:
46-
text_lines = open(dst_path, 'r').readlines()
47-
if len(text_lines) <= 3:
48-
blank_ut.append(dst_path)
44+
if osp.splitext(osp.basename(f))[0] != '__init__':
45+
46+
dirname = osp.dirname(f)
47+
dirname = dirname.replace('__', '')
48+
dirname = dirname.replace('mmedit', 'tests')
49+
dirname = dirname.replace('/', '/test_')
50+
os.makedirs(dirname, exist_ok=True)
51+
52+
basename = osp.basename(f)
53+
basename = 'test_' + basename
54+
55+
dst_path = osp.join(dirname, basename)
56+
target_ut.append(dst_path)
57+
if not osp.exists(dst_path):
58+
missing_ut.append(dst_path)
59+
fp = open(dst_path, 'a')
60+
fp.close()
61+
else:
62+
text_lines = open(dst_path, 'r').readlines()
63+
if len(text_lines) <= 3:
64+
blank_ut.append(dst_path)
4965

5066
existing_ut = glob('tests/test_*/**/*.py', recursive=True)
5167
additional_ut = list(set(existing_ut) - set(target_ut))

.github/workflows/merge_stage_test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ jobs:
9797
- name: Run unittests and generate coverage report
9898
run: |
9999
coverage run --branch --source mmedit -m pytest tests/
100-
coverage xml --omit="**/stylegan3_ops/*,**/conv2d_gradfix.py"
100+
coverage xml --omit="**/stylegan3_ops/*,**/conv2d_gradfix.py,**/grid_sample_gradfix.py,**/misc.py,**/upfirdn2d.py,**all_gather_layer.py"
101101
coverage report -m
102102
# Only upload coverage report for python3.7 && pytorch1.8.1 cpu
103103
- name: Upload coverage to Codecov

.github/workflows/pr_stage_test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
- name: Run unittests and generate coverage report
4646
run: |
4747
coverage run --branch --source mmedit -m pytest tests/
48-
coverage xml --omit="**/stylegan3_ops/*,**/conv2d_gradfix.py"
48+
coverage xml --omit="**/stylegan3_ops/*,**/conv2d_gradfix.py,**/grid_sample_gradfix.py,**/misc.py,**/upfirdn2d.py,**all_gather_layer.py"
4949
coverage report -m
5050
# Upload coverage report for python3.7 && pytorch1.8.1 cpu
5151
- name: Upload coverage to Codecov

demo/inpainting_demo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66

77
from mmedit.apis import init_model, inpainting_inference
8-
from mmedit.engine import tensor2img
8+
from mmedit.utils import tensor2img
99

1010

1111
def parse_args():

demo/restoration_demo.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import torch
77

88
from mmedit.apis import init_model, restoration_inference
9-
from mmedit.engine import tensor2img
10-
from mmedit.utils import modify_args
9+
from mmedit.utils import modify_args, tensor2img
1110

1211

1312
def parse_args():

demo/restoration_video_demo.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import torch
99

1010
from mmedit.apis import init_model, restoration_video_inference
11-
from mmedit.engine import tensor2img
12-
from mmedit.utils import modify_args
11+
from mmedit.utils import modify_args, tensor2img
1312

1413
VIDEO_EXTENSIONS = ('.mp4', '.mov')
1514

mmedit/datasets/transforms/__init__.py

+13-17
Original file line numberDiff line numberDiff line change
@@ -25,28 +25,24 @@
2525
RandomJPEGCompression, RandomNoise,
2626
RandomResize, RandomVideoCompression)
2727
from .random_down_sampling import RandomDownSampling
28-
from .trans_utils import (adjust_gamma, bbox2mask, brush_stroke_mask,
29-
get_irregular_mask, random_bbox)
3028
from .trimap import (FormatTrimap, GenerateTrimap,
3129
GenerateTrimapWithDistTransform, TransformTrimap)
3230
from .values import CopyValues, SetValues
3331

3432
__all__ = [
35-
'random_bbox', 'get_irregular_mask', 'brush_stroke_mask', 'bbox2mask',
36-
'adjust_gamma', 'BinarizeImage', 'Clip', 'ColorJitter', 'CopyValues',
37-
'Crop', 'CropLike', 'DegradationsWithShuffle', 'LoadImageFromFile',
38-
'LoadMask', 'Flip', 'FixedCrop', 'GenerateCoordinateAndCell',
39-
'GenerateFacialHeatmap', 'GenerateFrameIndices',
40-
'GenerateFrameIndiceswithPadding', 'GenerateSegmentIndices',
41-
'GetMaskedImage', 'GetSpatialDiscountMask', 'MATLABLikeResize',
42-
'MirrorSequence', 'ModCrop', 'Normalize', 'PackEditInputs',
43-
'PairedRandomCrop', 'RandomAffine', 'RandomBlur', 'RandomDownSampling',
44-
'RandomJPEGCompression', 'RandomMaskDilation', 'RandomNoise',
45-
'RandomResize', 'RandomResizedCrop', 'RandomRotation', 'RandomTransposeHW',
46-
'RandomVideoCompression', 'RescaleToZeroOne', 'Resize', 'SetValues',
47-
'TemporalReverse', 'ToTensor', 'UnsharpMasking', 'CropAroundCenter',
48-
'CropAroundFg', 'GenerateSeg', 'CropAroundUnknown', 'GenerateSoftSeg',
49-
'FormatTrimap', 'TransformTrimap', 'GenerateTrimap',
33+
'BinarizeImage', 'Clip', 'ColorJitter', 'CopyValues', 'Crop', 'CropLike',
34+
'DegradationsWithShuffle', 'LoadImageFromFile', 'LoadMask', 'Flip',
35+
'FixedCrop', 'GenerateCoordinateAndCell', 'GenerateFacialHeatmap',
36+
'GenerateFrameIndices', 'GenerateFrameIndiceswithPadding',
37+
'GenerateSegmentIndices', 'GetMaskedImage', 'GetSpatialDiscountMask',
38+
'MATLABLikeResize', 'MirrorSequence', 'ModCrop', 'Normalize',
39+
'PackEditInputs', 'PairedRandomCrop', 'RandomAffine', 'RandomBlur',
40+
'RandomDownSampling', 'RandomJPEGCompression', 'RandomMaskDilation',
41+
'RandomNoise', 'RandomResize', 'RandomResizedCrop', 'RandomRotation',
42+
'RandomTransposeHW', 'RandomVideoCompression', 'RescaleToZeroOne',
43+
'Resize', 'SetValues', 'TemporalReverse', 'ToTensor', 'UnsharpMasking',
44+
'CropAroundCenter', 'CropAroundFg', 'GenerateSeg', 'CropAroundUnknown',
45+
'GenerateSoftSeg', 'FormatTrimap', 'TransformTrimap', 'GenerateTrimap',
5046
'GenerateTrimapWithDistTransform', 'CompositeFg', 'RandomLoadResizeBg',
5147
'MergeFgAndBg', 'PerturbBg', 'RandomJitter', 'LoadPairedImageFromFile',
5248
'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad'

mmedit/datasets/transforms/alpha.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from mmengine.utils import is_list_of, is_tuple_of
1111

1212
from mmedit.registry import TRANSFORMS
13-
from .trans_utils import random_choose_unknown
13+
from mmedit.utils import random_choose_unknown
1414

1515

1616
@TRANSFORMS.register_module()

mmedit/datasets/transforms/crop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.nn.modules.utils import _pair
1010

1111
from mmedit.registry import TRANSFORMS
12-
from .trans_utils import random_choose_unknown
12+
from mmedit.utils import random_choose_unknown
1313

1414

1515
@TRANSFORMS.register_module()

mmedit/datasets/transforms/fgbg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from mmengine.fileio import FileClient
1111

1212
from mmedit.registry import TRANSFORMS
13-
from .trans_utils import add_gaussian_noise, adjust_gamma
13+
from mmedit.utils import add_gaussian_noise, adjust_gamma
1414

1515

1616
@TRANSFORMS.register_module()

mmedit/datasets/transforms/generate_assistant.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mmcv.transforms.base import BaseTransform
55

66
from mmedit.registry import TRANSFORMS
7-
from .trans_utils import make_coord
7+
from mmedit.utils import make_coord
88

99
try:
1010
import face_alignment

mmedit/datasets/transforms/loading.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mmengine.fileio import FileClient, list_from_file
99

1010
from mmedit.registry import TRANSFORMS
11-
from .trans_utils import (bbox2mask, brush_stroke_mask, get_irregular_mask,
11+
from mmedit.utils import (bbox2mask, brush_stroke_mask, get_irregular_mask,
1212
random_bbox)
1313

1414

mmedit/engine/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .hooks import * # noqa: F401, F403
3-
from .logging import * # noqa: F401, F403
4-
from .misc import * # noqa: F401, F403
53
from .optimizers import * # noqa: F401, F403
64
from .runner import * # noqa: F401, F403
75
from .schedulers import * # noqa: F401, F403

mmedit/engine/logging/__init__.py

-4
This file was deleted.

mmedit/engine/runner/__init__.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from .loops import GenTestLoop, GenValLoop
2+
from .gen_loops import GenTestLoop, GenValLoop
3+
from .log_processor import GenLogProcessor
34
from .multi_loops import MultiTestLoop, MultiValLoop
45

5-
__all__ = ['MultiValLoop', 'MultiTestLoop', 'GenTestLoop', 'GenValLoop']
6+
__all__ = [
7+
'MultiValLoop', 'MultiTestLoop', 'GenTestLoop', 'GenValLoop',
8+
'GenLogProcessor'
9+
]
File renamed without changes.
+7-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .fid_inception import InceptionV3
23
from .gaussian_funcs import gauss_gradient
4+
from .inception_utils import (disable_gpu_fuser_on_pt19, load_inception,
5+
prepare_inception_feat, prepare_vgg_feat)
36

4-
__all__ = ['gauss_gradient']
7+
__all__ = [
8+
'gauss_gradient', 'InceptionV3', 'disable_gpu_fuser_on_pt19',
9+
'load_inception', 'prepare_vgg_feat', 'prepare_inception_feat'
10+
]

mmedit/evaluation/metrics/inception_utils.py renamed to mmedit/evaluation/functional/inception_utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from torch.utils.data.dataset import Dataset
2222
from torchvision.models.inception import inception_v3
2323

24-
from mmedit.models import InceptionV3
25-
from mmedit.utils import MMGEN_CACHE_DIR, download_from_url
24+
from mmedit.utils import MMEDIT_CACHE_DIR, download_from_url
25+
from . import InceptionV3
2626

2727
ALLOWED_INCEPTION = ['StyleGAN', 'PyTorch']
2828
TERO_INCEPTION_URL = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' # noqa
@@ -131,7 +131,7 @@ def _load_inception_from_url(inception_url: str) -> nn.Module:
131131
print_log(f'Try to download Inception Model from {inception_url}...',
132132
'current')
133133
try:
134-
path = download_from_url(inception_url, dest_dir=MMGEN_CACHE_DIR)
134+
path = download_from_url(inception_url, dest_dir=MMEDIT_CACHE_DIR)
135135
print_log('Download Finished.', 'current')
136136
return _load_inception_from_path(path)
137137
except Exception as e:
@@ -336,7 +336,7 @@ def prepare_inception_feat(dataloader: DataLoader,
336336
if inception_pkl is None:
337337
inception_pkl, args = get_inception_feat_cache_name_and_args(
338338
dataloader, metric, real_nums, capture_mean_cov, capture_all)
339-
inception_pkl = osp.join(MMGEN_CACHE_DIR, inception_pkl)
339+
inception_pkl = osp.join(MMEDIT_CACHE_DIR, inception_pkl)
340340
else:
341341
args = dict()
342342
if osp.exists(inception_pkl):
@@ -510,7 +510,7 @@ def prepare_vgg_feat(dataloader: DataLoader,
510510
# cannot load or download from file, extract manually
511511
if vgg_pkl is None:
512512
vgg_pkl, args = get_vgg_feat_cache_name_and_args(dataloader, metric)
513-
vgg_pkl = osp.join(MMGEN_CACHE_DIR, vgg_pkl)
513+
vgg_pkl = osp.join(MMEDIT_CACHE_DIR, vgg_pkl)
514514
else:
515515
args = dict()
516516
if osp.exists(vgg_pkl):

mmedit/evaluation/metrics/fid.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from torch.utils.data.dataloader import DataLoader
1111

1212
from mmedit.registry import METRICS
13+
from ..functional import (disable_gpu_fuser_on_pt19, load_inception,
14+
prepare_inception_feat)
1315
from .base_gen_metric import GenerativeMetric
14-
from .inception_utils import (disable_gpu_fuser_on_pt19, load_inception,
15-
prepare_inception_feat)
1616

1717

1818
@METRICS.register_module('FID-Full')

mmedit/evaluation/metrics/inception_score.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from torch.utils.data.dataloader import DataLoader
1313

1414
from mmedit.registry import METRICS
15+
# from .inception_utils import disable_gpu_fuser_on_pt19, load_inception
16+
from ..functional import disable_gpu_fuser_on_pt19, load_inception
1517
from .base_gen_metric import GenerativeMetric
16-
from .inception_utils import disable_gpu_fuser_on_pt19, load_inception
1718

1819

1920
@METRICS.register_module('IS')

0 commit comments

Comments
 (0)