Why is there no normal categorical cross-entropy loss? #4668
-
This question concerns Optax just as much as Flax, but as the issue seems to stem from expections from Flax models, I am asking it here. I was following this example in the JAX AI stack documentation, and noticed something odd. When implementing a classifier for images of digits, the model does not by default output softmaxxed logits. Instead, softmaxxing is invoked when calling Given that |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Computing cross-entropy directly from logits rather than probabilities avoids some redundant computation and/or improves numerical accuracy. PyTorch takes a similar approach (see CrossEntropyLoss and NLLLoss), so this isn't unique to optax or flax. |
Beta Was this translation helpful? Give feedback.
Computing cross-entropy directly from logits rather than probabilities avoids some redundant computation and/or improves numerical accuracy.
PyTorch takes a similar approach (see CrossEntropyLoss and NLLLoss), so this isn't unique to optax or flax.