Solving ODE with huge inputs / multi-device
Hello,
I am looking to use Diffrax for a multi-device (multi-host to be more precise) setup.
I have setup everything correctly but I am not getting the same results as a single device solver.
What I want to know, is Diffrax supposed to work out of the box with sharded global arrays? The function I am giving to ODETerm takes a sharded array but it does give me the correct results (sharding and custom partioning is done)
I can provide a MWE if needed.
Second (but related) question.
Is there anyway I can save snapshots (SaveAt) to disk and delete them rather than keeping the solution in memory?
thank you
In theory Diffrax should work under parallelism! In practice it's not a very common combination, so perhaps something has gone wrong!
My first usual thing to be suspicious of is our usage of equinox.error_if, which does fairly magical things. You could try disabling that (with an environment variable, see its docs) and see?
As for saving results: probably yes by using a jax.pure_callback. if you want to do that as you go along then you could pass it to SaveAt(fn=...). You'll need to return a dummy value at minimum though, since if the output of the callback is unused then it will be DCE'd.
Thank you very much for your answer.
I have been testing a bit and I found that even a ConstantStepSize Euler does this :
E0705 10:38:57.589597 3858372 spmd_partitioner.cc:569] [spmd] Involuntary full rematerialization. The compiler was not able to go from sharding {maximal device=0} to {devices=[1,1,2,2,1,1]<=[2,2]T(1,0)} without doing a full rematerialization of the tensor. You probably want to enrich the sharding annotations to prevent this from happening.
it is very difficult to see why just by looking at the HLO since there is no all-gather
However by settings EQX_ON_ERROR=nan everything works like a charm (and quite fast) even with a PIDController
When this happens to me it is almost always an element wise operation on non-addressable array with a (supposed) fully replicated array. The solution was either shard_map and using an empty PartitionSpec on the fully replicated array, or using lax.with_sharding_constraint
I will close the issue because with EQX_ON_ERROR=nan diffrax is usable in a multi-host setup
Thank you again.