Skip to content

Implementation of lookahead #242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -490,3 +490,14 @@ Thank you to Matthew Mann for his inspiring [simple port](https://github.com/man
primaryClass = {cs.CV}
}
```

```bibtex
@misc{chavdarova2021taming,
title={Taming GANs with Lookahead-Minmax},
author={Tatjana Chavdarova and Matteo Pagliardini and Sebastian U. Stich and Francois Fleuret and Martin Jaggi},
year={2021},
eprint={2006.14567},
archivePrefix={arXiv},
primaryClass={stat.ML}
}
```
10 changes: 8 additions & 2 deletions stylegan2_pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def train_from_folder(
calculate_fid_num_images = 12800,
clear_fid_cache = False,
seed = 42,
log = False
log = False,
lookahead=False,
lookahead_alpha=0.5,
lookahead_k = 5
):
model_args = dict(
name = name,
Expand Down Expand Up @@ -155,7 +158,10 @@ def train_from_folder(
calculate_fid_num_images = calculate_fid_num_images,
clear_fid_cache = clear_fid_cache,
mixed_prob = mixed_prob,
log = log
log = log,
lookahead = lookahead,
lookahead_alpha = lookahead_alpha,
lookahead_k = lookahead_k
)

if generate:
Expand Down
52 changes: 48 additions & 4 deletions stylegan2_pytorch/stylegan2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import fire
import json

from collections import defaultdict
from tqdm import tqdm
from math import floor, log2
from random import random
Expand Down Expand Up @@ -284,6 +285,30 @@ def slerp(val, low, high):
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res

# lookahead
class Lookahead(torch.optim.Optimizer):
def __init__(self, optimizer, alpha=0.5):
self.optimizer = optimizer
self.alpha = alpha
self.param_groups = self.optimizer.param_groups
self.state = defaultdict(dict)

def lookahead_step(self):
for group in self.param_groups:
for fast in group["params"]:
param_state = self.state[fast]
if "slow_params" not in param_state:
param_state["slow_params"] = torch.zeros_like(fast.data)
param_state["slow_params"].copy_(fast.data)
slow = param_state["slow_params"]
# slow <- slow + alpha * (fast - slow)
slow += (fast.data - slow) * self.alpha
fast.data.copy_(slow)

def step(self, closure = None):
loss = self.optimizer.step(closure)
return loss

# losses

def gen_hinge_loss(fake, real):
Expand Down Expand Up @@ -677,7 +702,7 @@ def forward(self, x):
return x.squeeze(), quantize_loss

class StyleGAN2(nn.Module):
def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, ttur_mult = 2, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, lr_mlp = 0.1, rank = 0):
def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, ttur_mult = 2, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, lr_mlp = 0.1, rank = 0, lookahead=False, lookahead_alpha=0.5):
super().__init__()
self.lr = lr
self.steps = steps
Expand Down Expand Up @@ -710,6 +735,11 @@ def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8
self.G_opt = Adam(generator_params, lr = self.lr, betas=(0.5, 0.9))
self.D_opt = Adam(self.D.parameters(), lr = self.lr * ttur_mult, betas=(0.5, 0.9))

if lookahead:
# Wrap optimizers with the lookahead optimizer
self.G_opt = Lookahead(self.G_opt, alpha=lookahead_alpha)
self.D_opt = Lookahead(self.D_opt, alpha=lookahead_alpha)

# init weights
self._init_weights()
self.reset_parameter_averaging()
Expand Down Expand Up @@ -792,6 +822,9 @@ def __init__(
rank = 0,
world_size = 1,
log = False,
lookahead = False,
lookahead_alpha=0.5,
lookahead_k = 5,
*args,
**kwargs
):
Expand Down Expand Up @@ -880,6 +913,10 @@ def __init__(

self.logger = aim.Session(experiment=name) if log else None

self.lookahead = lookahead
self.lookahead_k = lookahead_k
self.lookahead_alpha = lookahead_alpha

@property
def image_extension(self):
return 'jpg' if not self.transparent else 'png'
Expand All @@ -894,7 +931,7 @@ def hparams(self):

def init_GAN(self):
args, kwargs = self.GAN_params
self.GAN = StyleGAN2(lr = self.lr, lr_mlp = self.lr_mlp, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, fmap_max = self.fmap_max, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, rank = self.rank, *args, **kwargs)
self.GAN = StyleGAN2(lr = self.lr, lr_mlp = self.lr_mlp, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, fmap_max = self.fmap_max, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, rank = self.rank, lookahead=self.lookahead, lookahead_alpha=self.lookahead_alpha, *args, **kwargs)

if self.is_ddp:
ddp_kwargs = {'device_ids': [self.rank]}
Expand Down Expand Up @@ -1112,15 +1149,22 @@ def train(self):
self.GAN.G_opt.step()

# calculate moving averages
if self.lookahead and (self.steps + 1) % self.lookahead_k == 0:
# Joint lookahead update
self.GAN.D_opt.lookahead_step()
self.GAN.G_opt.lookahead_step()

if self.is_main:
self.GAN.EMA()

if apply_path_penalty and not np.isnan(avg_pl_length):
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
self.track(self.pl_mean, 'PL')

if self.is_main and self.steps % 10 == 0 and self.steps > 20000:
if self.is_main and not self.lookahead and self.steps % 10 == 0 and self.steps > 20000:
self.GAN.EMA()

if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2:
if self.is_main and not self.lookahead and self.steps <= 25000 and self.steps % 1000 == 2:
self.GAN.reset_parameter_averaging()

# save from NaN errors
Expand Down