jax icon indicating copy to clipboard operation
jax copied to clipboard

Annotate tree_util

Open NeilGirdhar opened this issue 3 years ago • 11 comments

NeilGirdhar avatar Jan 03 '22 07:01 NeilGirdhar

Nice! Thanks as always, Neil.

I have suggestions which fall into two main categories:

  1. whitespace

Of course, I'm happy to fix all whitespace to match your style 😄

  1. having more specific output type annotations (lists rather than sequences).

I'm interested if there are any strong opinions against 2, but it seems preferable to be specific about output types (unless we thought they were likely to change).

I don't have any strong opinion, but I should have commented on this in the change. I initially wrote the annotations using List, but it turns out that the underlying jaxlib functions are already annotated using Sequence. This caused typing errors in the tree_util code.

I feel like jaxlib and tree_util should match: either jaxlib should change its annotations to List, or else tree_util should use Sequence to match.

Promising sequence means that you could one day choose a different sequence type than list. For example, you might provide some alternative structure that supports easier tree traversal. Maybe that was the motivation of whoever annotated jaxlib: keep options open. I agree with the jaxlib author that promising less is usually better.

Please let me know what you prefer.

NeilGirdhar avatar Jan 08 '22 00:01 NeilGirdhar

I agree with the jaxlib author that promising less is usually better.

I'm not so sure about this in general. Why not promise even less and say it's just an Iterable? We want to promise the right amount, and IMO there's no harm in continuing to say it's a list, as our docs already did. Or at least, there wouldn't be except for the other issue you mentioned.

but it turns out that the underlying jaxlib functions are already annotated using Sequence. This caused typing errors in the tree_util code.

Ah I see. Let's keep it how your PR does it then! I don't mean to add any extra work.

mattjj avatar Jan 08 '22 04:01 mattjj

I'm not so sure about this in general. Why not promise even less and say it's just an Iterable? We want to promise the right amount,

Yes, you're absolutely right of course.

Ah I see. Let's keep it how your PR does it then! I don't mean to add any extra work.

Okay, thanks! I'll make the change above, squash, and let you know when it's done.

NeilGirdhar avatar Jan 08 '22 04:01 NeilGirdhar

You're so generous and easy to work with, Neil! It's always a pleasure, and I always learn something. You're the best.

mattjj avatar Jan 08 '22 04:01 mattjj

Done, and thank you, you all are a pleasure to work with as well!

NeilGirdhar avatar Jan 08 '22 04:01 NeilGirdhar

Unfortunately, internal typechecks (i.e. typechecks spanning all google code which uses these functions!) ran into some issues with this.

As one example, jraph does addition on the leaves outputs of tree_flatten calls. But that doesn't typecheck against Sequence, since you can't add Sequences together!

Maybe we could put these annotations under if TYPE_CHECKING? Got a better idea?

mattjj avatar Jan 08 '22 05:01 mattjj

As one example, jraph does addition on the leaves outputs of tree_flatten calls. But that doesn't typecheck against Sequence, since you can't add Sequences together!

That makes sense. That's why I had to change a couple instances of concatenation in the Jax code to use generalized unpacking.

Maybe we could put these annotations under if TYPE_CHECKING? Got a better idea?

Unfortunately, that won't solve the type checking errors since those run under TYPE_CHECKING too. I can see a few options:

  1. Change tree-util annotations to use lists, and change jaxlib to match,
  2. Change tree-util annotations to use lists, and add casts throughout so that jaxlib doesn't complain, or
  3. Change the client code to cast results to list where necessary.

Option 1 means changing Jaxlib too (a couple small changes here). Also, just glancing over the interface, it's a bit weird that flatten returns a sequence, but flatten_up_to returns a list. Maybe someone thought they could do some optimization by returning immutable structures without casting to list? On the other hand, it might be worth making the interface homogenous.

Option 2 adds extra casts that weren't there before, and is awkward having two mismatched interfaces. It's the least work, but my least favorite.

Option 3 depends on how many type errors you ended up with.

Another question goes back to your point above. What do you want in an ideal world? Would you rather promise list and people can easily concatenate them or would you rather promise sequence and you're free to use another container down the road?

There may be other options, but those are the obvious ones to me. Please let me know what you decide and I'll repair this change.

NeilGirdhar avatar Jan 08 '22 07:01 NeilGirdhar

Thanks for that super clear analysis.

I think Option 3 would be pretty annoying.

I like Option 1 but I wonder if @hawkinsp has a different opinion. Peter, WDYT?

mattjj avatar Jan 08 '22 17:01 mattjj

@hawkinsp If you like, I can make a corresponding pull request to tensorflow's jaxlib to make the interface promise lists?

NeilGirdhar avatar Feb 02 '22 18:02 NeilGirdhar

Sure, that's fine. Just tag me in the PR so I see it.

hawkinsp avatar Feb 03 '22 15:02 hawkinsp

This should pass when Tensorflow 2.10's version of jaxlib is used thanks to https://github.com/tensorflow/tensorflow/pull/54330 being compiled into it.

NeilGirdhar avatar Sep 06 '22 22:09 NeilGirdhar

@mattjj Would it be possible to get this merged now that the tensorflow type annotations are in jaxlib?

NeilGirdhar avatar Oct 26 '22 18:10 NeilGirdhar

(taking over review/merge, since this is related to #12049)

jakevdp avatar Oct 26 '22 20:10 jakevdp