Add environment variable for setting logging level
This is how to turn on INFO-level logging for jax right now:
from absl import logging
logging.set_verbosity(logging.INFO)
It'd be better to have an env var like JAX_LOG_LEVEL=INFO so turning on logging doesn't require code changes. This would also help transition to regular Python logging.
Related: #6308
We've migrated away from ABSL logging, so the way you do this now is actually just the regular Python mechanism:
e.g., this turns on INFO logs for JAX:
import logging
logging.getLogger("jax").setLevel(logging.INFO)
I'm not sure we need a flag any more (I guess we still could), but it would make sense to document this.
Is there a doc on how to set logging level? This was the only thing that came up when I searched for "Jax docs logging level"
I am assuming what @hawkinsp mentioned is still accurate. CC @skye
Edit: I confirmed that this still works on Aug 23, 2024:
import logging
logging.getLogger("jax").setLevel(logging.INFO)
I would still very much like an env flag too since that's how I set logging levels for all other components.
Is there a doc on how to set logging level? This was the only thing that came up when I searched for "Jax docs logging level"
I am assuming what @hawkinsp mentioned is still accurate. CC @skye
Edit: I confirmed that this still works on Aug 23, 2024:
import logging logging.getLogger("jax").setLevel(logging.INFO)I would still very much like an env flag too since that's how I set logging levels for all other components.
Hello, does logging.debug() work on it's own? to avoid the .getLogger() on every operation.
We can probably close this issue, it's already supported: https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAX_LOGGING_LEVEL&type=code
This is how you do it:
export JAX_LOGGING_LEVEL="DEBUG"