containers icon indicating copy to clipboard operation
containers copied to clipboard

foldTree does not optimize well

Open meooow25 opened this issue 2 years ago • 2 comments

We have in Data.Tree

foldTree :: (a -> [b] -> b) -> Tree a -> b 

Unfortunately foldTree does not optimize as well it could when b is a function.

An an example, consider that we want to calculate the sum of depths of nodes in a tree. We can write a recursive function manually:

depthSum_rec :: Tree a -> Int
depthSum_rec t = go t 0 0 where
    go (Node _ ts) depth acc = foldl' (\acc' t' -> go t' (depth+1) acc') (acc + depth) ts

--    depthSum_rec: OK (0.18s)
--      5.18 ms ± 508 μs

Now let's use foldTree:

depthSum_foldTree :: Tree a -> Int
depthSum_foldTree t = foldTree f t 0 0 where
    f _ ks depth acc = foldl' (\acc' k -> k (depth+1) acc') (acc + depth) ks

--    depthSum_foldTree: OK (0.34s)
--      43.6 ms ± 3.2 ms

That's a lot worse! The problem is that the list of partially applied functions [b] is manifested, see GHC#23319. According to SPJ this can't be easily improved.


Consider a different fold function which also folds over the [b] without creating it:

foldTree2 :: (a -> b -> c) -> (c -> b -> b) -> b -> Tree a -> c
foldTree2 f c z = go where go (Node x ts) = f x (foldr (c . go) z ts)

Now we can write:

depthSum_foldTree2 :: Tree a -> Int
depthSum_foldTree2 t = foldTree2 f f' (const id) t 0 0 where
    f _ k depth acc = k depth (acc + depth)
    f' k1 k2 depth acc = k2 depth (k1 (depth+1) acc)

--    depthSum_foldTree2: OK (0.23s)
--      5.16 ms ± 376 μs

As good as depthSum_rec! Could we have foldTree2 (perhaps with a better name) in Data.Tree?


The benchmark setup, for completeness
import Data.List
import Data.Tree
import Test.Tasty.Bench

main :: IO ()
main = defaultMain
    [ env (pure binTree) $ \t -> bgroup ""
        [ bench "depthSum_rec" $ whnf depthSum_rec t
        , bench "depthSum_foldTree" $ whnf depthSum_foldTree t
        , bench "depthSum_foldTree2" $ whnf depthSum_foldTree2 t
        ]
    ]

binTree :: Tree Int
binTree = unfoldTree (\x -> (x, takeWhile (<1000000) [2*x + 1, 2*x + 2])) 1

meooow25 avatar Apr 30 '23 16:04 meooow25

The type of the function gives me no clue what it does. That makes me a bit suspicious. Can you write documentation that makes it easy for people to think about? What benefit does this have over doing the folding by hand?

treeowl avatar Apr 30 '23 17:04 treeowl

The type of the function gives me no clue what it does.

It is just the replacement of all the constructors involved in a Tree. So foldTree2 Node (:) [] = id.

What benefit does this have over doing the folding by hand?

It lets us avoid writing a recursive function, which often ends up shorter or simpler. Another benefit is that it could participate in fold/build fusion. I have been thinking about this a bit, but perhaps it deserves a separate issue.

meooow25 avatar Apr 30 '23 17:04 meooow25

Closing since foldTree2 is probably not worth it.

Good performance can also be obtained from foldTree using oneShot as shown in https://gitlab.haskell.org/ghc/ghc/-/issues/23319#note_499967.

meooow25 avatar May 18 '25 19:05 meooow25