Control argument donation in `filter_jit` when first argument is `self`
Hello,
first of all thank you for maintaining this fantastic library!
I was wondering if there is a generally recommended pattern to deal with buffer donation in a case like this:
class Example(eqx.Module):
def do_stuff(self, x, aux):
x = self.loads_of_logic(x, aux)
return x
where I'd want to donate x but keep and later reuse aux. Right now I'm using a helper method to get rid of the first self and reorder the arguments like this:
model = Example()
def do_stuff(aux, x):
return model.do_stuff(x, aux)
do_stuff_jit = eqx.filter_jit(do_stuff, donate="all-except-first")
It works but it is a bit awkward when I need to call do_stuff from multiple different places in the code. Is there a better way to do this? Or would it be possible at all to support a finer-grained donate_argnums in filter_jit in the future?
I think probably something like this is what you're looking for:
@eqx.filter_jit(donate="all-except-first")
def _do_stuff(aux, x, example):
return example.loads_of_logic(x, aux)
class Example(eqx.Module):
def do_stuff(self, x, aux):
return _do_stuff(aux, x, self)
so that in particular do_stuff can be called from multiple places without requiring a wrapper in each place.
This is a bit verbose, but I think the scenario you're looking at here doesn't come up very often – at least, it hasn't come up for me before. A lot of code only has a single JIT decorator around some top-level function (the bit where we switch from 'normal software' to 'pile of mathematics'), and in my experience that function is usually a normal function rather than a method. :)