base
base copied to clipboard
Monad.Make() generates non-tail-recursive `all` and `all_unit`
We have a custom logging-ish monad and just ran into an issue on OS X where we ran out of stack space calling Our_monad.all
on a list with ~100,000 elements.
I noticed that Monad.Make
generates this:
https://github.com/janestreet/base/blob/master/src/monad.ml#L59
It's structured in basically the way you'd write a tail-recursive loop, but the >>=
prevents it from being tail-recursive.
It's not entirely clear to me if it's possible to make this tail recursive but it's worth thinking about.
(Our workaround was adding custom all
and all_unit
functions)
I tried messing around with it, but it seems hard to get a tail-recursive implementation of all
for different kinds of bind
(e.g., it changes depending on whether bind calls its function argument before returning or whether that function is supposed to be called later). Is your bind
function itself tail-recursive?
My code looks like this:
type 'a t = 'a * Error_log.t
let create value log =
value, log
let merge_logs (value, log_a) ~error_log:log_b =
let log = Error_log.merge log_a log_b in
create value log
include Monad.Make (struct
type 'a t = 'a * Error_log.t
let return x =
create x Error_log.empty
let map (value, log) ~f =
f value, log
let map = `Custom map
let bind (value, error_log) ~f =
f value
|> merge_logs ~error_log
end)
I came up with this version of all
that seems to use a constant amount of stack space at the expense of calling bind
twice as often:
let all =
let rec loop acc = function
| [] ->
map acc ~f:List.rev
| t :: ts ->
let acc =
acc >>= fun acc ->
t >>| fun v ->
v :: acc
in
loop acc ts
in
fun ts ->
loop (return []) ts
I'm not sure if I'm properly taking the various things bind
can do into account though. Are there situations where this is worse (in terms of stack space)?
Yeah, this can blow the stack if you have a monad whose bind defers applying its argument, and then you try to do something like evaluate the end result of the monad:
(* Monad whose bind defers calling [f] *)
module M' = struct
module T = struct
type 'a t =
| Return : 'a -> 'a t
| Bind : 'a t * ('a -> 'b t) -> 'b t
let return x = Return x
let bind t ~f = Bind (t, f)
let map = `Define_using_bind
end
include T
include Monad.Make (T)
let rec eval : 'a. 'a t -> 'a =
fun (type a) (t : a t) ->
match t with
| Return x -> x
| Bind (t, f) -> eval (f (eval t))
;;
end
In this case, the alternative all
(something like the code below) causes a stack overflow where the original one was stack-safe.
let all_tailrec ts =
let rec loop vs = function
| [] -> vs >>| List.rev
| t :: ts ->
loop (t >>= fun v -> vs >>| fun vs -> v :: vs) ts
in
loop (return []) ts
I suspect this all_divide_and_conquer
will use log-limited stack space in both cases, by the way (even though it's probably slower than both):
module Make(M : Monad.S) : sig
val all_divide_and_conquer : 'a M.t list -> 'a list M.t
end = struct
let all_divide_and_conquer =
fun ts ->
List.map ts ~f:(M.map ~f:(fun x -> [x]))
|> List.reduce_balanced ~f:(fun a b ->
M.bind a ~f:(fun a ->
M.map b ~f:(fun b -> (a @ b))))
|> function
| None -> M.return []
| Some res -> res
end
Gack, I keep misclicking. Sorry for the spurious notifications.
(Not sure if you're still wanting my opinion on this, but I'd be in favor of making the standard version as safe as possible and assuming if people want to hyper-optimize this, they won't be using the built-in implementation anyway)