dm-haiku
dm-haiku copied to clipboard
Typing (type hinting) in Haiku?
Hey all!
I like using mypy as a smell test for my code and I was wondering if Haiku (and to a larger extent, Jax) has (well supported) built in typing? I tried doing
import jax.np as jnp
type(jnp.ones([3]))
I found Type Hints but it seems to only be a subset of supported ops. Does the team have any recommendations?
Hi @IanQS , I am not sure if as a library we need to do anything to enable hinting with MyPy for users. We do provide a couple of aliases you may find it useful to use with Haiku code, for example hk.Params and hk.State refer to the pytree structures returned or expected by Haiku.
We make sure that Haiku passes with pytype which I think offers similar functionality to mypy.
No worries! I ended up reading through the examples and found that most things just revolved around
Foo = jnp.ndarray
You could look at https://github.com/deepmind/tensor_annotations if you want a more complete type annotation (e.g. with rank and axis names). The JAX team also have some thoughts on import time type checking in JAX. So far we have not adopted any of these solutions in our libraries, there is definitely some energy from users to improve the state of typing in JAX libraries, however for now there is not a canonical solution that pleases most.
The JAX team also have some thoughts on import time type checking in JAX. So far we have not adopted any of these solutions in our libraries
Ahh, it's unfortunate that there are no built-ins. Are you at liberty to mention what DM uses internally? If the types are all aliased, or if DM is using the tensor annotations library? I figure if I follow DMs paradigm I can just trust the team's decision and move with it.
We tend to use aliases internally. There is a mixture of people aliasing jnp.ndarray and chex.Array.
I know that JAX team have plans to make rng keys a distinct type (PRNGKeyArray) which is a super common thing users want to type hint (currently jax.random.PRNGKey is a function not a type).
I am not aware of a good way to type a function that accepts a PyTree of arrays in a way that will actually make the type checker do something useful.
Now that MyPy implements ParamSpec, adding an annotation for hk.transform is possible and would be very useful.
Currently, hk.transform(f) is untyped. It is possible to produce a strict type for it. Please let me know if that's a pull request you'd consider accepting.
This works for me:
P = ParamSpec('P')
R = TypeVar('R')
class HaikuTransform(Protocol, Generic[P, R]):
init: Callable[Concatenate[KeyArray | None, P], hk.Params]
apply: Callable[Concatenate[hk.Params | None, KeyArray | None, P], R]
def hk_transform(function: Callable[P, R]) -> HaikuTransform[P, R]:
return hk.transform(function) # type: ignore
Hey @NeilGirdhar , thank you for flagging this!
It will probably be easier for someone at DeepMind to make this change based on your recommendation. This is because internally we make use of a mono-repo and changes to Haiku must pass unit tests for all of our users. When we have previously enhanced typing in Haiku this has helped identify typing issues in non-open source usage of Haiku, which needs to be fixed before we can import the change.
I'll see if someone is available to pick this up.
Makes sense! Thanks for taking a look, @tomhennigan