Skip to content

Commit 5670f0f

Browse files
author
Anuj Rawat
committed
Fixing test that fails on AVX512
The operation categorical_crossentropy requires taking log as an intermediate step. Due to the rank (2) and shape (3, 3) of the tensors used in this example, on AVX2 and older builds, the log operation uses plog, Eigen's packet log method, whereas on AVX512 build, the log operation is not vectorized and ends up using std::log. Due to the precision mismatch between std::log and Eigen's plog, the results do not match exactly. The loss values comes out to be equal to [0.10536055 0.8046685 0.06187541], instead of [0.10536055 0.8046684 0.06187541]. This is an expected mismatch and should not fail the test. The absolutely correct way to test would be to compare hex values and make sure that the results are within the expected range of the ULP error. An easier fix would be to reduce the precision of the test to account for such mismatches between the implementation of operators in the underlying math libraries. We are taking the second approach and will compare results after rounding to 5 decimal places.
1 parent c52c529 commit 5670f0f

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

tensorflow/python/keras/backend.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4482,12 +4482,11 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
44824482
[0.5 0.89 0.6 ]
44834483
[0.05 0.01 0.94]], shape=(3, 3), dtype=float32)
44844484
>>> loss = tf.keras.backend.categorical_crossentropy(a, b)
4485-
>>> print(loss)
4486-
tf.Tensor([0.10536055 0.8046684 0.06187541], shape=(3,), dtype=float32)
4485+
>>> print(np.around(loss, 5))
4486+
[0.10536 0.80467 0.06188]
44874487
>>> loss = tf.keras.backend.categorical_crossentropy(a, a)
4488-
>>> print(loss)
4489-
tf.Tensor([1.1920929e-07 1.1920929e-07 1.1920930e-07], shape=(3,),
4490-
dtype=float32)
4488+
>>> print(np.around(loss, 5))
4489+
[0. 0. 0.]
44914490
44924491
"""
44934492
if from_logits:

0 commit comments

Comments
 (0)