Chase Riley Roberts
Chase Riley Roberts
@mattjj should be g2g now
I removed the test in `shard_map_test.py`
There were a bunch of stupid tiny errors on my end. @mattjj I tested this on a 8 GPU node and the tests pass so we should be g2g now.
Basically what I'm experimenting with is a way to decompose a jax program in to multiple jax+mpi programs that run distributed. ```python big_jaxpr = jax.make_jaxpr(some_func)(*args, **kwargs) many_smaller_jaxprs = decompose_problem(big_jaxpr) my_methods...
`Traceback`s in general are not pickleable sadly since they reference the full memory stack. https://stackoverflow.com/questions/6132469/why-cant-i-pickle-an-errors-traceback-in-python Your other suggestion of making tracebacks in Jaxprs optional is already the case I believe....
Actually, upon further experimenting, I'm not sure pickleing is viable for the usecase I described above. ```python >>> import jax >>> j = jax.make_jaxpr(lambda a, b: a + b)(1.0, 2.0)...
> I think that is the difference between pickle and cloudpickle. Using the former should avoid the problem. I wish it was that simple. Normal pickle has it's own struggles....
I work on jax full time now so I'm going to try and lead this. The two issues are the traceback thing and the global dictionary lookups. The traceback thing...
Ok so hacking in this change in `core.py` ```python # Table that stores all primitive definitions. # Needed so that primitives are treated as singletons # when using cloudpickle. _PRIMITIVES_TABLE_:...
What's the status of this?