dex-lang
dex-lang copied to clipboard
AD of partially active if crashes
Reproduction
def foo (a: Float) (b: Float) : Float =
(b, z) = if (a < b)
then (True, a * b)
else (False, a * 2)
if b
then z * z
else z * z * z
grad (\x . (foo x (x+1))) 1.0
> Compiler bug!
> Please report this at github.com/google-research/dex-lang/issues
>
> Not implemented
> CallStack (from HasCallStack):
> error, called at src/lib/Linearize.hs:588:18 in dex-0.1.0.0-7ifYmHkPRmbOSGL0nbf4g:Linearize
> notImplemented, called at src/lib/Linearize.hs:261:24 in dex-0.1.0.0-7ifYmHkPRmbOSGL0nbf4g:Linearize
The error points to constructing a tangent type for the boolean, though in this case it would also in principle be possible to notice that the first component of the pair is not actually active, and so somehow split the pair and not demand a tangent for that boolean.
This example is isomorphic to the first if
returning an ADT with a payload, which would be even more awkward to split, but this factoring may merit special treatment because it arises when differentiating a state-effect computation that manipulates booleans.
Or then again, maybe we should just implement tangents and cotangents for ADTs and be done with it.