Skip to content

AdamW implementation does not truly decouple learning rate and weight decay #1849

Open
@leenachennuru

Description

@leenachennuru

Describe the bug

AdamW implementation (see here) does not truly decouple the weight decay and learning rate parameters in line with the adamw paper. This coupling often complicates HP tuning as tuning the learning rate also changes the effective WD used to train the model.

The implementation computes the updates as

$w_{t} = (1- \eta_{\text{effective}} \lambda) w_{t-1} - \eta_{\text{effective}} {\hat{m}_t} / {\sqrt{\hat{v}_t} + \epsilon}$

where $\eta_{\text{effective}} = \eta_t \eta_{\text{max}}$ with $\eta_t$ denoting the scheduler and $\eta_{\text{max}}$ the max/base LR.

This clearly couples LR and WD and is not in line with the paper which proposes to compute the updates as

$w_{t} = (1- \eta_t \lambda) w_{t-1} - \eta_t \eta_{\text{max}} {\hat{m}_t} / {\sqrt{\hat{v}_t} + \epsilon}$

For easier and more intuitive tuning, it would be useful to enable the completely decoupled version of AdamW via the simple fix: $\lambda = (\eta_{\text{effective}} / \eta_{\text{max}}) \lambda$ with updates: $w_{t} = (1- \lambda) w_{t-1} - \eta_{\text{effective}} {\hat{m}_t}/{\sqrt{\hat{v}_t} + \epsilon}$.

Note: This bug also exists in implementations of AdamW in Pytorch and Optax and has already been highlighted a few times across different papers, libraries, and blogs. More links below for reference.

  1. Mosaic ML Library
  2. Optimi
  3. Paper: How to set AdamW's weight decay as you scale model and dataset size
  4. Fabian Schaipp's blog

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions