-
I'm working on a project with a deadline, and seem to have hit a wall. My code is running many orders of magnitude slower than I had hoped, and I don't understand why. I'm making a Transformer model, and need I tried to profile it using I really do apologize for the lack of a minimal working example, but I haven't been able to pin down the cause of the slow speeds. Would anyone experienced mind quickly glancing at my repo and looking for obvious mistakes? I know this isn't how GitHub usually works, but if anyone can help me get this working, I'd love to Venmo enough for a dinner out on the town – I'm feeling very desperate. Thank you in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 1 reply
-
Looking at the code my first guess is that your problem is the placement of The second problem is that the jit does not capture the full train step so you are missing some optimization opportunities. @jax.jit
def train_step(optimizer, batch):
print("compling train step...") # This should print only once in the entire train script. Otherwise you have a re-compile bug
def loss_fn(...):
...
loss, grad = jax.value_and_grad(loss_fn)
optimizer = optimizer.apply_gradient(grad)
return optimizer, loss
def train():
...
for step in range(num_steps):
...
optimizer = train_step(optimizer, batch)
... |
Beta Was this translation helpful? Give feedback.
-
@jheek Thank you so much for the help! That sped things up considerably, after the The compilation itself takes a whole 10 minutes (611 seconds, last run), which is longer than it was before. Is there anything I can do for that? |
Beta Was this translation helpful? Give feedback.
-
I'm surprised the compilation is that slow! |
Beta Was this translation helpful? Give feedback.
-
Oh, dear, so I misunderstood something fairly essential to transformers then. That's unfortunate. Is a beam search still performed in that token-by-token manner? Also, @jheek, if you don't mind, email me at ____ so I can Venmo or PayPal you! I've really appreciated your help |
Beta Was this translation helpful? Give feedback.
Looking at the code my first guess is that your problem is the placement of
jax.jit
. Note how you end up calling jit each iteration.The second problem is that the jit does not capture the full train step so you are missing some optimization opportunities.
Please try to rewrite your train code like this: