warp icon indicating copy to clipboard operation
warp copied to clipboard

[QUESTION] Optimization example with JAX

Open itk22 opened this issue 9 months ago • 2 comments

Hi, It's very exciting to see the recent improvements in JAX integration! I noticed there is an example showing how to use Warp with PyTorch for optimization (defining a loss in Warp and using PyTorch optimizers). Could you provide a similar example but using optax in JAX instead? Specifically, I would like to know how to:

  • Define a loss function in Warp
  • Compute gradients using Warp's tape
  • Optimize parameters using JAX's optimizers

Thank you!

itk22 avatar Mar 31 '25 13:03 itk22

Hi @itk22, we're planning to add JAX backward pass interoperability in an upcoming release (#515). This will include optimization examples similar to the PyTorch ones.

nvlukasz avatar Apr 01 '25 17:04 nvlukasz

Great, looking forward to this!

itk22 avatar Apr 01 '25 21:04 itk22