threadpoolctl
threadpoolctl copied to clipboard
limit the # of threads for jax
Hi, I know this is probably more of an issue on the jax side and has been discussed there, e.g. google/jax#743, google/jax#1539 and google/jax#6790, although I'm still wondering if you know how to limit the # of threads for jax. Below is a simple snippet showing that currently, jax does not observe the threadpool limits.
import jax.numpy as jnp
from threadpoolctl import threadpool_limits
ja = jnp.ones((1000, 1000))
with threadpool_limits(5):
for _ in range(100):
foo = ja @ ja
Hi @HerculesJack, according to this comment https://github.com/google/jax/issues/743#issuecomment-495031093, the threading mechanism of jax is not one of the ones that threadpoolctl supports. It could be interesting to check if Eigen threadpools exposes some symbols allowing to control the number of threads.
Note: if Eigen exposes some well defined symbols to inspect and control the number of threads in its threadpool, then the mechanism implemented in #137 should make it possible to add support for tensorflow and jax to threadpoolctl.
I tried to check the Eigen documentation to see if it's the case but it seems to be down at the moment: https://www.tuxfamily.org/en/news/2023070900.