foldTree does not optimize well
We have in Data.Tree
foldTree :: (a -> [b] -> b) -> Tree a -> b
Unfortunately foldTree does not optimize as well it could when b is a function.
An an example, consider that we want to calculate the sum of depths of nodes in a tree. We can write a recursive function manually:
depthSum_rec :: Tree a -> Int
depthSum_rec t = go t 0 0 where
go (Node _ ts) depth acc = foldl' (\acc' t' -> go t' (depth+1) acc') (acc + depth) ts
-- depthSum_rec: OK (0.18s)
-- 5.18 ms ± 508 μs
Now let's use foldTree:
depthSum_foldTree :: Tree a -> Int
depthSum_foldTree t = foldTree f t 0 0 where
f _ ks depth acc = foldl' (\acc' k -> k (depth+1) acc') (acc + depth) ks
-- depthSum_foldTree: OK (0.34s)
-- 43.6 ms ± 3.2 ms
That's a lot worse! The problem is that the list of partially applied functions [b] is manifested, see GHC#23319. According to SPJ this can't be easily improved.
Consider a different fold function which also folds over the [b] without creating it:
foldTree2 :: (a -> b -> c) -> (c -> b -> b) -> b -> Tree a -> c
foldTree2 f c z = go where go (Node x ts) = f x (foldr (c . go) z ts)
Now we can write:
depthSum_foldTree2 :: Tree a -> Int
depthSum_foldTree2 t = foldTree2 f f' (const id) t 0 0 where
f _ k depth acc = k depth (acc + depth)
f' k1 k2 depth acc = k2 depth (k1 (depth+1) acc)
-- depthSum_foldTree2: OK (0.23s)
-- 5.16 ms ± 376 μs
As good as depthSum_rec! Could we have foldTree2 (perhaps with a better name) in Data.Tree?
The benchmark setup, for completeness
import Data.List
import Data.Tree
import Test.Tasty.Bench
main :: IO ()
main = defaultMain
[ env (pure binTree) $ \t -> bgroup ""
[ bench "depthSum_rec" $ whnf depthSum_rec t
, bench "depthSum_foldTree" $ whnf depthSum_foldTree t
, bench "depthSum_foldTree2" $ whnf depthSum_foldTree2 t
]
]
binTree :: Tree Int
binTree = unfoldTree (\x -> (x, takeWhile (<1000000) [2*x + 1, 2*x + 2])) 1
The type of the function gives me no clue what it does. That makes me a bit suspicious. Can you write documentation that makes it easy for people to think about? What benefit does this have over doing the folding by hand?
The type of the function gives me no clue what it does.
It is just the replacement of all the constructors involved in a Tree. So foldTree2 Node (:) [] = id.
What benefit does this have over doing the folding by hand?
It lets us avoid writing a recursive function, which often ends up shorter or simpler. Another benefit is that it could participate in fold/build fusion. I have been thinking about this a bit, but perhaps it deserves a separate issue.
Closing since foldTree2 is probably not worth it.
Good performance can also be obtained from foldTree using oneShot as shown in https://gitlab.haskell.org/ghc/ghc/-/issues/23319#note_499967.