dex-lang
dex-lang copied to clipboard
Support AD on ADTs that have tangent types equal to the primal types
Currently, grad
does not work on data types such as List
. For example:
ls = (AsList _ [1.,2.,3.])
def sumList ((AsList _ arr) : List Float): Float = sum arr
sumList ls
> 6.
(grad sumList) ls
dex: Not implemented
CallStack (from HasCallStack):
error, called at src/lib/Autodiff.hs:507:18 in dex-0.1.0.0-RKG8rO926ZDrlxlc3NaeB:Autodiff
notImplemented, called at src/lib/Autodiff.hs:386:19 in dex-0.1.0.0-RKG8rO926ZDrlxlc3NaeB:Autodiff
Discussion on Friday @apaszke @danieldjohnson @duvenaud
For some data types, taking grad
with respect to them do not require user-supplied tangent types, since they are the same as the original type. This would be easier to implement than to support grad
on any, general data structure.