dex-lang icon indicating copy to clipboard operation
dex-lang copied to clipboard

Support AD on ADTs that have tangent types equal to the primal types

Open chenzizhao opened this issue 3 years ago • 2 comments

Currently, grad does not work on data types such as List. For example:

ls = (AsList _ [1.,2.,3.])
def sumList ((AsList _ arr) : List Float): Float = sum arr
sumList ls
> 6.
(grad sumList) ls
dex: Not implemented
CallStack (from HasCallStack):
  error, called at src/lib/Autodiff.hs:507:18 in dex-0.1.0.0-RKG8rO926ZDrlxlc3NaeB:Autodiff
  notImplemented, called at src/lib/Autodiff.hs:386:19 in dex-0.1.0.0-RKG8rO926ZDrlxlc3NaeB:Autodiff

Discussion on Friday @apaszke @danieldjohnson @duvenaud For some data types, taking grad with respect to them do not require user-supplied tangent types, since they are the same as the original type. This would be easier to implement than to support grad on any, general data structure.

chenzizhao avatar Jun 18 '21 18:06 chenzizhao