Skip to content

Commit c55daa1

Browse files
committed
refactoring / add lsgan (wip) / use basemodel / fix dcgan bugs
1 parent d1d8b9f commit c55daa1

File tree

9 files changed

+341
-88
lines changed

9 files changed

+341
-88
lines changed

README.md

+8-3
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@
66
* MNIST 에 대해서는 해 봤지만 사실 MNIST 는 데이터셋이 간단해서 구별이 잘 안 감
77
* 따라서 CelebA 나 LSUN 등 좀 복잡한 데이터셋에 대해서 해 보면서 몇몇 코드들 리팩토링도 하고...
88

9+
## Checks
10+
11+
* Data read 가 좀 이상함. 20 epoch 을 주면 10 epoch 만 돌고 끝남. epoch 계산을 내가 잘못하고 있나? 체크해보자.
12+
* LSGAN hyperparams
13+
914
## GANs
1015

11-
* [ ] DCGAN
12-
* [ ] LSGAN
16+
* [x] DCGAN
17+
* [ ] LSGAN - WIP
1318
* [ ] EBGAN
1419
* [ ] WGAN
1520
* [ ] WGAN-GP
@@ -26,4 +31,4 @@
2631
## Resuable code
2732

2833
* `utils.py`
29-
* `inputpipe.py`
34+
* `inputpipe.py`

basemodel.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# coding: utf-8
2+
3+
'''
4+
BaseModel for Generative Adversarial Netowrks.
5+
이 모델을 상속받아서 _build_train_graph, _discriminator, _generator 세가지만 구현해주면 된다.
6+
'''
7+
8+
import tensorflow as tf
9+
slim = tf.contrib.slim
10+
11+
12+
class BaseModel(object):
13+
def __init__(self, input_pipe, z_dim=100, name='basemodel'):
14+
'''
15+
training mode: input_pipe = input pipeline
16+
generation mode: input_pipe = None
17+
'''
18+
self.name = name
19+
training = input_pipe is not None
20+
# check: DCGAN specified BN-params?
21+
self.bn_params = {
22+
"decay": 0.99,
23+
"epsilon": 1e-5,
24+
"scale": True,
25+
"is_training": training
26+
}
27+
self.z_dim = z_dim
28+
if training == True:
29+
self._build_train_graph(input_pipe)
30+
else:
31+
self._build_gen_graph()
32+
33+
34+
def _build_gen_graph(self):
35+
'''build computational graph for generation (evaluation)
36+
'''
37+
with tf.variable_scope(self.name):
38+
self.z = tf.placeholder(tf.float32, [None, self.z_dim])
39+
self.fake_sample = self._generator(self.z)
40+
41+
42+
def _build_train_graph(self, X):
43+
'''build computational graph for training
44+
'''
45+
pass

config.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import dcgan, lsgan
2+
3+
4+
model_zoo = ['DCGAN', 'LSGAN', 'WGAN', 'WGAN-GP', 'BEGAN']
5+
6+
def get_model(mtype, name, input_pipe):
7+
model = None
8+
if mtype == 'DCGAN':
9+
model = dcgan.DCGAN
10+
elif mtype == 'LSGAN':
11+
model = lsgan.LSGAN
12+
elif mtype == 'WGAN':
13+
pass
14+
elif mtype == 'WGAN-GP':
15+
pass
16+
elif mtype == 'BEGAN':
17+
pass
18+
else:
19+
assert False, mtype + ' is not in the model zoo'
20+
21+
assert model, mtype + ' is work in progress'
22+
23+
return model(input_pipe=input_pipe, name=name)
24+
25+
26+
def pprint_args(FLAGS):
27+
print("\nParameters:")
28+
for attr, value in sorted(vars(FLAGS).items()):
29+
print("{}={}".format(attr.upper(), value))
30+
print("")

dcgan.py

+9-38
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# coding: utf-8
22
import tensorflow as tf
33
slim = tf.contrib.slim
4-
from utils import *
4+
from utils import expected_shape
5+
import ops
6+
from basemodel import BaseModel
57

68
'''
79
일단 MNIST 는 무시하자... 귀찮다.
@@ -19,44 +21,13 @@
1921
init - normal dist + stddev 0.02
2022
'''
2123

22-
def lrelu(inputs, leak=0.2, scope="lrelu"):
23-
"""
24-
https://github.com/tensorflow/tensorflow/issues/4079
25-
"""
26-
with tf.variable_scope(scope):
27-
f1 = 0.5 * (1 + leak)
28-
f2 = 0.5 * (1 - leak)
29-
return f1 * inputs + f2 * abs(inputs)
30-
31-
32-
class DCGAN(object):
33-
def __init__(self, input_pipe, z_dim=100, name='dcgan'):
24+
class DCGAN(BaseModel):
25+
def __init__(self, input_pipe, z_dim=100, name='dcgan2'):
3426
'''
3527
training mode: input_pipe = input pipeline
3628
generation mode: input_pipe = None
3729
'''
38-
self.name = name
39-
training = bool(input_pipe)
40-
# check: DCGAN specified BN-params?
41-
self.bn_params = {
42-
"decay": 0.99,
43-
"epsilon": 1e-5,
44-
"scale": True,
45-
"is_training": training
46-
}
47-
self.z_dim = z_dim
48-
if training == True:
49-
self._build_train_graph(input_pipe)
50-
else:
51-
self._build_gen_graph()
52-
53-
54-
def _build_gen_graph(self):
55-
'''build computational graph for generation (evaluation)
56-
'''
57-
with tf.variable_scope(self.name):
58-
self.z = tf.placeholder(tf.float32, [None, self.z_dim])
59-
self.fake_sample = self._generator(self.z)
30+
super(DCGAN, self).__init__(input_pipe=input_pipe, z_dim=z_dim, name=name)
6031

