chex icon indicating copy to clipboard operation
chex copied to clipboard

chex.variants(with_pmap=True) ignores `static_argnames`

Open fvisin opened this issue 3 years ago • 9 comments

The _with_pmap function accepts static_argnums as a parameter, but not static_argnames. This is inconsistent with other variants, such as with_jit and with_device. Crucially, this prevents to test methods that require to pass arguments by name (e.g., Distrax's Distribution.sample())

More generally, it would be best if all variants accepted the same parameters where possible (i.e., where not specific to a single variant) and I would suggest to check all keys in **unused_kwargs against a list of allowed parameters (i.e., the union of the parameters of all variant functions) to prevent silent errors due to e.g., misspells.

fvisin avatar Nov 23 '21 10:11 fvisin

Hi, I would like to work on this issue if no one else has picked it up. Thanks!

broper2 avatar Nov 23 '21 16:11 broper2

Hi @fvisin, thanks for opening this issue. And thanks for volunteering @broper2!

static_argnames kwarg is not supported in the original pmap because of in_axes incompatibility with kwargs (see the docstring). I would prefer not to add this to chex to avoid possible confusion / problems with forward compatibility. WDYT?

Re: checking **_unused_kwargs -- that's exactly what's accomplished by @check_variant_arguments :)

hbq1 avatar Nov 23 '21 21:11 hbq1

that's exactly what's accomplished by @check_variant_arguments :)

...how did I miss that!? Apologies for the noise :facepalm:

static_argnames kwarg is not supported in the original pmap because of in_axes incompatibility with kwargs (see the docstring). I would prefer not to add this to chex to avoid possible confusion / problems with forward compatibility. WDYT?

That makes sense. Do you think it would be reasonable to fail with a more specific error message when static_argnames is set, though? It took me a while to realise that the argument was being ignored, it would have helped to know that a possible reason for the test failing was that static_argnames is ignored by with_pmap.

fvisin avatar Nov 24 '21 10:11 fvisin

That's a good point! We can prohibit users from using both static_argnames and static_argnums at the same time in variants' args, wdyt?

hbq1 avatar Nov 24 '21 16:11 hbq1

That's also a good idea, but I don't think it would have solved my specific case!

I was setting only static_argnames, but since with_pmap was ignoring it I didn't immediately realise why the test was failing. When static_argnames is set and with_pmap fails, it would be handy if the error message suggested that the failure might be caused by static_argnames being ignored by with_pmap due to it not being supported in the original pmap. Does it make sense?

fvisin avatar Nov 24 '21 17:11 fvisin

Ah I see, it all makes sense now, TY!

@broper2 are you interested in preparing the fix?

hbq1 avatar Nov 24 '21 17:11 hbq1

Thanks @hbq1, will get this PR up within the week.

broper2 avatar Nov 24 '21 17:11 broper2

Hi @hbq1, dont want to leave you hanging...but I have had some other stuff come up. May be best to have someone else take this issue and prepare the PR.

broper2 avatar Nov 29 '21 21:11 broper2

NP, thanks for letting me know @broper2!

hbq1 avatar Nov 30 '21 10:11 hbq1