glistix icon indicating copy to clipboard operation
glistix copied to clipboard

Tail call optimization

Open PgBiel opened this issue 10 months ago • 1 comments

Opening this issue to gather information on potentially generating functions with TCO through Glistix to allow deeper recursion.

While proper support would require some help from upstream, there are some workarounds we can try.

References:

  • Inspiration for a potential workaround: https://discourse.nixos.org/t/tail-call-optimization-in-nix-today/17763
  • Upstream issue: nix#8430

PgBiel avatar Apr 12 '25 23:04 PgBiel

The solution proposed in the linked Discourse post, without modifications, is generally impractical as it is too slow (demonstrated later below), presumably due to the hacks for functions with any amount of arguments and the multiple function calls involved in general.

However, I've found a minimized version which, with some cooperation from the original function, is as fast - and, in the case of deeper recursion, sometimes even faster - than the original function:

let
  tco = f: o:
    o.done
    or (let
        closure =
          builtins.genericClosure {
            startSet = [{ key = 0; returned = o; }];
            operator = (item: let
              returned = item.returned;
              result =
                if returned ? done
                then []
                else [{ key = item.key + 1; returned = returned.call f; }];
            in result);
          };
        last = builtins.elemAt closure (builtins.length closure - 1);
      in last.returned.done);

  # example without TCO
  countDown = n: if n == 0 then n else countDown (n - 1);

  # example with TCO
  countDown2 = let
    f = n: if n == 0 then { done = n; } else { call = f: f (n - 1); };
  in n: tco f (f n);
in { inherit tco countDown countDown2; }

We can show that it produces the desired effect:

$ nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown 999999'
error:
       … while evaluating a branch condition
         at /tmp/tco/ex.nix:20:18:
           19|   # example without TCO
           20|   countDown = n: if n == 0 then n else countDown (n - 1);
             |                  ^
           21|

       error: stack overflow; max-call-depth exceeded
       at /tmp/tco/ex.nix:20:53:
           19|   # example without TCO
           20|   countDown = n: if n == 0 then n else countDown (n - 1);
             |                                                     ^
           21|

$ nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown2 999999'
0

The non-TCO function works for up to 9999 recursive calls, so we can run hyperfine on both to unscientifically compare their performance at the highest recursion depth. Here's the output on my local machine:

$ hyperfine "nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown 9999'" --warmup 100
Benchmark 1: nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown 9999'
  Time (mean ± σ):      29.7 ms ±   1.0 ms    [User: 12.7 ms, System: 16.6 ms]
  Range (min … max):    28.4 ms …  33.6 ms    96 runs

$ hyperfine "nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown2 9999'" --warmup 100
Benchmark 1: nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown2 9999'
  Time (mean ± σ):      27.4 ms ±   0.9 ms    [User: 16.3 ms, System: 10.8 ms]
  Range (min … max):    26.0 ms …  32.4 ms    101 runs

The TCO version is faster for this function: it adds less overhead than it takes away overhead from filling up the stack. And it's still pretty much just as fast - when not faster - on very low depths as well:

Depth 0

$ hyperfine "nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown 0'" --warmup 100
Benchmark 1: nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown 0'
  Time (mean ± σ):      18.7 ms ±   0.6 ms    [User: 10.3 ms, System: 8.2 ms]
  Range (min … max):    17.4 ms …  20.3 ms    145 runs

$ hyperfine "nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown2 0'" --warmup 100
Benchmark 1: nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown2 0'
  Time (mean ± σ):      18.7 ms ±   0.7 ms    [User: 10.4 ms, System: 8.1 ms]
  Range (min … max):    17.4 ms …  21.3 ms    138 runs
Depth 2

$ hyperfine "nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown 2'" --warmup 100
Benchmark 1: nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown 2'
  Time (mean ± σ):      18.9 ms ±   0.9 ms    [User: 10.3 ms, System: 8.4 ms]
  Range (min … max):    17.5 ms …  22.9 ms    143 runs

$ hyperfine "nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown2 2'" --warmup 100
Benchmark 1: nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown2 2'
  Time (mean ± σ):      18.8 ms ±   0.7 ms    [User: 10.3 ms, System: 8.2 ms]
  Range (min … max):    17.4 ms …  21.2 ms    141 runs

We can see a similar pattern for the emphasize function in the Discourse post, which constructs a very large string on the stack, where the TCO version is approximately just as fast as the original function.

