Yongchang Hao

Results 5 issues of Yongchang Hao

This should fix issue #9 for this version of code.

In the following code, the comment says weight A is init as usual (kaiming init as in other places in the code) and B as zeros. However, the behavior is...

The doc says `mu` is inferred from `grads` and `updates` if `mu_dtype=None`. But [this line](https://github.com/google/automl/blob/ecbc7a3b5ab1944d5f39a80a21fd07b7606583f0/lion/lion_optax.py#LL108C45-L108C45) actually turns `jnp.bfloat16` and `jnp.float16` to `jnp.float32` when `mu_dtype=None`. Example on GPUs: ```python >>> jax.__version__...

There is a filter of the indices that selects only `latents[:, 0] != 3`. https://github.com/YugeTen/fish/blob/333efa24572d99da0a4107ab9cc4af93a915d2a9/src/models/datasets.py#L266-L271 In the code above, this condition is applied on `val` and `test`, but not on...

I have a use case where there are multiple indexes on one device. The queries are also batched. For example, if there are N indexes, the query matrix has shape...

feature request