Skip to content

Commit bd88727

Browse files
committed
clean up
1 parent 512ab92 commit bd88727

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+700
-835
lines changed

README.md

+12-9
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Overview
2-
PyTorch version: 0.4.1 | Python 3.6.5
2+
PyTorch 0.4.1 | Python 3.6.5
33

44
Annotated implementations with comparative introductions for minimax, non-saturating, wasserstein, wasserstein gradient penalty, least squares, deep regret analytic, bounded equilibrium, relativistic, f-divergence, Fisher, and information generative adversarial networks (GANs), and standard, variational, and bounded information rate variational autoencoders (VAEs).
55

6-
Paper links are supplied at the beginning of each file with a short summary of the paper. See src folder for files to run via terminal, or notebooks folder for Jupyter notebook visualizations via your local browser. The main file changes can be see in the train, train_D, and train_G of the Trainer class, although changes are not completely limited to only these two areas (e.g. Wasserstein GAN clamps weight in the train function, BEGAN gives multiple outputs from train_D, fGAN has a slight modification in viz_loss function to indicate method used in title).
6+
Paper links are supplied at the beginning of each file with a short summary of the paper. See src folder for files to run via terminal, or notebooks folder for Jupyter notebook visualizations via your local browser. The main file changes can be see in the ```train```, ```train_D```, and ```train_G``` of the Trainer class, although changes are not completely limited to only these two areas (e.g. Wasserstein GAN clamps weight in the train function, BEGAN gives multiple outputs from train_D, fGAN has a slight modification in viz_loss function to indicate method used in title).
77

8-
All code in this repository operates in a generative, unsupervised manner on binary (black and white) MNIST. The architectures are compatible with a variety of datatypes (1D, 2D, 3D) and plotting functions work with binary/RGB images too. If a GPU is detected, the models use it. Otherwise, they default to CPU.
8+
All code in this repository operates in a generative, unsupervised manner on binary (black and white) MNIST. The architectures are compatible with a variety of datatypes (1D, 2D, square 3D images). Plotting functions work with binary/RGB images. If a GPU is detected, the models use it. Otherwise, they default to CPU. VAE Trainer classes contain methods to visualize latent space representations (see ```make_all``` function).
99

1010
# Usage
1111
To initialize an environment:
@@ -20,7 +20,7 @@ For playing around in Jupyer notebooks:
2020
jupyter notebook
2121
```
2222

23-
To run from Terminal / Bash:
23+
To run from Terminal:
2424
```
2525
cd src
2626
python bir_vae.py
@@ -34,9 +34,10 @@ Suppose we have a non-saturating GAN and we wanted to implement a least-squares
3434