Code

let
  tco = f: o:
    o.done
    or (let
        closure =
          builtins.genericClosure {
            startSet = [{ key = 0; returned = o; }];
            operator = (item: let
              returned = item.returned;
              result =
                if returned ? done
                then []
                else [{ key = item.key + 1; returned = returned.call f; }];
            in result);
          };
        last = builtins.elemAt closure (builtins.length closure - 1);
      in last.returned.done);

  # example without TCO
  emphasize = acc: before: after: n:
      if n == 0
      then before + " " + acc + after
      else emphasize (acc + "very ") before after (n - 1);

  # example with TCO
  emphasize2 = let
    f = acc: before: after: n:
    if n == 0
    then { done = before + " " + acc + after; }
    else { call = f: f (acc + "very ") before after (n - 1); };
  in acc: before: after: n: tco f (f acc before after n);
in { inherit tco emphasize emphasize2; }
Larger max recursion depth

$ nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize "" "This is" "cool" 50000'
error:
       … while evaluating a branch condition
         at /tmp/tco/ex2.nix:21:7:
           20|   emphasize = acc: before: after: n:
           21|       if n == 0
             |       ^
           22|       then before + " " + acc + after

       error: stack overflow; max-call-depth exceeded
       at /tmp/tco/ex2.nix:23:54:
           22|       then before + " " + acc + after
           23|       else emphasize (acc + "very ") before after (n - 1);
             |                                                      ^
           24|

$ nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize2 "" "This is" "cool" 50000'
"This is very very [ ... 50000x ... ] very cool"
Depth 9999

$ hyperfine "nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize \"\" \"This is\" \"cool\" 9999'" --warmup 100
Benchmark 1: nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize "" "This is" "cool" 9999'
  Time (mean ± σ):     154.7 ms ±   1.5 ms    [User: 31.5 ms, System: 121.9 ms]
  Range (min … max):   152.5 ms … 157.4 ms    19 runs

$ hyperfine "nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize2 \"\" \"This is\" \"cool\" 9999'" --warmup 100
Benchmark 1: nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize2 "" "This is" "cool" 9999'
  Time (mean ± σ):     156.0 ms ±   3.1 ms    [User: 36.2 ms, System: 118.6 ms]
  Range (min … max):   152.3 ms … 165.5 ms    18 runs

TCO version is slightly slower.

Depth 2

$ hyperfine "nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize \"\" \"This is\" \"cool\" 2'" --warmup 100
Benchmark 1: nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize "" "This is" "cool" 2'
  Time (mean ± σ):      18.7 ms ±   0.6 ms    [User: 10.3 ms, System: 8.2 ms]
  Range (min … max):    17.5 ms …  21.0 ms    146 runs

$ hyperfine "nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize2 \"\" \"This is\" \"cool\" 2'" --warmup 100
Benchmark 1: nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize2 "" "This is" "cool" 2'
  Time (mean ± σ):      18.6 ms ±   0.7 ms    [User: 10.1 ms, System: 8.3 ms]
  Range (min … max):    17.5 ms …  21.4 ms    143 runs

Unfortunately, we still hit a stack overflow quite early for this last function, as noted in the Discourse post. Presumably, the string itself is getting too large.

nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize2 "" "This is" "cool" 70000'
error: stack overflow (possible infinite recursion)

The problem with diagnostics

Which leads us to our next point: diagnostics. Infinite recursion with TCO does not at all show where it might have occurred, unlike the built-in recursion detector which is used without TCO, as seen above. This is pretty annoying for debugging purposes. As a workaround, we could have a Glistix flag to temporarily disable TCO. In addition, there could be a separate flag to impose a maximum recursion depth for debugging purposes.

Comparison with the Discourse post's solution

  1. For countDown:
Code

