Roman Novak

Results 8 issues of Roman Novak

Internal change

It would be very convenient to sometimes manually ask JAX to execute a specific RAM-heavy operation (thinking about matrix inversion / solving linear systems) on CPU where more RAM is...

enhancement
P2 (eventual)
NVIDIA GPU

An example where `np.einsum` is slower than manual matmul/transpositions. (#1966 works equally fast for me, but this example is consistently slower) on CPU and GPU. https://colab.research.google.com/gist/romanngg/e63834765d00497e315455867a52eae1/einsum_is_slow.ipynb ``` import jax.numpy as...

enhancement
performance
XLA
P2 (eventual)
CPU

Here I create two arrays and stick one of them to a different GPU: ```python import jax jax.devices() [GpuDevice(id=0), GpuDevice(id=1)] a = jax.random.normal(jax.random.PRNGKey(1), (2, 3)) b = jax.random.normal(jax.random.PRNGKey(2), (2, 3))...

bug
P0 (urgent)
NVIDIA GPU

Example: ``` x = lax.conv_general_dilated( lhs=np.ones((1, 1, 1), np.bool_), rhs=np.ones((1, 1, 1), np.bool_), window_strides=(1,), padding='SAME', dimension_numbers=('NCH', 'HIO', 'NCH') ) ``` Gives ``` --------------------------------------------------------------------------- TypeError Traceback (most recent call last) in...

enhancement
P2 (eventual)
NVIDIA GPU

Example: ```python from jax import lax from jax import numpy as np a = lax.dot_general(lhs=np.array([[True], [True]]), rhs=np.array([[True, True, True], [True, True, True]]), dimension_numbers=(((0,), (0,)), ((), ()))) ``` ```python a ```...

bug
XLA
P0 (urgent)
NVIDIA GPU

Hello and thank you for the great library! I'm curious if `opt_einsum` can be generalized to let the user specify which inputs are the same, and use this info to...