George Stepaniants

Results 7 comments of George Stepaniants

Thank you! So I ended up doing something like this to get jvp and vjp batched over inputs and tangent vectors. Does this make sense? ``` d = 5 D...

Actually this is a lot faster ``` ft_jacobians = vmap(jacrev(predict))(xs) ft_jvp2 = torch.einsum("ikl, jl -> ijk", ft_jacobians, us) ft_vjp2 = torch.einsum("ikl, jk -> ijl", ft_jacobians, vs) print(torch.norm(ft_jvp2[0, :, :] -...

Okay great, so I am running my code on the newest M1 mac and I've noticed that this problem appears when I run this MRE in visual studio code. If...

Thanks Patrick this clarifies it! So just to double check, ForwardMode is very similar to RecursiveCheckpointAdjoint and DirectAdjoint (which are all discretize-then-optimize approaches), but it is using forward mode AD...

I see, so the only numerical operation I have in the callback is when I call my equinox model "__call__" function in a for loop. Should I jit the "__call__"...

Just wanted to bump this question again. If we want a callback that efficiently plots the results of our model every epoch or so, should we be jitting the __call__...

That is roughly what I do, I only use my callback to make plots every epoch, with many batch steps in between. It is the "extract some value" step that...