lecture-jax icon indicating copy to clipboard operation
lecture-jax copied to clipboard

Vectorized B in opt_savings

Open Smit-create opened this issue 2 years ago • 5 comments
trafficstars

https://github.com/QuantEcon/lecture-jax/pull/88#issuecomment-1667001277

Smit-create avatar Aug 12 '23 04:08 Smit-create

Deploy Preview for incomparable-parfait-2417f8 ready!

Name Link
Latest commit 46f56ac2465be3ce8cc4b761bd1fd41390af469c
Latest deploy log https://app.netlify.com/sites/incomparable-parfait-2417f8/deploys/64d73893cf01ac0007bb17ac
Deploy Preview https://deploy-preview-94--incomparable-parfait-2417f8.netlify.app
Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

netlify[bot] avatar Aug 12 '23 04:08 netlify[bot]

🚀 Deployed on https://64d73ad6aee1735ec98f0328--incomparable-parfait-2417f8.netlify.app

github-actions[bot] avatar Aug 12 '23 04:08 github-actions[bot]

@jstac, This is the new version of #88.

The time measures on this PR are:

HPI: 0.03126978874206543
VFI: 0.9759271144866943
OPI: 0.3162047863006592

And on the deployed version: https://jax.quantecon.org/opt_savings.html

HPI: 0.03399658203125
VFI: 0.90895676612854
OPI: 0.3122248649597168

Not much variation in HPI and OPI compared to VFI.

Please have a look. Thank you.

Smit-create avatar Aug 12 '23 04:08 Smit-create

Thanks @Smit-create !

Regarding

B_vec_wp = jax.vmap(B, in_axes=(None, None, None, 0, 0, None))
B_vec_y_wp = jax.vmap(B_vec_wp, in_axes=(None, None, 0, None, None, 0))
B_vec_w_y_wp = jax.vmap(B_vec_y_wp, in_axes=(None, 0, None, None, None, None))

I would have expected

B_vec_wp = jax.vmap(B, 
    in_axes=(None, None, None, 0, None, None))
B_vec_y_wp = jax.vmap(B_vec_wp, 
    in_axes=(None, None, 0, None, None, None))
B_vec_w_y_wp = jax.vmap(B_vec_y_wp, 
    in_axes=(None, 0, None, None, None, None))

because we are not vectorizing in v or Q. Does that also work, or am I wrong?

jstac avatar Aug 12 '23 07:08 jstac

Does that also work, or am I wrong?

I'm not sure if that would be correct.

For the function:

B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′)

We start from w' and since it is used as the axis 0 in wp and axis 0 in v we write it as:

B_vec_wp = jax.vmap(B, in_axes=(None, None, None, 0, 0, None))

Then pick the y which is the axis 0 in y_grid and axis 0 in Q, and we write it as:

B_vec_y_wp = jax.vmap(B_vec_wp, in_axes=(None, None, 0, None, None, 0))

And at last, we use the same for w:

B_vec_w_y_wp = jax.vmap(B_vec_y_wp, in_axes=(None, 0, None, None, None, None))

So you can see that our function B is exactly the same as the mathematical function B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′).

The loops version would like:

res = np.empty((w_size, y_size, w_size))
for w in range(w_size):
    for y in range(y_size):
        for wp in range(w_size):
            res[w, y, wp] = u(R*w_grid[w] + y_grid[y] - w_grid[wp]) + β*(np.sum(v[wp,:]*Q[y,:]))

That is the same as what we want.

Your suggestion

B_vec_wp = jax.vmap(B, 
    in_axes=(None, None, None, 0, None, None))
B_vec_y_wp = jax.vmap(B_vec_wp, 
    in_axes=(None, None, 0, None, None, None))
B_vec_w_y_wp = jax.vmap(B_vec_y_wp, 
    in_axes=(None, 0, None, None, None, None))

If we don't do any vectorization in v and Q, B will fail as v has a shape of (w_size, y_size) and Q has a shape of (y_size, y_size) and they are not broadcastable if w_size != y_size. Also, it's not correct as the equation would turn out to be:

B(w, y, w′) = u(Rw + y - w′) + β Σ_(i, j) v(i, j) Q(i, j)

The loops version would like:

res = np.empty((w_size, y_size, w_size))
for w in range(w_size):
    for y in range(y_size):
        for wp in range(w_size):
            res[w, y, wp] = u(R*w_grid[w] + y_grid[y] - w_grid[wp]) + β*(np.sum(v*Q)) # fails when v and Q are not broadcast-able
            # that is fails when w_size != y_size. Moreover, it mathematically deviates from the correct definition.

Therefore I think that it isn't right.

Smit-create avatar Aug 12 '23 07:08 Smit-create

Closing as complete.

jstac avatar May 18 '24 17:05 jstac