jax icon indicating copy to clipboard operation
jax copied to clipboard

make config.jax_platforms = 'cpu' suppress no-cpu warning

Open mattjj opened this issue 2 years ago • 5 comments

also:

  • switch the warning to use warnings.warn rather than absl.logging.warning
  • fix a test which wasn't being run because "test" wasn't in the method name (and it was failing at HEAD because it was using a test technique which only works with warnings.warn not absl.logging.warning)

mattjj avatar Apr 13 '22 22:04 mattjj

L299:

  platform = (platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name or None)

Line 235:

if FLAGS.jax_platforms:

Those lines should be fixed as well, shouldn't they?

In addition, I can see a lot of lines that checks jax_platforms and then jax_platform_name (or jax_xla_backend), which could be refactored into a simple utility function.

wookayin avatar Apr 13 '22 23:04 wookayin

@wookayin perhaps so, but in this PR I'm just trying to fix a specific bug. It could be that those other things lead to bugs, or maybe they're intentional; unfortunately I don't have time to figure that out right now!

Are there bugs you have in mind?

mattjj avatar Apr 13 '22 23:04 mattjj

Please refer to https://github.com/google/jax/issues/6805#issuecomment-1098436118. I guess these were just missed when writing #8035, and all these are in a similar vein. I'm fine having them addressed in another issue/PR as another follow-up.

wookayin avatar Apr 13 '22 23:04 wookayin

Ah I see thanks. Sorry I didn't make the connection to your earlier comment; I'm juggling a few different changes at the moment and am paying minimal attention to this one.

Can you open an issue for the bug in https://github.com/google/jax/issues/6805#issuecomment-1098436118 ?

mattjj avatar Apr 13 '22 23:04 mattjj

@mattjj is there any blocker to merging this PR now?

ilemhadri avatar Sep 09 '22 20:09 ilemhadri

@mattjj Any update on this? If I understand correctly, JAX_PLATFORM_NAME is deprecated in favor of JAX_PLATFORMS. However, there's still some code that uses the former. This includes the code that prints the warning

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

carlosgmartin avatar Apr 19 '23 06:04 carlosgmartin

Closing as moot with PR https://github.com/google/jax/pull/17751, which means we never issue that warning. What we will do is specifically warn if we detect an NVIDIA GPU or a Google TPU on the machine, we're not using it, and you didn't provide a list of specific platforms to use.

hawkinsp avatar Sep 26 '23 17:09 hawkinsp