Find a way to implement mapping and other higher order operations on dual numbers efficiently, despite function space not being a differentiable type
Edit2: this should be completely rewritten and the version in Google doc from 9 March is not coherent, either. For a start, please ignore this ticket description and see a more modest problem statement in the last github comment of this ticket.
Edit: this is partially blocked by the performance tickets, because the viability of the solutions depends on their performance; this should be completely rewritten, see https://github.com/Mikolaj/mostly-harmless/discussions/16 that says
Show how SPJ's order of magnitude speedup of Index0 and similar functions (from 9 March section of the google doc) helps and how it's still many orders of magnitude slower than manual gradient, keeping higher order dual number functions not yet feasible. Present benchmark results. Try again to update in-place the vectors, matrices and tensors held by the vectors of parameters --- currently only the vectors of parameters are updated in-place, which guarantees optimal performance of the primal component at scalar level, but is slow due to (matrices of deltas representing scalars) instead of (deltas representing matrices). Previous attempt has shown minuscule gains, but there were no indexing operations nor higher order functions on them (recover comments about this from old commits in the repo).
End Edit.
E. g. this crucial tensor operation does not seem possible to include in our engine (used as in "reduce any matrix within the tensor to a vector using f and return a one rank lower tensor with the results of the reduction").
rerankS :: forall n i o sh r s.
( Numeric r, Numeric s, Drop n sh ~ i, Shape sh, KnownNat n
, Shape o, Shape (Take n sh ++ o) )
=> (OS.Array i r -> OS.Array o s)
-> DualNumber (OS.Array sh r)
-> DualNumber (OS.Array (Take n sh ++ o) s)
rerankS f (D u u') = D (OS.rerank f u) undefined
But even the following fully general mapping on vectors seems impossible to implement.
map1 :: (Numeric r, Numeric s)
=> (r -> s) -> DualNumber (Vector r) -> DualNumber (Vector s)
map1 f (D u u') = D (V.map f u) (Map1 f u')
eval1 :: Vector s -> Delta1 s -> ST s ()
eval1 r = \case
Map1 (f :: r -> s) d -> eval1 (V.map g r) d
Where g : s -> r, which would be needed, but which we don't have, does something analogous as f but on different types (it's not an inverse!). See the appendix for some related weak ideas.
For the specific case where the mapped function in an automorphism, the following works at least when f is scaling and perhaps if f is linear or more (TODO: investigate)
eval1 :: Vector r -> Delta1 r -> ST r ()
eval1 r = \case
Map1 (f :: r -> r) d -> eval1 (V.map f r) d
However, a bit more complex specific variant is again hard to unravel
map21 :: Numeric r => (Vector r -> r) -> DualNumber (Matrix r) -> DualNumber (Vector r)
map21 f (D u u') = D (V.fromList $ map f $ HM.toRows u) (Map21 f u')
eval1 :: Vector r -> Delta1 r -> ST r ()
eval1 r = \case
Map21 (f :: Vector r -> r) (d :: Delta2 r) -> eval2 ??? d
Lack of compositionality?
In other words, we seem to lack compositionality. We have sumElements0 :: DualNumber (Vector r) -> DualNumber r but it seems we can't express summing all rows of a matrix as map21 sumElements0 or anything else, no matter how many general dual number functions (and delta expression constructors to implement them) on matrices we add. The only way seems to be to add a new delta expression constructor that sums rows of a matrix and has no common parts with sumElements0.
Edit: Performing row selection n times, applying sumElements0 to each row and converting the vector of dual numbers into a dual number of a vector doesn't count, because it drops one level too low, which has the cost of creating n new delta-expressions in place of a constant number. It would be good to benchmark this pseudo-solution and confirm it's asymptotically worse both in time and memory than a sensible solution for whatever simple case we can obtain it.
The pseudo-solution with as many repeated evaluations as the number of rows of the matrix could look as follows
map21 :: Numeric r
=> (Vector r -> r) -> (Delta1 r -> Delta0 r) -> DualNumber (Matrix r)
-> DualNumber (Vector r)
map21 f f' (D u u') = D (V.fromList $ map f $ HM.toRows u) (Map21 f' u')
eval1 :: Vector r -> Delta1 r -> ST r ()
eval1 r = \case
Map21 (f' :: Delta1 r -> Delta0 r) (d :: Delta2 r) ->
mapM_ (\irow -> eval0 (r V.! irow) (f' (RowIndex1 d irow))) [0 .. V.length r]
The application of this function could be (glossing over the extra argument to SumElements0 that could not be required in a dependently typed version)
map21 HM.sumElements SumElements0 some_matrix
which underlines the lack of compositionality given that
sumElements0 :: IsScalar r => DualNumber (Vector r) -> DualNumber r
sumElements0 (D u u') = D (HM.sumElements u) (SumElements0 u' (V.length u))
Note also how different that is from map1 above. I think we are bumping our heads against the problems with exponential objects in this category (lack of a differentiable function type, in other words). From what I investigated, these are very hard problems, so if would be ideal to sidestep them somehow.
BTW, does this only work for functions f :: DualNumber (Vector r) -> DualNumber r that have the property that f (D u u') = h u &&& g u' for some h and g? Exponential object in DualNumber category restricted to such functions is trivial to define element-wise, but many important functions are not of this form.
Appendix: some weak ideas
To implement map1, perhaps we'd need
map1 :: (Numeric r, Numeric s)
=> DualNumber (r -> s) -> DualNumber (Vector r) -> DualNumber (Vector s)
where DualNumber (r -> s) contains a function r -> s and a function s -> r that does something analogous (note the contravariance). But I think exponential objects in that category are nothing like that.
Or DualNumber (r -> s) could contain a function r -> s and a function Delta1 r -> Delta1 s (note the lack of contravariance; would the user provide that function?), giving
eval1 :: Vector s -> Delta1 s -> ST s ()
eval1 r = \case
Map1 (g :: Delta1 r -> Delta1 s) d -> eval1 r (g d)
This is a modest problem statement about map1 only, but it's still too complex. See the last comment in this ticket instead.
Function map1 is a case of the general problem of bad asymptotic performance of higher order functions due to their usage of indexing (the Index0 delta constructor here).
https://github.com/Mikolaj/horde-ad/blob/10149a18c07942f01525399520fab73348d57992/src/HordeAd/Core/DualNumber.hs#L312-L322
Below is an optimization of Index0 that helps (but not enough) in map1. This is a part of eval0, which is the delta-expressions evaluation function for rank 0.
eval0 :: r -> Delta0 r -> ST s ()
eval0 !r = \case
...
https://github.com/Mikolaj/horde-ad/blob/10149a18c07942f01525399520fab73348d57992/src/HordeAd/Internal/Delta.hs#L505-L514
The problem is stated in the code snippet above: this is still asymptotically quadratic in terms of memory even with the optimization shown there (which indeed helps with runtime).
Various variants of map1 (on rank 0, 1, 2, foldr, with f known to be +, etc.) are written as tests and benchmarked, but I haven't yet performed benchmarks on a collection of data points to verify asymptotic behaviour, because various kinds of noise would require a numerous collection of points and long runtimes.
The interesting part: normally when we encounter performance problems, we add custom delta-expression constructors. But here this is problematic, because the constructors would probably need to be higher-order and that probably doesn't fit our theoretical framework. The disjoint musings in the original ticket description are about that (and this ties in with Tom Smeding's setup that potentially permits differentiable function space types).
Below the line is a copy of scribbles in google doc related to this example.
Conclusion: Mikolaj will finish a concrete example in Tom’s style (perhaps just variations of SumElements, because we already have benchmarks for those, with and without specialized Delta constructors), with 1, 2, .. n and then try to extend in-place updating to fix both the runtime and memory asymptotic inefficiency in this simplified setting and report in the ticket.
Also, let’s find an example where the OneHot does not suffice (but TwoHot might, etc.). Here it is: for the case of Scale1 k d -> eval1 (k * r) d we’d need OneHotTimesK and similarly for Dot0 v vd -> eval1 (HM.scale r v) vd and then there’s Append1 d k e -> eval1 (V.take k r) d >> eval1 (V.drop k r) e and others. So we’d need a complex datatype, not just OneHot, just as we have for rank 2, see #18. However, regardless of the datatype we use, we’d still need to update in place or allocate records as many times as there are Index operations. OTOH, quite possibly updating in place during the evaluation of the scaling factor is easier than updating in place all the vectors assigned to delta-variables.
Re mutable values of the delta-variables, the first problem is that orthotope tensors are not mutable, so we’d need to store everything in mutable vectors and maintain indexing translations outside. Alternatively, thaw and freeze the vectors inside the current tensor implementation, but this is very risky, e.g., with sharing and either laziness or slices and also with any concurrency.
Actually, it turns out, the scaling factor needs to be of the same datatype as the values assigned to delta-variables. So making the latter mutable conflicts with #18. Probably we should make the datatype very symbolic and, when it’s finally converted to a concrete vector, matrix or tensor, do the conversion in one big mutable swoop. This should be beneficial independently of the higher order differentiable functions question, so we can revisit the question (#12) after it’s done. Comfortable benchmarking is needed to verify at each step the optimization doesn’t degrade performance in some areas for some random reasons.
Actually, again, the datatype generalising OneHot is not a solution, because for vectors of 10^8 elements, we would need to generate a value of the datatype that deep. Therefore in-place updating the vectors that are assigned to delta-variables seems to be the only solution. This should be possible, because we know not only how many delta variables are generated, but also the size of vectors assigned to individual variables (at least for the typed tensors case), so this can all be allocated up-front and only once. If so, the datatype for the scaling factor is probably not that important and instead we should focus on optimizations taking advantage of all subterms of delta-terms being variables (once we strictly follow how the paper does it, which makes sense), such as the one proposed by Simon for Index0.
However, the prohibitive cost appears even earlier, when, e.g., while summing a vector, we generate delta-expressions of the form Add0 (Index0 (Var1 …) …) (Add0 (Index0 (Var1 …) …) (Add0 …)) with as many Index0 terms as there are elements in, say, a 10^8-element vector. That’s again the story of an unboxed vector vs a boxed representation of the same or some corresponding data, which is ten or twenty times larger and also can’t be processed using low level bulk operations. This is not an asymptotic degradation, but it’s a large multiplicative constant applied to the existing performance bottleneck (memory at the end of forward pass, due to the large size of initial data and/or of parameters reflected in the size of the tape).
A modest problem statement about map1M only
This implementation of maping over a rank 1 differentiable value (a vector) has a prohibitive cost in terms of the size of the stored delta expression.
https://github.com/Mikolaj/horde-ad/blob/8d83a16d43224e992cdc5cbc25c6cc555b90611d/src/HordeAd/Core/DualNumber.hs#L325-L333
To simplify, let's focus on a similar problem (for which we have benchmarks) of summing a vector (that would be foldr1M, not map1M). There we generate delta-expressions of the form Add0 (Index0 (Var1 …) …) (Add0 (Index0 (Var1 …) …) (Add0 …)) with as many Index0 terms as there are elements in, say, a 10^8-element vector. That’s again the story of an unboxed vector vs a boxed representation of the same or some corresponding data, which is ten or twenty times larger and also can’t be processed using low level bulk operations. This is not an asymptotic degradation, but it’s a large multiplicative constant applied to the existing performance bottleneck (memory at the end of forward pass, due to the large size of initial data and/or of parameters reflected in the size of the tape).
Updating in place doesn't help here. Fancy datatypes for to represent the scaling factor don't help. Noticing that with the translation from the paper, the argument delta expresison of Index0 is always a variable doesn't help.
The interesting part is that normally when we encounter performance problems, we add custom delta-expression constructors. But here this is not easy, because the constructors corresponding to map1M (or foldr1M) would probably need to be higher-order and that probably doesn't fit our theoretical framework. The disjoint musings in the original ticket description are about that (and this ties in with Tom Smeding's setup that potentially permits differentiable function space types).
Related (the first few very closely) links from Tom Smeding:
https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1 https://github.com/tomsmeding/accelerate/tree/no-explode/src/Data/Array/Accelerate/Trafo/AD https://dl.acm.org/doi/pdf/10.1145/3341701 https://arxiv.org/abs/2207.03418 https://github.com/tomsmeding/ad-dualrev-th https://arxiv.org/abs/2103.15776
The last link in that list is to CHAD, which is a very interesting algorithm with the following upsides:
- It has a neat categorical foundation.
- No ID generation, so no forced sequentiality: parallelism is preserved.
- Higher-order functions are differentiated naturally.
- Neat, not-too-complicated (though more complicated than the dual-numbers algorithms) source-code transformation.
And the following downsides:
- It exhibits exponential blowup in the number of nested lambda expressions.
- This is (probably -- still need to write this out, but discussed to satisfaction with Matthijs) "fixable" by performing closure conversion, i.e. bundling every function object with its closure, thereby making it a closed function. Downside is that the resulting program cannot be type checked in any reasonable type system, because you get multiple existentials in disparate parts of the program that end up needing to be equal -- and said equality is forgotten by the existential wrappers. In an untyped setting, it could work though. (Any type system theorists wanting to take a crack at this?)
- It badly needs treatment for excessive use of
+and0at large types.- This can (probably -- same story) be fixed using clever Cayley transformation (i.e. the difference lists trick: representing a monoid M using functions M -> M, where m : M maps to
\m' -> m <> m').
- This can (probably -- same story) be fixed using clever Cayley transformation (i.e. the difference lists trick: representing a monoid M using functions M -> M, where m : M maps to
The other links are things we discussed in the meeting:
- my Accelerate thesis (pdf, code) handles second-order array combinators if the scalar-level expressions don't have loops. Furthermore, array indexing in the combination function of
foldis not supported, and supportingscanrequires a little help from the underlying implementation; the derivative ofscanis not hard to write down, but doesn't express cleanly in parallelfold/scan/mapetc. The main contribution of the thesis is how array indexing can be handled, and compiled to plain second-order array programming code, in this restricted case: only in generate/map/zipWith, and not in an expression-level loop. - The paper by Shaikhha et al. basically attempts to do statically what we (i.e. you and us) do dynamically. It breaks down (i.e. degrades to vectorised-forward-AD complexity) in the presence of dynamic control flow or sufficiently complicated index arithmetic.