Skip to content

Incredibly slow transformer code #1229

Answered by jheek
Numeri asked this question in Q&A
Discussion options

You must be logged in to vote

Looking at the code my first guess is that your problem is the placement of jax.jit. Note how you end up calling jit each iteration.

The second problem is that the jit does not capture the full train step so you are missing some optimization opportunities.
Please try to rewrite your train code like this:

@jax.jit
def train_step(optimizer,  batch):
  print("compling train step...") # This should print only once in the entire train script. Otherwise you have a re-compile bug
  def loss_fn(...):
    ...
  loss, grad = jax.value_and_grad(loss_fn)
  optimizer = optimizer.apply_gradient(grad)
  return optimizer, loss


def train():
  ...

  for step in range(num_steps):
    ...
    optimizer = t…

Replies: 4 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Answer selected by jheek
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@jheek
Comment options

jheek Apr 9, 2021
Maintainer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants
Converted from issue

This discussion was converted from issue #1218 on April 09, 2021 07:43.