let
  tco = f: o:
    o.done
    or (let
        closure =
          builtins.genericClosure {
            startSet = [{ key = 0; returned = o; }];
            operator = (item: let
              returned = item.returned;
              result =
                if returned ? done
                then []
                else [{ key = item.key + 1; returned = returned.call f; }];
            in result);
          };
        last = builtins.elemAt closure (builtins.length closure - 1);
      in last.returned.done);

  # Discourse post version
  tco2 = let
      lib = (import <nixpkgs> {}).lib;
      argCount = f:
          let
            # N.B. since we are only interested if the result of calling is a function
            # as opposed to a normal value or evaluation failure, we never need to
            # check success, as value will be false (i.e. not a function) in the
            # failure case.
            called = builtins.tryEval (
              f (builtins.throw "You should never see this error message")
            );
          in
          if !(builtins.isFunction f || builtins.isFunction (f.__functor or null))
          then 0
          else 1 + argCount called.value;
      unapply =
          let
              unapply' = acc: n: f: x:
              if n == 1
              then f (acc ++ [ x ])
              else unapply' (acc ++ [ x ]) (n - 1) f;
          in
          unapply' [ ];
      apply = f: args: builtins.foldl' (f: x: f x) f args;
      tailCallOpt = f:
          let
            argc = argCount (lib.fix f);

            # This function simulates being f for f's self reference. Instead of
            # recursing, it will just return the arguments received as a specially
            # tagged set, so the recursion step can be performed later.
            fakef = unapply argc (args: {
              __tailCall = true;
              inherit args;
            });
            # Pass fakef to f so that it'll be called instead of recursing, ensuring
            # only one recursion step is performed at a time.
            encodedf = f fakef;

            # This is the main function, implementing the “optimized” recursion
            opt = args:
              let
                steps = builtins.genericClosure {
                  # This is how we encode a (tail) call: A set with final == false
                  # and the list of arguments to pass to be found in args.
                  startSet = [
                    {
                      key = "0";
                      id = 0;
                      final = false;
                      inherit args;
                    }
                  ];

                  operator =
                    { id, final, ... }@state:
                    let
                      # Generate a new, unique key to make genericClosure happy
                      newIds = {
                        key = toString (id + 1);
                        id = id + 1;
                      };

                      # Perform recursion step
                      call = apply encodedf state.args;

                      # If call encodes a new call, return the new encoded call,
                      # otherwise signal that we're done.
                      newState =
                        if builtins.isAttrs call && call.__tailCall or false
                        then newIds // {
                          final = false;
                          inherit (call) args;
                        } else newIds // {
                          final = true;
                          value = call;
                        };
                    in

                    if final
                    then [ ] # end condition for genericClosure
                    else [ newState ];
                };
              in
              # The returned list contains intermediate steps we need to ignore
              (builtins.head (builtins.filter (x: x.final) steps)).value;
          in
          # make it look like a normal function again
          unapply argc opt;
    in tailCallOpt;

  # example without TCO
  countDown = n: if n == 0 then n else countDown (n - 1);

  # example with our proposed TCO
  countDown2 = let
    f = n: if n == 0 then { done = n; } else { call = f: f (n - 1); };
  in n: tco f (f n);

  # example with Discourse post's TCO
  countDown3 = let
    f = f: n: if n == 0 then n else f (n - 1);
  in tco2 f;
in { inherit tco countDown countDown2 countDown3; }
Larger max recursion depth

$ nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown3 999999'
0
Depth 9999

$ hyperfine "nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown3 9999'" --warmup 100
Benchmark 1: nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown3 9999'
  Time (mean ± σ):     292.6 ms ±   4.3 ms    [User: 151.6 ms, System: 72.6 ms]
  Range (min … max):   287.2 ms … 302.8 ms    10 runs

Much slower than our previous benchmarks of around 30ms.

Depth 2

$ hyperfine "nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown3 2'" --warmup 100
Benchmark 1: nix eval --impure --expr 'let x = import ./ex.nix; in x.countDown3 2'
  Time (mean ± σ):     278.4 ms ±   6.1 ms    [User: 138.4 ms, System: 71.2 ms]
  Range (min … max):   272.8 ms … 294.1 ms    10 runs

Impractically slow for this purpose.

  1. For emphasize:
Code


