chex
chex copied to clipboard
chex.variants(with_pmap=True) ignores `static_argnames`
The _with_pmap
function accepts static_argnums
as a parameter, but not static_argnames
. This is inconsistent with other variants, such as with_jit
and with_device
. Crucially, this prevents to test methods that require to pass arguments by name (e.g., Distrax's Distribution.sample())
More generally, it would be best if all variants accepted the same parameters where possible (i.e., where not specific to a single variant) and I would suggest to check all keys in **unused_kwargs
against a list of allowed parameters (i.e., the union of the parameters of all variant functions) to prevent silent errors due to e.g., misspells.
Hi, I would like to work on this issue if no one else has picked it up. Thanks!
Hi @fvisin, thanks for opening this issue. And thanks for volunteering @broper2!
static_argnames
kwarg is not supported in the original pmap because of in_axes
incompatibility with kwargs (see the docstring). I would prefer not to add this to chex to avoid possible confusion / problems with forward compatibility. WDYT?
Re: checking **_unused_kwargs
-- that's exactly what's accomplished by @check_variant_arguments
:)
that's exactly what's accomplished by
@check_variant_arguments
:)
...how did I miss that!? Apologies for the noise :facepalm:
static_argnames
kwarg is not supported in the original pmap because ofin_axes
incompatibility with kwargs (see the docstring). I would prefer not to add this to chex to avoid possible confusion / problems with forward compatibility. WDYT?
That makes sense. Do you think it would be reasonable to fail with a more specific error message when static_argnames
is set, though? It took me a while to realise that the argument was being ignored, it would have helped to know that a possible reason for the test failing was that static_argnames
is ignored by with_pmap
.
That's a good point! We can prohibit users from using both static_argnames
and static_argnums
at the same time in variants' args, wdyt?
That's also a good idea, but I don't think it would have solved my specific case!
I was setting only static_argnames
, but since with_pmap
was ignoring it I didn't immediately realise why the test was failing. When static_argnames
is set and with_pmap
fails, it would be handy if the error message suggested that the failure might be caused by static_argnames
being ignored by with_pmap
due to it not being supported in the original pmap. Does it make sense?
Ah I see, it all makes sense now, TY!
@broper2 are you interested in preparing the fix?
Thanks @hbq1, will get this PR up within the week.
Hi @hbq1, dont want to leave you hanging...but I have had some other stuff come up. May be best to have someone else take this issue and prepare the PR.
NP, thanks for letting me know @broper2!