indaba-pracs-2022 icon indicating copy to clipboard operation
indaba-pracs-2022 copied to clipboard

[BUG-FIX][Intro-prac] Swapped jnp.equal with jnp.isclose.

Open anasbekheit opened this issue 1 year ago • 1 comments

The function check_squared_error used jnp.equal to compare floats, which would make the test fail if the student decides to use native python implementation of the squared_error function, due to native python using double float precision while jax defaults to single float precision.

MRE:

This implementation of squared_error fails while it should not:

def squared_error(b, w, x, y):
    # first calculate f(x_i), also sometimes referred to as yhat
    yhat = w * x + b
    # then calculate the squared error
    error = (yhat - y) ** 2
    return error

Fix:

        # before fix:
        assert jnp.equal(error, correct_error[i]), msg

       # after fix:
        assert jnp.isclose(error, correct_error[i]), msg

anasbekheit avatar Dec 02 '23 23:12 anasbekheit

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB