Alex McKinney

Results 23 issues of Alex McKinney

Would it be possible to add tuple inputs to `scatter_dimension` in `jax.lax.psum_scatter`? For example, we have an array of shape `(batch_size, sequence_length, dim)` sharded as `(dp, sp, tp)` we pass...

enhancement

I was reading the paper [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) and found it pretty interesting. It proposes a few simple changes that could be useful when...

New scheduler
Good second issue

### Description The following code block: https://github.com/cupy/cupy/blob/ecf3d94ee4f8720371c9b1846c14afa06c90bcb2/cupy/_core/dlpack.pyx#L64-L80 Does not contain an entry for `bfloat16`. This ultimately means when using something like `jax.from_dlpack` on a `cupy` array, it will fail if...

cat:bug
st:awaiting-author