3535
[Original](https://github.com/shayneobrien/generative-models/blob/master/src/ns_gan.py#L166-L208) (NSGAN)
3636
```
37-
def train_D(self):
38-
37+
def train_D(self, images):
38+
...
3939
D_loss = -torch.mean(torch.log(DX_score + 1e-8) + torch.log(1 - DG_score + 1e-8))
40+
4041
return D_loss
4142
```
4243
```
@@ -49,15 +50,17 @@ def train_G(self, images):
4950

5051
[New](https://github.com/shayneobrien/generative-models/blob/master/src/ls_gan.py#L166-L209) (LSGAN)
5152
```
52-
def train_D(self, images, a=0, b=1):
53+
def train_D(self, images):
5354
...
54-
D_loss = (0.50 * torch.mean((DX_score - b)**2)) + (0.50 * torch.mean((DG_score - a)**2))
55+
D_loss = (0.50 * torch.mean((DX_score - 1.)**2)) + (0.50 * torch.mean((DG_score - 0.)**2))
56+
5557
return D_loss
5658
```
5759
```
5860
def train_G(self, images):
5961
...
6062
G_loss = 0.50 * torch.mean((DG_score - 1.)**2)
63+
6164
return G_loss
6265
```
6366

@@ -72,7 +75,7 @@ All models were trained for 25 epochs with hidden dimension 400, latent dimensio
7275
[MMGAN](https://arxiv.org/abs/1406.2661) | <img src = 'viz/MMGAN/reconst_1.png' height = '150px'> | <img src = 'viz/MMGAN/reconst_25.png' height = '150px'> | <img src = 'viz/gifs/MMGAN_gif.gif' height = '150px'> | <img src = 'viz/losses/MMGAN_loss.png' height = '150px'>
7376
[NSGAN](https://arxiv.org/abs/1406.2661) | <img src = 'viz/NSGAN/reconst_1.png' height = '150px'> | <img src = 'viz/NSGAN/reconst_25.png' height = '150px'> | <img src = 'viz/gifs/NSGAN_gif.gif' height = '150px'> | <img src = 'viz/losses/NSGAN_loss.png' height = '150px'>
7477
[WGAN](https://arxiv.org/abs/1701.07875) | <img src = 'viz/WGAN/reconst_1.png' height = '150px'> | <img src = 'viz/WGAN/reconst_25.png' height = '150px'> | <img src = 'viz/gifs/WGAN_gif.gif' height = '150px'> | <img src = 'viz/losses/WGAN_loss.png' height = '150px'>
75-
[WGANGP](https://arxiv.org/abs/1704.00028) | <img src = 'viz/WGANGP/reconst_1.png' height = '150px'> | <img src = 'viz/WGANGP/reconst_25.png' height = '150px'> | <img src = 'viz/gifs/WGANGP_gif.gif' height = '150px'> | <img src = 'viz/losses/WGANGP_loss.png' height = '150px'>
78+
[WGPGAN](https://arxiv.org/abs/1704.00028) | <img src = 'viz/WGPGAN/reconst_1.png' height = '150px'> | <img src = 'viz/WGPGAN/reconst_25.png' height = '150px'> | <img src = 'viz/gifs/WGPGAN_gif.gif' height = '150px'> | <img src = 'viz/losses/WGPGAN_loss.png' height = '150px'>
7679
[DRAGAN](https://arxiv.org/abs/1705.07215) | <img src = 'viz/DRAGAN/reconst_1.png' height = '150px'> | <img src = 'viz/DRAGAN/reconst_25.png' height = '150px'> | <img src = 'viz/gifs/DRAGAN_gif.gif' height = '150px'> | <img src = 'viz/losses/DRAGAN_loss.png' height = '150px'>
7780
[BEGAN](https://arxiv.org/abs/1703.10717) | <img src = 'viz/BEGAN/reconst_1.png' height = '150px'> | <img src = 'viz/BEGAN/reconst_25.png' height = '150px'> | <img src = 'viz/gifs/BEGAN_gif.gif' height = '150px'> | <img src = 'viz/losses/BEGAN_loss.png' height = '150px'>
7881
[LSGAN](https://arxiv.org/abs/1611.04076) | <img src = 'viz/LSGAN/reconst_1.png' height = '150px'> | <img src = 'viz/LSGAN/reconst_25.png' height = '150px'> | <img src = 'viz/gifs/LSGAN_gif.gif' height = '150px'> | <img src = 'viz/losses/LSGAN_loss.png' height = '150px'>

src/ae.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,17 @@ def evaluate(self, iterator):
164164
return np.mean([self.compute_batch(batch).item() for batch in iterator])
165165

166166
def reconstruct_images(self, images, epoch, save=True):
167-
"""Reconstruct a fixed input at each epoch for progress visualization"""
167+
""" Reconstruct a fixed input at each epoch for progress viz """
168168
# Reshape images, pass through model, reshape reconstructed output
169169
batch = to_cuda(images.view(images.shape[0], -1))
170170
reconst_images = self.model(batch)
171171
reconst_images = reconst_images.view(images.shape).squeeze()
172172

173173
# Plot
174174
plt.close()
175-
size_figure_grid, k = int(reconst_images.shape[0]**0.5), 0
176-
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
177-
for i, j in product(range(size_figure_grid), range(size_figure_grid)):
175+
grid_size, k = int(reconst_images.shape[0]**0.5), 0
176+
fig, ax = plt.subplots(grid_size, grid_size, figsize=(5, 5))
177+
for i, j in product(range(grid_size), range(grid_size)):
178178
ax[i,j].get_xaxis().set_visible(False)
179179
ax[i,j].get_yaxis().set_visible(False)
180180
ax[i,j].imshow(reconst_images[k].data.numpy(), cmap='gray')
@@ -187,10 +187,10 @@ def reconstruct_images(self, images, epoch, save=True):
187187
os.makedirs(outname)
188188
torchvision.utils.save_image(self.debugging_image.data,
189189
outname + 'real.png',
190-
nrow=size_figure_grid)
190+
nrow=grid_size)
191191
torchvision.utils.save_image(reconst_images.unsqueeze(1).data,
192192
outname + 'reconst_%d.png' %(epoch),
193-
nrow=size_figure_grid)
193+
nrow=grid_size)
194194

195195
def viz_loss(self):
196196
""" Visualize reconstruction loss """

src/bayes_gan.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# TODO
2-
""" (BayesGAN)
2+
""" (BayesGAN) https://arxiv.org/abs/1705.09558
3+
Bayesian GAN
34
45
From the authors:
5-
"
6-
Bayesian GAN (Saatchi and Wilson, 2017) is a Bayesian formulation of Generative
6+
7+
"Bayesian GAN (Saatchi and Wilson, 2017) is a Bayesian formulation of Generative
78
Adversarial Networks (Goodfellow, 2014) where we learn the distributions of the
89
generator parameters $\theta_g$ and the discriminator parameters $\theta_d$
910
instead of optimizing for point estimates. The benefits of the Bayesian approach
@@ -13,7 +14,7 @@
1314
1415
We learn Bayesian GAN via an approximate inference algorithm called Stochastic
1516
Gradient Hamiltonian Monte Carlo (SGHMC) which is a gradient-based MCMC methods
16-
whose samples approximate the true posterior distributions of $\theta_g$ and
17+
whose samples approximate the true posterior distributions of $\theta_g$ and
1718
$\theta_d$. The Bayesian GAN training process starts from sampling noise $z$
1819
from a fixed distribution(typically standard d-dim normal). The noise is fed
1920
to the generator where the parameters $\theta_g$ are sampled from the posterior
@@ -27,6 +28,4 @@
2728
2829
SGHMC is fancy for using point estimates (as in most GANs) to infer the
2930
posteriors.
30-
31-
https://arxiv.org/pdf/1705.09558.pdf
3231
"""

src/be_gan.py

+50-33
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
""" (BEGAN)
1+
""" (BEGAN) https://arxiv.org/abs/1703.10717
22
Boundary Equilibrium GAN
33
4-
https://arxiv.org/abs/1703.10717
5-
64
BEGAN uses an autoencoder as a discriminator and optimizes a lower bound of the
75
Wasserstein distance between auto-encoder loss distributions on real and fake
86
data (as opposed to the sample distributions of the generator and real data).
@@ -44,7 +42,7 @@
4442
from itertools import product
4543
from tqdm import tqdm
4644

47-
from utils import *
45+
from .utils import *
4846

4947

5048
class Generator(nn.Module):
@@ -89,6 +87,8 @@ def __init__(self, image_size, hidden_dim, z_dim):
8987
self.G = Generator(image_size, hidden_dim, z_dim)
9088
self.D = Discriminator(image_size, hidden_dim)
9189

90+
self.shape = int(image_size ** 0.5)
91+
9292

9393
class BEGANTrainer:
9494
def __init__(self, model, train_iter, val_iter, test_iter, viz=False):
@@ -114,21 +114,26 @@ def train(self, num_epochs, G_lr=1e-4, D_lr=1e-4, D_steps=1,
114114
115115
Inputs:
116116
num_epochs: int, number of epochs to train for
117-
G_lr: float, learning rate for generator's optimizer (default 1e-4)
118-
D_lr: float, learning rate for discriminator's optimizer (default 1e-4)
119-
D_steps: int, ratio for how often to train D compared to G (default 1)
120-
GAMMA: float, balance equilibrium between G and D objectives (default 0.50)
121-
LAMBDA: float, weight D loss for updating K (default 1e-3)
122-
K: float, how much to initially emphasize loss(D(G(z))) in total D loss (default 0.00)
117+
G_lr: float, learning rate for generator's optimizer
118+
D_lr: float, learning rate for discriminator's optimizer
119+
D_steps: int, ratio for how often to train D compared to G
120+
GAMMA: float, balance equilibrium between G and D objectives
121+
LAMBDA: float, weight D loss for updating K
122+
K: float, how much to emphasize loss(D(G(z))) in initial D loss
123123
"""
124124

125125
# Adam optimizers
126-
G_optimizer = optim.Adam(params=[p for p in self.model.G.parameters() if p.requires_grad], lr=G_lr)
127-
D_optimizer = optim.Adam(params=[p for p in self.model.D.parameters() if p.requires_grad], lr=D_lr)
128-
129-
# Reduce learning rate by factor of 2 if convergence_metric stops decreasing by a threshold for last five epochs
130-
G_scheduler = ReduceLROnPlateau(G_optimizer, factor=0.50, threshold=0.01, patience=5*len(self.train_iter))
131-
D_scheduler = ReduceLROnPlateau(D_optimizer, factor=0.50, threshold=0.01, patience=5*len(self.train_iter))
126+
G_optimizer = optim.Adam(params=[p for p in self.model.G.parameters()
127+
if p.requires_grad], lr=G_lr)
128+
D_optimizer = optim.Adam(params=[p for p in self.model.D.parameters()
129+
if p.requires_grad], lr=D_lr)
130+
131+
# Reduce learning rate by factor of 2 if convergence_metric stops
132+
# decreasing by a threshold for last five epochs
133+
G_scheduler = ReduceLROnPlateau(G_optimizer, factor=0.50, threshold=0.01,
134+
patience=5*len(self.train_iter))
135+
D_scheduler = ReduceLROnPlateau(D_optimizer, factor=0.50, threshold=0.01,
136+
patience=5*len(self.train_iter))
132137

133138
# Approximate steps/epoch given D_steps per epoch
134139
# --> roughly train in the same way as if D_step (1) == G_step (1)
@@ -179,14 +184,15 @@ def train(self, num_epochs, G_lr=1e-4, D_lr=1e-4, D_steps=1,
179184
# Save relevant output for progress logging
180185
G_losses.append(G_loss.item())
181186

182-
# PROPORTIONAL CONTROL THEORY: Dynamically update K, log convergence measure
183-
convergence_measure = (DX_loss + torch.abs(GAMMA*DX_loss - DG_loss)).item()
187+
# PROPORTIONAL CONTROL THEORY: Dynamically update K,
188+
# log convergence measure
189+
convergence = (DX_loss+torch.abs(GAMMA*DX_loss-DG_loss)).item()
184190
K_update = (K + LAMBDA*(GAMMA*DX_loss - DG_loss)).item()
185191
K = min(max(0, K_update), 1)
186192

187193
# Learning rate scheduler
188-
D_scheduler.step(convergence_measure)
189-
G_scheduler.step(convergence_measure)
194+
D_scheduler.step(convergence)
195+
G_scheduler.step(convergence)
190196

191197
# Save losses
192198
self.Glosses.extend(G_losses)
@@ -195,7 +201,7 @@ def train(self, num_epochs, G_lr=1e-4, D_lr=1e-4, D_steps=1,
195201
# Progress logging
196202
print ("Epoch[%d/%d], G Loss: %.4f, D Loss: %.4f, K: %.4f, Convergence Measure: %.4f"
197203
%(epoch, num_epochs, np.mean(G_losses),
198-
np.mean(D_losses), K, convergence_measure))
204+
np.mean(D_losses), K, convergence))
199205
self.num_epochs += 1
200206

201207
# Visualize generator progress
@@ -207,10 +213,11 @@ def train_D(self, images, K):
207213
""" Run 1 step of training for discriminator
208214
209215
Input:
210-
images: batch of images (reshaped to [batch_size, 784])
216+
images: batch of images (reshaped to [batch_size, -1])
211217
K: how much to emphasize loss(D(G(z))) in total D loss
212218
Output:
213-
D_loss: BEGAN loss for discriminator, E[||x - AE(x)||1] - K*E[G(z) - AE(G(z))]
219+
D_loss: BEGAN loss for discriminator,
220+
E[||x-AE(x)||1] - K*E[G(z) - AE(G(z))]
214221
"""
215222

216223
# Reconstruct the images using D (autoencoder), get reconstruction loss
@@ -239,7 +246,8 @@ def train_G(self, images):
239246
G_loss: BEGAN loss for G, E[||G(z) - AE(G(Z))||1]
240247
"""
241248

242-
# Get noise, classify it using G, then reconstruct the output of G using D (autoencoder).
249+
# Get noise, classify it using G, then reconstruct the output of G
250+
# using D (autoencoder).
243251
noise = self.compute_noise(images.shape[0], self.model.z_dim) # z
244252
G_output = self.model.G(noise) # G(z)
245253
DG_reconst = self.model.D(G_output) # D(G(z))
@@ -250,7 +258,7 @@ def train_G(self, images):
250258
return G_loss
251259

252260
def compute_noise(self, batch_size, z_dim):
253-
""" Compute random noise for the generator to learn to make images from """
261+
""" Compute random noise for input into Generator G """
254262
return to_cuda(torch.randn(batch_size, z_dim))
255263

256264
def process_batch(self, iterator):
@@ -272,13 +280,16 @@ def generate_images(self, epoch, num_outputs=36, save=True):
272280
images = self.model.G(noise)
273281

274282
# Reshape to proper image size
275-
images = images.view(images.shape[0], 28, 28)
283+
images = images.view(images.shape[0],
284+
self.model.shape,
285+
self.model.shape,
286+
-1).squeeze()
276287

277288
# Plot
278289
plt.close()
279-
size_figure_grid = int(num_outputs**0.5)
280-
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
281-
for i, j in product(range(size_figure_grid), range(size_figure_grid)):
290+
grid_size = int(num_outputs**0.5)
291+
fig, ax = plt.subplots(grid_size, grid_size, figsize=(5, 5))
292+
for i, j in product(range(grid_size), range(grid_size)):
282293
ax[i,j].get_xaxis().set_visible(False)
283294
ax[i,j].get_yaxis().set_visible(False)
284295
ax[i,j].cla()
@@ -291,7 +302,7 @@ def generate_images(self, epoch, num_outputs=36, save=True):
291302
os.makedirs(outname)
292303
torchvision.utils.save_image(images.unsqueeze(1).data,
293304
outname + 'reconst_%d.png'
294-
%(epoch), nrow=size_figure_grid)
305+
%(epoch), nrow=grid_size)
295306

296307
def viz_loss(self):
297308
""" Visualize loss for the generator, discriminator """
@@ -300,9 +311,15 @@ def viz_loss(self):
300311
plt.style.use('ggplot')
301312
plt.rcParams["figure.figsize"] = (8,6)
302313

303-
# Plot Discriminator loss in red, Generator loss in green
304-
plt.plot(np.linspace(1, self.num_epochs, len(self.Dlosses)), self.Dlosses, 'r')
305-
plt.plot(np.linspace(1, self.num_epochs, len(self.Dlosses)), self.Glosses, 'g')
314+
# Plot Discriminator loss in red
315+
plt.plot(np.linspace(1, self.num_epochs, len(self.Dlosses)),
316+
self.Dlosses,
317+
'r')
318+
319+
# Plot Generator loss in green
320+
plt.plot(np.linspace(1, self.num_epochs, len(self.Dlosses)),
321+
self.Glosses,
322+
'g')
306323

307324
# Add legend, title
308325
plt.legend(['Discriminator', 'Generator'])

0 commit comments

Comments
 (0)