lecture-jax
lecture-jax copied to clipboard
Vectorized B in opt_savings
https://github.com/QuantEcon/lecture-jax/pull/88#issuecomment-1667001277
Deploy Preview for incomparable-parfait-2417f8 ready!
| Name | Link |
|---|---|
| Latest commit | 46f56ac2465be3ce8cc4b761bd1fd41390af469c |
| Latest deploy log | https://app.netlify.com/sites/incomparable-parfait-2417f8/deploys/64d73893cf01ac0007bb17ac |
| Deploy Preview | https://deploy-preview-94--incomparable-parfait-2417f8.netlify.app |
| Preview on mobile | Toggle QR Code...Use your smartphone camera to open QR code link. |
To edit notification comments on pull requests, go to your Netlify site configuration.
🚀 Deployed on https://64d73ad6aee1735ec98f0328--incomparable-parfait-2417f8.netlify.app
@jstac, This is the new version of #88.
The time measures on this PR are:
HPI: 0.03126978874206543
VFI: 0.9759271144866943
OPI: 0.3162047863006592
And on the deployed version: https://jax.quantecon.org/opt_savings.html
HPI: 0.03399658203125
VFI: 0.90895676612854
OPI: 0.3122248649597168
Not much variation in HPI and OPI compared to VFI.
Please have a look. Thank you.
Thanks @Smit-create !
Regarding
B_vec_wp = jax.vmap(B, in_axes=(None, None, None, 0, 0, None))
B_vec_y_wp = jax.vmap(B_vec_wp, in_axes=(None, None, 0, None, None, 0))
B_vec_w_y_wp = jax.vmap(B_vec_y_wp, in_axes=(None, 0, None, None, None, None))
I would have expected
B_vec_wp = jax.vmap(B,
in_axes=(None, None, None, 0, None, None))
B_vec_y_wp = jax.vmap(B_vec_wp,
in_axes=(None, None, 0, None, None, None))
B_vec_w_y_wp = jax.vmap(B_vec_y_wp,
in_axes=(None, 0, None, None, None, None))
because we are not vectorizing in v or Q. Does that also work, or am I wrong?
Does that also work, or am I wrong?
I'm not sure if that would be correct.
For the function:
B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′)
We start from w' and since it is used as the axis 0 in wp and axis 0 in v we write it as:
B_vec_wp = jax.vmap(B, in_axes=(None, None, None, 0, 0, None))
Then pick the y which is the axis 0 in y_grid and axis 0 in Q, and we write it as:
B_vec_y_wp = jax.vmap(B_vec_wp, in_axes=(None, None, 0, None, None, 0))
And at last, we use the same for w:
B_vec_w_y_wp = jax.vmap(B_vec_y_wp, in_axes=(None, 0, None, None, None, None))
So you can see that our function B is exactly the same as the mathematical function B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′).
The loops version would like:
res = np.empty((w_size, y_size, w_size))
for w in range(w_size):
for y in range(y_size):
for wp in range(w_size):
res[w, y, wp] = u(R*w_grid[w] + y_grid[y] - w_grid[wp]) + β*(np.sum(v[wp,:]*Q[y,:]))
That is the same as what we want.
Your suggestion
B_vec_wp = jax.vmap(B,
in_axes=(None, None, None, 0, None, None))
B_vec_y_wp = jax.vmap(B_vec_wp,
in_axes=(None, None, 0, None, None, None))
B_vec_w_y_wp = jax.vmap(B_vec_y_wp,
in_axes=(None, 0, None, None, None, None))
If we don't do any vectorization in v and Q, B will fail as v has a shape of (w_size, y_size) and Q has a shape of (y_size, y_size) and they are not broadcastable if w_size != y_size. Also, it's not correct as the equation would turn out to be:
B(w, y, w′) = u(Rw + y - w′) + β Σ_(i, j) v(i, j) Q(i, j)
The loops version would like:
res = np.empty((w_size, y_size, w_size))
for w in range(w_size):
for y in range(y_size):
for wp in range(w_size):
res[w, y, wp] = u(R*w_grid[w] + y_grid[y] - w_grid[wp]) + β*(np.sum(v*Q)) # fails when v and Q are not broadcast-able
# that is fails when w_size != y_size. Moreover, it mathematically deviates from the correct definition.
Therefore I think that it isn't right.
Closing as complete.