yuanz
yuanz
@johannahaffner thank you. > This is expected. The values of the leaves are the same, but you have two new trees now, that makes two objects that JAX is treating...
Maybe I didn't ask clearly. My point is that, if you define `_norm` as an instance method instead of a `Module` field, it will not be compared (like `__call__` of...
That I understand because `query_proj` is a `Module` so that it handles well recursively. However, `_norm` is internally supposed to be static and stateless by `WeightNorm`. Since `tree_equal` compares arrays...