drjit
drjit copied to clipboard
Differentiating through loops
I've been playing around with Dr.Jit. It seems to work very well, but I think there's something fairly fundamental that I'm not understanding.
I tried to implement a very basic differentiable renderer based on the sphere tracing example in the docs. (I know that I'm not handling discontinuities properly and this won't work well). I do actually get fairly decent results if I use a regular python loop in the sphere tracing routine, but whenever I try to use the "recorded loop" construct and then ask it to compute gradients I get the message:
loop_process_state(): one of the supplied loop state variables of type Float is attached to the AD graph (i.e., grad_enabled(..) is true). However, propagating derivatives through multiple iterations of a recorded loop is not supported (and never will be). Please see the documentation on differentiating loops for details and suggested alternatives.
I haven't found any documentation on "differentiating loops". Is there something you could point me to that explains how I would differentiate through a loop? I did look at the differentiable sphere tracing code, but it's hard for me to disentangle the discontinuity handling from the basic "differentiable loop" code.
Thanks!
Dear @daseyb,
The documentation is still in progress (and will take probably take quite some time https://github.com/mitsuba-renderer/mitsuba3/discussions/123 :( On differentiating recorded loops, AD is supported only within each iteration. Building AD graph across iterations, as the message says, will never be supported.
- How useful can it be with such strict constraints?
Turns out that this is all we need in current differentiable rendering algorithms. Eg:
prb_basic.py
in mitsuba3 repo. - What if I really want to propagate gradients through iterations? We can still unroll the loops. Disabling "loopRecord" is equivalent to a python loop.
- Why do we need recorded loops then?
Recorded loops compile the loop iteration only once. It tracks the operations needed in each iteration, and knows what variables would change (the
lambda: (var1, var2)
in loop declaration) so that we can re-use the compiled code repeatedly.
The official doc would explain this in a way more detailed way. Hope this helps for now.
Thanks for the reply! That clarifies things. I'll try to look at examples in the mitsuba3 repo to get some idea of what the "proper" way to do this is. I hope everyone enjoys their vacation, I can imagine you need it after all the big recent releases! :)