horde-ad
horde-ad copied to clipboard
Find a way to implement mapping and other higher order operations on dual numbers efficiently, despite function space not being a differentiable type
Edit2: this should be completely rewritten and the version in Google doc from 9 March is not coherent, either. For a start, please ignore this ticket description and see a more modest problem statement in the last github comment of this ticket.
Edit: this is partially blocked by the performance tickets, because the viability of the solutions depends on their performance; this should be completely rewritten, see https://github.com/Mikolaj/mostly-harmless/discussions/16 that says
Show how SPJ's order of magnitude speedup of Index0 and similar functions (from 9 March section of the google doc) helps and how it's still many orders of magnitude slower than manual gradient, keeping higher order dual number functions not yet feasible. Present benchmark results. Try again to update in-place the vectors, matrices and tensors held by the vectors of parameters --- currently only the vectors of parameters are updated in-place, which guarantees optimal performance of the primal component at scalar level, but is slow due to (matrices of deltas representing scalars) instead of (deltas representing matrices). Previous attempt has shown minuscule gains, but there were no indexing operations nor higher order functions on them (recover comments about this from old commits in the repo).
End Edit.
E. g. this crucial tensor operation does not seem possible to include in our engine (used as in "reduce any matrix within the tensor to a vector using f
and return a one rank lower tensor with the results of the reduction").
rerankS :: forall n i o sh r s.
( Numeric r, Numeric s, Drop n sh ~ i, Shape sh, KnownNat n
, Shape o, Shape (Take n sh ++ o) )
=> (OS.Array i r -> OS.Array o s)
-> DualNumber (OS.Array sh r)
-> DualNumber (OS.Array (Take n sh ++ o) s)
rerankS f (D u u') = D (OS.rerank f u) undefined
But even the following fully general mapping on vectors seems impossible to implement.
map1 :: (Numeric r, Numeric s)
=> (r -> s) -> DualNumber (Vector r) -> DualNumber (Vector s)
map1 f (D u u') = D (V.map f u) (Map1 f u')
eval1 :: Vector s -> Delta1 s -> ST s ()
eval1 r = \case
Map1 (f :: r -> s) d -> eval1 (V.map g r) d
Where g : s -> r
, which would be needed, but which we don't have, does something analogous as f
but on different types (it's not an inverse!). See the appendix for some related weak ideas.
For the specific case where the mapped function in an automorphism, the following works at least when f
is scaling and perhaps if f
is linear or more (TODO: investigate)
eval1 :: Vector r -> Delta1 r -> ST r ()
eval1 r = \case
Map1 (f :: r -> r) d -> eval1 (V.map f r) d
However, a bit more complex specific variant is again hard to unravel
map21 :: Numeric r => (Vector r -> r) -> DualNumber (Matrix r) -> DualNumber (Vector r)
map21 f (D u u') = D (V.fromList $ map f $ HM.toRows u) (Map21 f u')
eval1 :: Vector r -> Delta1 r -> ST r ()
eval1 r = \case
Map21 (f :: Vector r -> r) (d :: Delta2 r) -> eval2 ??? d
Lack of compositionality?
In other words, we seem to lack compositionality. We have sumElements0 :: DualNumber (Vector r) -> DualNumber r
but it seems we can't express summing all rows of a matrix as map21 sumElements0
or anything else, no matter how many general dual number functions (and delta expression constructors to implement them) on matrices we add. The only way seems to be to add a new delta expression constructor that sums rows of a matrix and has no common parts with sumElements0
.
Edit: Performing row selection n
times, applying sumElements0
to each row and converting the vector of dual numbers into a dual number of a vector doesn't count, because it drops one level too low, which has the cost of creating n
new delta-expressions in place of a constant number. It would be good to benchmark this pseudo-solution and confirm it's asymptotically worse both in time and memory than a sensible solution for whatever simple case we can obtain it.
The pseudo-solution with as many repeated evaluations as the number of rows of the matrix could look as follows
map21 :: Numeric r
=> (Vector r -> r) -> (Delta1 r -> Delta0 r) -> DualNumber (Matrix r)
-> DualNumber (Vector r)
map21 f f' (D u u') = D (V.fromList $ map f $ HM.toRows u) (Map21 f' u')
eval1 :: Vector r -> Delta1 r -> ST r ()
eval1 r = \case
Map21 (f' :: Delta1 r -> Delta0 r) (d :: Delta2 r) ->
mapM_ (\irow -> eval0 (r V.! irow) (f' (RowIndex1 d irow))) [0 .. V.length r]
The application of this function could be (glossing over the extra argument to SumElements0
that could not be required in a dependently typed version)
map21 HM.sumElements SumElements0 some_matrix
which underlines the lack of compositionality given that
sumElements0 :: IsScalar r => DualNumber (Vector r) -> DualNumber r
sumElements0 (D u u') = D (HM.sumElements u) (SumElements0 u' (V.length u))
Note also how different that is from map1
above. I think we are bumping our heads against the problems with exponential objects in this category (lack of a differentiable function type, in other words). From what I investigated, these are very hard problems, so if would be ideal to sidestep them somehow.
BTW, does this only work for functions f :: DualNumber (Vector r) -> DualNumber r
that have the property that f (D u u') = h u &&& g u'
for some h
and g
? Exponential object in DualNumber
category restricted to such functions is trivial to define element-wise, but many important functions are not of this form.
Appendix: some weak ideas
To implement map1
, perhaps we'd need
map1 :: (Numeric r, Numeric s)
=> DualNumber (r -> s) -> DualNumber (Vector r) -> DualNumber (Vector s)
where DualNumber (r -> s)
contains a function r -> s
and a function s -> r
that does something analogous (note the contravariance). But I think exponential objects in that category are nothing like that.
Or DualNumber (r -> s)
could contain a function r -> s
and a function Delta1 r -> Delta1 s
(note the lack of contravariance; would the user provide that function?), giving
eval1 :: Vector s -> Delta1 s -> ST s ()
eval1 r = \case
Map1 (g :: Delta1 r -> Delta1 s) d -> eval1 r (g d)