6132

6233
def _build_train_graph(self, X):
@@ -115,7 +86,7 @@ def _discriminator(self, X, reuse=False):
11586
with tf.variable_scope('D', reuse=reuse):
11687
net = X
11788

118-
with slim.arg_scope([slim.conv2d], kernel_size=[5,5], stride=2, padding='SAME', activation_fn=lrelu,
89+
with slim.arg_scope([slim.conv2d], kernel_size=[5,5], stride=2, padding='SAME', activation_fn=ops.lrelu,
11990
normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params):
12091
net = slim.conv2d(net, 64, normalizer_fn=None)
12192
expected_shape(net, [32, 32, 64])
@@ -141,13 +112,13 @@ def _generator(self, z, reuse=False):
141112

142113
with slim.arg_scope([slim.conv2d_transpose], kernel_size=[5,5], stride=2, padding='SAME', activation_fn=tf.nn.relu,
143114
normalizer_fn=slim.batch_norm, normalizer_params=self.bn_params):
144-
net = slim.conv2d_transpose(net, 512, normalizer_fn=None)
115+
net = slim.conv2d_transpose(net, 512)
145116
expected_shape(net, [8, 8, 512])
146117
net = slim.conv2d_transpose(net, 256)
147118
expected_shape(net, [16, 16, 256])
148119
net = slim.conv2d_transpose(net, 128)
149120
expected_shape(net, [32, 32, 128])
150-
net = slim.conv2d_transpose(net, 3, activation_fn=tf.nn.tanh)
121+
net = slim.conv2d_transpose(net, 3, activation_fn=tf.nn.tanh, normalizer_fn=None)
151122
expected_shape(net, [64, 64, 3])
152123

153124
return net

eval.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dcgan import DCGAN
44
import numpy as np
55
import utils
6+
from config import get_model, pprint_args
67
import os, glob
78
import scipy.misc
89
from argparse import ArgumentParser
@@ -12,6 +13,7 @@
1213
def build_parser():
1314
parser = ArgumentParser()
1415
parser.add_argument('--model', help='DCGAN / LSGAN / WGAN / WGAN-GP / BEGAN', required=True) # DRAGAN, CramerGAN
16+
parser.add_argument('--name', help='default: model')
1517

1618
return parser
1719

@@ -39,17 +41,19 @@ def get_all_checkpoints(ckpt_dir, force=False):
3941
return ckpts
4042

4143

42-
def eval(model, sample_shape=[4,4], load_all_ckpt=True):
43-
dir_name = 'eval_' + model.name
44+
def eval(model, name, sample_shape=[4,4], load_all_ckpt=True):
45+
if name == None:
46+
name = model.name
47+
dir_name = 'eval_' + name
4448
if tf.gfile.Exists(dir_name):
4549
tf.gfile.DeleteRecursively(dir_name)
4650
tf.gfile.MkDir(dir_name)
4751

4852
# training=False => generator 만 생성
4953
restorer = tf.train.Saver(slim.get_model_variables())
5054
with tf.Session() as sess:
51-
# ckpt = tf.train.get_checkpoint_state('./checkpoints/' + model.name)
52-
ckpts = get_all_checkpoints('./checkpoints/' + model.name, force=load_all_ckpt)
55+
# ckpt = tf.train.get_checkpoint_state('./checkpoints/' + name)
56+
ckpts = get_all_checkpoints('./checkpoints/' + name, force=load_all_ckpt)
5357
size = sample_shape[0] * sample_shape[1]
5458

5559
z_ = sample_z([size, model.z_dim])
@@ -88,7 +92,7 @@ def to_gif(dir_name='eval'):
8892
parser = build_parser()
8993
FLAGS = parser.parse_args()
9094
FLAGS.model = FLAGS.model.upper()
91-
utils.pprint_args(FLAGS)
95+
pprint_args(FLAGS)
9296

93-
model = utils.get_model(FLAGS.model, training=False, X=None)
94-
eval(model, sample_shape=[4,4], load_all_ckpt=True)
97+
model = get_model(FLAGS.model, FLAGS.name, input_pipe=None)
98+
eval(model, name=FLAGS.name, sample_shape=[4,4], load_all_ckpt=True)

0 commit comments

Comments
 (0)