feat: clear jax cache on register
The idea is to "structurally" identify jax. jax.jit produces as jaxlib object with attribute clear_cache. That's sufficient to identify JAX jitted functions. I've put notes for longer-term plans.
Needs tests.
Pull Request Test Coverage Report for Build 9419781118
Details
- 1 of 2 (50.0%) changed or added relevant lines in 1 file are covered.
- No unchanged relevant lines lost coverage.
- Overall coverage decreased (-0.08%) to 99.839%
| Changes Missing Coverage | Covered Lines | Changed/Added Lines | % |
|---|---|---|---|
| plum/dispatcher.py | 1 | 2 | 50.0% |
| <!-- | Total: | 1 | 2 |
| Totals | |
|---|---|
| Change from base Build 9340121872: | -0.08% |
| Covered Lines: | 1237 |
| Relevant Lines: | 1239 |
💛 - Coveralls
Pull Request Test Coverage Report for Build 9419813345
Details
- 1 of 2 (50.0%) changed or added relevant lines in 1 file are covered.
- No unchanged relevant lines lost coverage.
- Overall coverage decreased (-0.08%) to 99.839%
| Changes Missing Coverage | Covered Lines | Changed/Added Lines | % |
|---|---|---|---|
| plum/dispatcher.py | 1 | 2 | 50.0% |
| <!-- | Total: | 1 | 2 |
| Totals | |
|---|---|
| Change from base Build 9340121872: | -0.08% |
| Covered Lines: | 1237 |
| Relevant Lines: | 1239 |
💛 - Coveralls
Pull Request Test Coverage Report for Build 9419998190
Details
- 1 of 2 (50.0%) changed or added relevant lines in 1 file are covered.
- No unchanged relevant lines lost coverage.
- Overall coverage decreased (-0.08%) to 99.839%
| Changes Missing Coverage | Covered Lines | Changed/Added Lines | % |
|---|---|---|---|
| plum/dispatcher.py | 1 | 2 | 50.0% |
| <!-- | Total: | 1 | 2 |
| Totals | |
|---|---|
| Change from base Build 9340121872: | -0.08% |
| Covered Lines: | 1237 |
| Relevant Lines: | 1239 |
💛 - Coveralls
@wesselb @PhilipVinc what do you think of this simple solution? There's definitely room for improvement, as indicated in the comment, but I think since this is private API it's good enough to get in as is, then iterate. WDYT?
I would be a bit against this.
The thing you proposed works when directly jitting a dispatch function, like
@dispatch
@jit
def myfun(a: jax.Array, b:int):
return a * b
@dispatch
@jit
def myfun(a: jax.Array, b:float):
return a + b
however it does not detect usages like
@dispatch
def myfun(a: jax.Array, b:int):
return a * b
@dispatch
def myfun(a: jax.Array, b:float):
return a + b
@jit
def my_algorithm(a, b):
return 12 * myfun(a, b)
# this causes jax to jit,
my_algorithm(a, b)
@dispatch
def myfun(a: jax.Array, b:float):
return a + b + 10
# jax has no idea that the function changed, so he will return same result as before
my_algorithm(a, b)
This also applies to your function if inside a jitted function you call other dispatch functions.
This is the use-case that should be addressed in my opinion, otherwise it's too brittle.
@PhilipVinc some really good points!
Just for completeness, this also helps address the case
@jit
@dispatch
def myfun(a: jax.Array, b:int):
return a * b
@jit
@dispatch
def myfun(a: jax.Array, b:float):
return a + b
(The decorator order is important jit-wise)
But your are correct that this doesn't work for "higher-order" functions that call other functions which themselves are multiply-dispactched, as you show later in your comment.
@wesselb, maybe is this related to your suggestion of a cache ? Or is this a separate complication?
How does your PR work with jit(dispatch(...)) ?
the dispatch logic is executed before the hitting happens so I do not understand how it would work correctly in this case?
See the opening comment of https://github.com/beartype/plum/issues/154#issue-2329848202.
jit compiles away the plum dispatching, which is great for speeding up the code. Prior to this PR, when a new dispatch was added JAX did not update the plum dispatch it was going to use, keeping its internally compiled dispatch. This is exactly what you point out in my_algorithm, but applies here to myfun. With this PR when a new dispatch is added to myfun it wipes the cache and forces the next usage to JIT again, thus keeping JAX in sync with plum.
I haven't tested your "higher-order" case, but I hope (🤞) adding jit to myfun would actually work with this PR because
my_algorithm wouldn't compile myfunc away, but keep it as a distinct function.
@jit
@dispatch
def myfun(a: jax.Array, b:int):
return a * b
@jit
@dispatch
def myfun(a: jax.Array, b:float):
return a + b
@jit
def my_algorithm(a, b):
return 12 * myfun(a, b) # myfun is not compiled away, depending on the JAX version (see discussion in https://github.com/google/jax/issues/9298)
If the jit on myfun had inline=True, then IDK what would happen.
But overall I agree this is still brittle and thus not an optimal solution!
But isn't this
# Hooks to clear the JIT caches of various libraries.
# TODO: this needs to be systematized. This should probably work
# by entry points to allow for any JIT library to be added.
# JAX:
if "jax" in type(method).__module__ and hasattr(method, "clear_cache"):
method.clear_cache()
only going to detect if you @dispatch@jit?
import jax
def dispatch(fun):
print("Fun:", type(fun).__module__)
return fun
@jax.jit
@dispatch
def test(x, y):
return x+y
print("fun now:", type(test).__module__)
outputs
Fun: builtins
fun now: jaxlib.xla_extension
so in this case it would not be triggered
Ah. Good catch. I didn't write tests yet for this, which would have caught this mistake. This is the difference between the example I gave in #154 and here.
Yes, I need to refactor this to search through all the registrants, not the current method.
We definitely do want to detect both jit(dispatch(... and dispatch(jit(....
But neither approach will work for jitted functions that call a separate dispatched function.
Is there any way around this?!
The only way that I can think of, brutal and not ideal, is to clear all jax caches with jax.clear_caches() every time we register a new function.
It's definitely not ideal, but so long as a user imports in all modules (as in registers everything) then it's not so bad. It's only when new dispatches are registered that this becomes annoying.
Maybe @patrick-kidger can advise on how to best make plum and jax play well together?
AFAIK it's not really possible.
- Adding a new dispatch rule is mutating global state (the methods table);
- I do not know of any way to have JAX automatically re-JIT when changing global state.
Tbh I think supporting this is an antipattern anyway. Everything prior to jax.jit is basically the 'source code', which on your first run you then compile. Looking at other languages: trying to automatically recompile on detecting your source code is nonstandard in C++/Rust-type languages, and in Julia caused a lot of headache: the fact that they were handling this meant they were silently footgunning their compiletimes with lots of cache invalidations.
That aside I really don't like the idea that plum should try to do something based on specific third-party libraries. If everyone did this it would be incredibly hard to reason about what my code does. I would much rather they each just do their thing without trying to interfere with each other.
@patrick-kidger. Thanks for the info!
Are you suggesting that instead we just add something to the docs showing how jit + dispatch can work together, but also showing how this can be dangerous when adding new dispatches, and suggesting jax.clear_caches for that case?
Instead of clearing the cache, this PR could check if there is a non-empty cache and raise a warning.
That sounds much more reasonable to me!
Having the extent of a cross-library interaction be a "hey, did you really mean this?" warning sounds like a good compromise I think.
Are you suggesting that instead we just add something to the docs showing how jit + dispatch can work together, but also showing how this can be dangerous when adding new dispatches, and suggesting jax.clear_caches for that case? Instead of clearing the cache, this PR could check if there is a non-empty cache and raise a warning.
This also sounds like a good approach to me. :)
@wesselb, maybe is this related to your suggestion of a cache ? Or is this a separate complication?
What I had in mind was slightly different. The cache would remember all dispatch decisions. The next time the function is called, instead of running the resolver on the arguments, the resolver would not look at the arguments, but immediately return the previously resolved method from the cache. This way you could "compile away" the overhead of dispatch.