futhark
futhark copied to clipboard
Vectorised AD
I would like the following functions to be made available:
val jvp2_vec 'a 'b [n] : (a -> b) -> a -> [n]a -> (b, [n]b)
val vjp2_vec 'a 'b [n] : (a -> b) -> a -> [n]b -> (b, [n]b)
The names are open to bikeshedding. The idea is to let AD compute multiple (co)tangents in one go. This can avoid n
executions of the primal function. In some cases the compiler might be able to optimise the replicated work, but I wouldn't want to rely on it in all cases.
I think this is fairly straightforward to implement: we just need to teach the AD passes that the (co)tangent of a primal variable of type t
is not necessarily of type t
, but can also be an array of type [n]t
(where n
is a constant in any instance of AD).
Is this inspired by the work Martin is/was doing?
Yes!
søn. 3. jul. 2022 kl. 16.51 skrev zfnmxt @.***>:
Is this inspired by the work Martin is/was doing?
— Reply to this email directly, view it on GitHub https://github.com/diku-dk/futhark/issues/1697#issuecomment-1173105281, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAI5DOYR2FZOHW7ZSPPCXHLVSGSFHANCNFSM52PQNMGQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>