Skip to content

Commit d110e53

Browse files
committed
50k vs cifar10.train.npz
1 parent 526b939 commit d110e53

10 files changed

+36
-33
lines changed

Readme.md

+18-15
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Collections of GANs
22

3-
Pytorch implementation of unsupervised GANs.
3+
Pytorch implementation of basic unsupervised GANs on CIFAR10.
44

5-
For more defails about calculating Inception Score and FID using pytorch can be found here [pytorch_gan_metrics](https://github.com/w86763777/pytorch-gan-metrics)
5+
For more defails about calculating Inception Score and FID using pytorch can be found here [pytorch_gan_metrics](https://github.com/w86763777/pytorch-gan-metrics).
66

77
## Models
88
- [x] DCGAN
@@ -18,15 +18,15 @@ For more defails about calculating Inception Score and FID using pytorch can be
1818
```
1919

2020
## Results
21-
21+
The FID is calculated by 50k generated images and CIFAR10 train set.
2222
|Model |Dataset|Inception Score|FID |
2323
|--------------|:-----:|:--------------:|:---:|
24-
|DCGAN |CIFAR10|5.91(0.15) |47.46|
25-
|WGAN(CNN) |CIFAR10|6.46(0.24) |38.98|
26-
|WGAN-GP(CNN) |CIFAR10|7.69(0.19) |22.81|
27-
|WGAN-GP(ResNet)|CIFAR10|7.80(0.20) |21.48|
28-
|SNGAN(CNN) |CIFAR10|7.64(0.21) |21.86|
29-
|SNGAN(ResNet) |CIFAR10|8.21(0.17) |19.11|
24+
|DCGAN |CIFAR10|6.01(0.05) |42.72|
25+
|WGAN(CNN) |CIFAR10|6.62(0.09) |40.03|
26+
|WGAN-GP(CNN) |CIFAR10|7.66(0.10) |19.83|
27+
|WGAN-GP(ResNet)|CIFAR10|7.95(0.14) |16.95|
28+
|SNGAN(CNN) |CIFAR10|7.84(0.12) |17.81|
29+
|SNGAN(ResNet) |CIFAR10|8.31(0.10) |14.32|
3030

3131
## Examples
3232
- DCGAN
@@ -54,13 +54,16 @@ For more defails about calculating Inception Score and FID using pytorch can be
5454
![sngan_res_gif](https://drive.google.com/uc?export=view&id=1et3V7NbLEqH6aOWzkOQceNcnfY3WBOGz) ![sngan_res_png](https://drive.google.com/uc?export=view&id=1neYWCexP8kY2eixMpztNL50TKFLXZcBL)
5555

5656
## Reproduce
57-
- Download [cifar10.test.npz](https://drive.google.com/drive/folders/1UBdzl6GtNMwNQ5U-4ESlIer43tNjiGJC?usp=sharing) for calculating FID. Then, create folder `stats` for the npz files
57+
- Download [cifar10.train.npz](https://drive.google.com/drive/folders/1UBdzl6GtNMwNQ5U-4ESlIer43tNjiGJC?usp=sharing) for calculating FID. Then, create folder `stats` for the npz files
5858
```
5959
stats
60-
└── cifar10.test.npz
60+
└── cifar10.train.npz
6161
```
6262

6363
- Train from scratch
64+
65+
Different methods are separated into different files for clear reading.
66+
6467
```bash
6568
# DCGAN
6669
python dcgan.py --flagfile ./configs/DCGAN_CIFAR10.txt
@@ -75,18 +78,18 @@ For more defails about calculating Inception Score and FID using pytorch can be
7578
# SNGAN(ResNet)
7679
python sngan.py --flagfile ./configs/SNGAN_CIFAR10_RES.txt
7780
```
78-
Though the training procedures of different GANs are almost identical, I still separate different methods into different files for clear reading.
81+
7982

80-
## Learning curve
83+
## Learning Curves
8184
![inception_score_curve](https://drive.google.com/uc?export=view&id=12JTJS5--2dDjFyVhHJ-b264Qp3S-v8xS)
8285
![fid_curve](https://drive.google.com/uc?export=view&id=1P4e_DEyW4wvFubPSu5t_i2gVRoecGqs5)
8386

8487
## Change Log
8588
- 2021-01-10
8689
- Update pytorch to 1.10.1 and CUDA 11.3
87-
- Calculate FID and Inception by `pytorch_gan_metrics`
90+
- Use `pytorch_gan_metrics` to calculate FID and Inception Score
8891
- Use 50k generated images and CIFAR10 train set to calculate FID
89-
- Fix default parameters
92+
- Fix default parameters especially for `wgan.py`
9093

9194
- 2021-04-16
9295
- Update pytorch to 1.8.1

configs/DCGAN_CIFAR10.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
--arch=cnn32
22
--batch_size=128
33
--dataset=cifar10
4-
--fid_cache=./stats/cifar10.test.npz
4+
--fid_cache=./stats/cifar10.train.npz
55
--logdir=./logs/DCGAN_CIFAR10
66
--loss=bce
77
--lr_D=0.0002
88
--lr_G=0.0002
99
--n_dis=1
10-
--num_images=10000
10+
--num_images=50000
1111
--record
1212
--sample_step=500
1313
--sample_size=64

configs/SNGAN_CIFAR10_CNN.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
--arch=cnn32
22
--batch_size=128
33
--dataset=cifar10
4-
--fid_cache=./stats/cifar10.test.npz
4+
--fid_cache=./stats/cifar10.train.npz
55
--logdir=./logs/SNGAN_CIFAR10_CNN
66
--loss=hinge
77
--lr_D=0.0002
88
--lr_G=0.0002
99
--n_dis=1
10-
--num_images=10000
10+
--num_images=50000
1111
--record
1212
--sample_step=500
1313
--sample_size=64

configs/SNGAN_CIFAR10_RES.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
--arch=res32
22
--batch_size=64
33
--dataset=cifar10
4-
--fid_cache=./stats/cifar10.test.npz
4+
--fid_cache=./stats/cifar10.train.npz
55
--logdir=./logs/SNGAN_CIFAR10_RES
66
--loss=hinge
77
--lr_D=0.0002
88
--lr_G=0.0002
99
--n_dis=5
10-
--num_images=10000
10+
--num_images=50000
1111
--record
1212
--sample_step=500
1313
--sample_size=64

configs/WGANGP_CIFAR10_CNN.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
--arch=cnn32
33
--batch_size=128
44
--dataset=cifar10
5-
--fid_cache=./stats/cifar10.test.npz
5+
--fid_cache=./stats/cifar10.train.npz
66
--logdir=./logs/WGANGP_CIFAR10_CNN
77
--loss=was
88
--lr_D=0.0002
99
--lr_G=0.0002
1010
--n_dis=1
11-
--num_images=10000
11+
--num_images=50000
1212
--record
1313
--sample_step=500
1414
--sample_size=64

configs/WGANGP_CIFAR10_RES.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
--arch=res32
33
--batch_size=64
44
--dataset=cifar10
5-
--fid_cache=./stats/cifar10.test.npz
5+
--fid_cache=./stats/cifar10.train.npz
66
--logdir=./logs/WGANGP_CIFAR10_RES
77
--loss=was
88
--lr_D=0.0002
99
--lr_G=0.0002
1010
--n_dis=5
11-
--num_images=10000
11+
--num_images=50000
1212
--record
1313
--sample_step=500
1414
--sample_size=64

configs/WGAN_CIFAR10_CNN.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
--arch=cnn32
33
--batch_size=128
44
--dataset=cifar10
5-
--fid_cache=./stats/cifar10.test.npz
5+
--fid_cache=./stats/cifar10.train.npz
66
--logdir=./logs/WGAN_CIFAR10_CNN
77
--loss=was
88
--lr_D=0.0002
99
--lr_G=0.0002
1010
--n_dis=1
11-
--num_images=10000
11+
--num_images=50000
1212
--record
1313
--sample_step=500
1414
--sample_size=64

dcgan.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,12 @@
4848
flags.DEFINE_integer('sample_size', 64, "sampling size of images")
4949
flags.DEFINE_string('logdir', './logs/DCGAN_CIFAR10', 'logging folder')
5050
flags.DEFINE_bool('record', True, "record inception score and FID")
51-
flags.DEFINE_string('fid_cache', './stats/cifar10.test.npz', 'FID cache')
51+
flags.DEFINE_string('fid_cache', './stats/cifar10.train.npz', 'FID cache')
5252
# generate
5353
flags.DEFINE_bool('generate', False, 'generate images')
5454
flags.DEFINE_string('pretrain', None, 'path to test model')
5555
flags.DEFINE_string('output', './outputs', 'path to output dir')
56-
flags.DEFINE_integer('num_images', 10000, 'the number of generated images')
56+
flags.DEFINE_integer('num_images', 50000, 'the number of generated images')
5757

5858
device = torch.device('cuda:0')
5959

wgan.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@
5151
flags.DEFINE_integer('sample_size', 64, "sampling size of images")
5252
flags.DEFINE_string('logdir', './logs/WGAN_CIFAR10_CNN', 'logging folder')
5353
flags.DEFINE_bool('record', True, "record inception score and FID")
54-
flags.DEFINE_string('fid_cache', './stats/cifar10.test.npz', 'FID cache')
54+
flags.DEFINE_string('fid_cache', './stats/cifar10.train.npz', 'FID cache')
5555
# generate
5656
flags.DEFINE_bool('generate', False, 'generate images')
5757
flags.DEFINE_string('pretrain', None, 'path to test model')
5858
flags.DEFINE_string('output', './outputs', 'path to output dir')
59-
flags.DEFINE_integer('num_images', 10000, 'the number of generated images')
59+
flags.DEFINE_integer('num_images', 50000, 'the number of generated images')
6060

6161
device = torch.device('cuda:0')
6262

wgangp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@
5151
flags.DEFINE_integer('sample_size', 64, "sampling size of images")
5252
flags.DEFINE_string('logdir', './logs/WGANGP_CIFAR10_RES', 'logging folder')
5353
flags.DEFINE_bool('record', True, "record inception score and FID")
54-
flags.DEFINE_string('fid_cache', './stats/cifar10.test.npz', 'FID cache')
54+
flags.DEFINE_string('fid_cache', './stats/cifar10.train.npz', 'FID cache')
5555
# generate
5656
flags.DEFINE_bool('generate', False, 'generate images')
5757
flags.DEFINE_string('pretrain', None, 'path to test model')
5858
flags.DEFINE_string('output', './outputs', 'path to output dir')
59-
flags.DEFINE_integer('num_images', 10000, 'the number of generated images')
59+
flags.DEFINE_integer('num_images', 50000, 'the number of generated images')
6060

6161
device = torch.device('cuda:0')
6262

0 commit comments

Comments
 (0)