Alex McKinney

Results 23 issues of Alex McKinney

Stumbled across google/jax#4862 discussion on parameter initialisation differences between PyTorch and Flax. Would it be worth adding a note in the documentation that highlights these weight init differences? Users porting...

Currently, using `optax.MultiSteps` with the rest of equinox is not possible. It will fail with the error `TypeError: Cannot interpret value of type as an abstract array; it does not...

question

Related #18897 Reproducing code: ```python import jax from jax.experimental import pallas as pl # take first two elements def kernel(x_ref, o_ref): x = x_ref[...] o_ref[...] = jax.lax.slice(x, (0,), (2,)) def...

enhancement
pallas

Just correcting what I think is a typo. To not break compatibility, should I alias `FlaxWhisperPipline = FlaxWhisperPipeline` in the `__init__.py` file? Thanks!

Though this repository already can be used for sampling from pretrained models, there are no test scripts to guarantee matching results between this implementation and the original one. Particularly risky...

As I want this repo to reach a similar (or exceed the) efficiency of the original CUDA implementation, this will likely involve a custom kernel to avoid materialising the full...

### Feature request Extend the `imagefolder` dataloading script to support loading multiple images per dataset entry. This only really makes sense if a metadata file is present. Currently you can...

enhancement

Greetings, I am working on a reimplementation of this paper in JAX. You can find it [here](https://github.com/vvvm23/mezo-jax) I am wondering what were the best hyperparameters you found for each dataset....

Hi all, I noticed in [this commit](https://github.com/google-deepmind/optax/commit/7d43c5c0cc1ab343229c3c394b3179a1404e97e8) that the Sophia optimiser (see [paper](https://arxiv.org/abs/2305.14342)) has been integrated into `optax`. However, I noticed the estimate of the diagonal hessian is far simpler...