Ji Yu

Results 3 issues of Ji Yu

**Describe the bug** Running mnist_cnn.py in the example dir crash the instance at the end of the first epoch. This was previously reported on Colab GPU instance. But I can...

bug

Ran into this bug in a rare edge case in loss_and_logs.py: ``` def compute(self) -> tp.Tuple[jnp.ndarray, Logs, Logs]: if self.losses is not None: loss, losses_logs = self.losses.compute() else: loss =...

Bug: Instances of the same tx.Module have different tree_structure ``` class T(tx.Module): pass t1 = T() t2 = T() jax.tree_structure(t1) == jax.tree_structure(t2) >>> False ``` Patch is simple: ``` diff...