warp
warp copied to clipboard
[QUESTION] Optimization example with JAX
Hi, It's very exciting to see the recent improvements in JAX integration! I noticed there is an example showing how to use Warp with PyTorch for optimization (defining a loss in Warp and using PyTorch optimizers). Could you provide a similar example but using optax in JAX instead? Specifically, I would like to know how to:
- Define a loss function in Warp
- Compute gradients using Warp's tape
- Optimize parameters using JAX's optimizers
Thank you!
Hi @itk22, we're planning to add JAX backward pass interoperability in an upcoming release (#515). This will include optimization examples similar to the PyTorch ones.
Great, looking forward to this!