indaba-pracs-2022
indaba-pracs-2022 copied to clipboard
[BUG-FIX][Intro-prac] Swapped jnp.equal with jnp.isclose.
The function check_squared_error used jnp.equal to compare floats, which would make the test fail if the student decides to use native python implementation of the squared_error function, due to native python using double float precision while jax defaults to single float precision.
MRE:
This implementation of squared_error fails while it should not:
def squared_error(b, w, x, y):
# first calculate f(x_i), also sometimes referred to as yhat
yhat = w * x + b
# then calculate the squared error
error = (yhat - y) ** 2
return error
Fix:
# before fix:
assert jnp.equal(error, correct_error[i]), msg
# after fix:
assert jnp.isclose(error, correct_error[i]), msg
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB