jax icon indicating copy to clipboard operation
jax copied to clipboard

Checkify exception only reports error from one device (not all)

Open billmark opened this issue 9 months ago • 7 comments

Description

When running with four hosts and four devices on each host, I see an "errs" returned by pmap of checkify that looks like the folllowing:

Error(at mapped index 0: before: pmean_input_ok failed step @12290 (`check` failed)
at mapped index 1: after:neg_delta_params has NaN @12290 (`check` failed)
at mapped index 2: after:neg_delta_params has NaN @12290 (`check` failed)
at mapped index 3: after:neg_delta_params has NaN @12290 (`check` failed))

However, an errs.throw() (as recommended in JAX docs) only shows one of these four errors:

 Top-level exception: after:neg_delta_params has NaN @12290 (`check` failed)
  ...
  jax._src.checkify.FailedCheckError: after:neg_delta_params has NaN @12290 (`check` failed)

I consider this behavior to be a bug. No reasonable person would expect the exception string to omit the errors from three our of four devices on that host. The exception string should contain all four errors.

System info (python version, jaxlib version, accelerator, etc.)

HEAD at google as of May 15, 2024. Running on TPU. (Four hosts, four devices per host).

billmark avatar May 15 '24 18:05 billmark