let
  tco = f: o:
    o.done
    or (let
        closure =
          builtins.genericClosure {
            startSet = [{ key = 0; returned = o; }];
            operator = (item: let
              returned = item.returned;
              result =
                if returned ? done
                then []
                else [{ key = item.key + 1; returned = returned.call f; }];
            in result);
          };
        last = builtins.elemAt closure (builtins.length closure - 1);
      in last.returned.done);


  tco2 = let
    lib = (import <nixpkgs> {}).lib;
    argCount = f:
        let
          # N.B. since we are only interested if the result of calling is a function
          # as opposed to a normal value or evaluation failure, we never need to
          # check success, as value will be false (i.e. not a function) in the
          # failure case.
          called = builtins.tryEval (
            f (builtins.throw "You should never see this error message")
          );
        in
        if !(builtins.isFunction f || builtins.isFunction (f.__functor or null))
        then 0
        else 1 + argCount called.value;
    unapply =
        let
            unapply' = acc: n: f: x:
            if n == 1
            then f (acc ++ [ x ])
            else unapply' (acc ++ [ x ]) (n - 1) f;
        in
        unapply' [ ];
    apply = f: args: builtins.foldl' (f: x: f x) f args;
    tailCallOpt = f:
        let
          argc = argCount (lib.fix f);

          # This function simulates being f for f's self reference. Instead of
          # recursing, it will just return the arguments received as a specially
          # tagged set, so the recursion step can be performed later.
          fakef = unapply argc (args: {
            __tailCall = true;
            inherit args;
          });
          # Pass fakef to f so that it'll be called instead of recursing, ensuring
          # only one recursion step is performed at a time.
          encodedf = f fakef;

          # This is the main function, implementing the “optimized” recursion
          opt = args:
            let
              steps = builtins.genericClosure {
                # This is how we encode a (tail) call: A set with final == false
                # and the list of arguments to pass to be found in args.
                startSet = [
                  {
                    key = "0";
                    id = 0;
                    final = false;
                    inherit args;
                  }
                ];

                operator =
                  { id, final, ... }@state:
                  let
                    # Generate a new, unique key to make genericClosure happy
                    newIds = {
                      key = toString (id + 1);
                      id = id + 1;
                    };

                    # Perform recursion step
                    call = apply encodedf state.args;

                    # If call encodes a new call, return the new encoded call,
                    # otherwise signal that we're done.
                    newState =
                      if builtins.isAttrs call && call.__tailCall or false
                      then newIds // {
                        final = false;
                        inherit (call) args;
                      } else newIds // {
                        final = true;
                        value = call;
                      };
                  in

                  if final
                  then [ ] # end condition for genericClosure
                  else [ newState ];
              };
            in
            # The returned list contains intermediate steps we need to ignore
            (builtins.head (builtins.filter (x: x.final) steps)).value;
        in
        # make it look like a normal function again
        unapply argc opt;
  in tailCallOpt;

  # example without TCO
  emphasize = acc: before: after: n:
      if n == 0
      then before + " " + acc + after
      else emphasize (acc + "very ") before after (n - 1);

  # example with our proposed TCO
  emphasize2 = let
    f = acc: before: after: n:
    if n == 0
    then { done = before + " " + acc + after; }
    else { call = f: f (acc + "very ") before after (n - 1); };
  in acc: before: after: n: tco f (f acc before after n);

  # example with Discourse post's TCO
  emphasize3 = let
      f = self: acc: before: after: n:
      if n == 0
      then before + " " + acc + after
      else self (acc + "very ") before after (n - 1);
    in tco2 f;
in { inherit tco emphasize emphasize2 emphasize3; }
Larger max recursion depth

$ nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize3 "" "This is" "cool" 50000'
"This is very very [ ... 50000x ... ] very cool"

Hits the same limit of about 60 to 70k calls.

Depth 9999

$ hyperfine "nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize3 9999'" --warmup 100
Benchmark 1: nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize3 9999'
  Time (mean ± σ):     277.5 ms ±   4.4 ms    [User: 139.6 ms, System: 69.0 ms]
  Range (min … max):   273.6 ms … 286.0 ms    10 runs

Once again considerably slower.

Depth 2

$ hyperfine "nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize3 2'" --warmup 100
Benchmark 1: nix eval --impure --expr 'let x = import ./ex2.nix; in x.emphasize3 2'
  Time (mean ± σ):     276.7 ms ±   2.0 ms    [User: 134.8 ms, System: 73.5 ms]
  Range (min … max):   273.1 ms … 279.1 ms    10 runs

Same.

We notice that our proposed TCO solution would be nearly transparent, sans the diagnostics problem.

Is it worth it?

In conclusion: with forced TCO, we lose helpful diagnostics about infinite recursion, which can suck. But we get faster execution in some cases, or otherwise pretty much just as fast, with our solution.

I think the best way forward is to implement compilation of tail calls to the nix target using this solution, but under a toggle. It could start off by default until we're sure this is what we want.

In addition, we should be able to compile in artificially limited recursion to help debug errors.

PgBiel avatar Apr 13 '25 00:04 PgBiel