equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Control argument donation in `filter_jit` when first argument is `self`

Open davidmarttila opened this issue 2 months ago • 1 comments

Hello,

first of all thank you for maintaining this fantastic library!

I was wondering if there is a generally recommended pattern to deal with buffer donation in a case like this:

class Example(eqx.Module):
  
  def do_stuff(self, x, aux):
    x = self.loads_of_logic(x, aux)
    return x

where I'd want to donate x but keep and later reuse aux. Right now I'm using a helper method to get rid of the first self and reorder the arguments like this:

model = Example()
def do_stuff(aux, x):
  return model.do_stuff(x, aux)

do_stuff_jit = eqx.filter_jit(do_stuff, donate="all-except-first")

It works but it is a bit awkward when I need to call do_stuff from multiple different places in the code. Is there a better way to do this? Or would it be possible at all to support a finer-grained donate_argnums in filter_jit in the future?

davidmarttila avatar Oct 21 '25 15:10 davidmarttila

I think probably something like this is what you're looking for:

@eqx.filter_jit(donate="all-except-first")
def _do_stuff(aux, x, example):
    return example.loads_of_logic(x, aux)

class Example(eqx.Module):
    def do_stuff(self, x, aux):
        return _do_stuff(aux, x, self)

so that in particular do_stuff can be called from multiple places without requiring a wrapper in each place.

This is a bit verbose, but I think the scenario you're looking at here doesn't come up very often – at least, it hasn't come up for me before. A lot of code only has a single JIT decorator around some top-level function (the bit where we switch from 'normal software' to 'pile of mathematics'), and in my experience that function is usually a normal function rather than a method. :)

patrick-kidger avatar Oct 24 '25 23:10 patrick-kidger