Skip to content

Commit bdd58e1

Browse files
committed
add training curve
1 parent 1a3fc37 commit bdd58e1

9 files changed

+51
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ logs
1313
# data
1414
data
1515
stats
16+
images

Readme.md

+4
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ For more defails about calculating Inception Score and FID using pytorch can be
8383
```
8484
Though the training procedures of different GANs are almost identical, I still separate different methods into different files for clear reading.
8585
86+
## Learning curve
87+
![inception_score_curve](https://drive.google.com/uc?export=view&id=12JTJS5--2dDjFyVhHJ-b264Qp3S-v8xS)
88+
![fid_curve](https://drive.google.com/uc?export=view&id=1P4e_DEyW4wvFubPSu5t_i2gVRoecGqs5)
89+
8690
## Change Log
8791
- 2021-04-16
8892
- Update pytorch to 1.8.1

images/dcgan_cifar10.png

-135 KB
Binary file not shown.

images/sngan_cifar10_cnn.png

-144 KB
Binary file not shown.

images/sngan_cifar10_res.png

-142 KB
Binary file not shown.

images/wgan_cifar10_cnn.png

-145 KB
Binary file not shown.

images/wgangp_cifar10_cnn.png

-142 KB
Binary file not shown.

images/wgangp_cifar10_res.png

-141 KB
Binary file not shown.

tools/plotcurve.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import os
2+
import glob
3+
4+
import matplotlib.pyplot as plt
5+
import pandas as pd
6+
7+
8+
if __name__ == '__main__':
9+
IS = []
10+
FID = []
11+
for path in glob.glob('./logs/*.csv'):
12+
df = pd.read_csv(path)
13+
_, name, _, tag = os.path.splitext(
14+
os.path.basename(path))[0].split('-')
15+
if '_CIFAR10_' in name:
16+
name = name.replace('_CIFAR10_', '(')
17+
name = name + ')'
18+
if name.endswith('_CIFAR10'):
19+
name = name.replace('_CIFAR10', '')
20+
if tag == 'inception_score' or tag == 'Inception_Score':
21+
IS.append((name, df.values[:, 1], df.values[:, 2]))
22+
elif tag == 'fid_score' or tag == 'FID':
23+
FID.append((name, df.values[:, 1], df.values[:, 2]))
24+
else:
25+
raise ValueError("???")
26+
IS = sorted(IS, key=lambda x: x[2][-1], reverse=True)
27+
FID = sorted(FID, key=lambda x: x[2][-1])
28+
29+
for name, x, y in IS:
30+
plt.plot(x, y, label=name)
31+
plt.legend()
32+
plt.xlabel('Step', fontsize=16)
33+
plt.ylabel('Inception Score', fontsize=16)
34+
# plt.title('Inception Score', fontsize=16)
35+
plt.savefig('./IS.png')
36+
plt.cla()
37+
38+
for name, x, y in FID:
39+
plt.plot(x, y, label=name)
40+
plt.legend()
41+
plt.ylim(0, 100)
42+
plt.yticks(range(0, 101, 10))
43+
plt.ylabel('FID', fontsize=16)
44+
plt.xlabel('Step', fontsize=16)
45+
# plt.title('FID curve', fontsize=16)
46+
plt.savefig('./FID.png')

0 commit comments

Comments
 (0)