jax-windows-builder
jax-windows-builder copied to clipboard
ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.
A Stack Overflow user reported a problem with these JAX builds of the error in the title when running the following code:
import jax.numpy as jnp
a = jnp.array([[1, 2], [3, 5]])
b = jnp.array([1, 2])
x = jnp.linalg.solve(a, b)
print(x)
This appears to happen in versions 0.3.7 and 0.3.5 but not 0.3.2. I can't check at the moment but believe similar errors also happened when trying to generate any random numbers on these versions.
Thx for the info. Will take some time to investigate into it once I have free time
Windows bites!
thoes
blahblah
ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
is simply caused by np.arange(num_bd, -1, -1) default to int32 on windows...
If you add , dtype=np.int64
to all thoes np array just before ir.IndexType.get()
, it works fine. At least the linear system solves to
[-1.0000002 1.0000001]
¯\ _(ツ)_/¯
(I hope this is fixed in JAX head at this point. Please let me know if it isn't...)
@hawkinsp currently suffer from
external/org_tensorflow/tensorflow/core/tpu/tpu_initializer_helper.cc(18): fatal error C1083: Cannot open include file: 'dirent.h': No such file or directory
during building of jaxlib.
If I use pip install -e .[minimum-jaxlib]
for main branch jax with the existing jaxlib==0.3.7, the following error occurs on import
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "D:\jax-windows-builder\jax\jax\__init__.py", line 120, in <module>
from jax.experimental.maps import soft_pmap as soft_pmap
File "D:\jax-windows-builder\jax\jax\experimental\maps.py", line 26, in <module>
from jax import numpy as jnp
File "D:\jax-windows-builder\jax\jax\numpy\__init__.py", line 18, in <module>
from jax.numpy import fft as fft
File "D:\jax-windows-builder\jax\jax\numpy\fft.py", line 15, in <module>
from jax._src.numpy.fft import (
File "D:\jax-windows-builder\jax\jax\_src\numpy\fft.py", line 19, in <module>
from jax import lax
File "D:\jax-windows-builder\jax\jax\lax\__init__.py", line 359, in <module>
from jax.lax import linalg as linalg
File "D:\jax-windows-builder\jax\jax\lax\linalg.py", line 15, in <module>
from jax._src.lax.linalg import (
File "D:\jax-windows-builder\jax\jax\_src\lax\linalg.py", line 979, in <module>
gpu_linalg.cuda_lu_pivots_to_permutation),
AttributeError: module 'jaxlib.cuda_linalg' has no attribute 'cuda_lu_pivots_to_permutation'
So it might take some time to verify it.
@hawkinsp The fix in upstream can be verfied.
This will be left as open because it affects artifacts of version 0.3.5 and 0.3.7
Any ETA on updating jaxlib builds?
Hey, getting the same problem :(
Nevermind, fixed with the simple solution provided by @cloudhan
If you add , dtype=np.int64 to all thoes np array just before ir.IndexType.get(), it works fine. At least the linear system solves to
With this I was successfully able to run DALL-E Mini locally on my windows machine, thanks so much!
Note that change has already been made at JAX head.
"If you add , dtype=np.int64 to all thoes np array just before ir.IndexType.get(), it works fine. At least the linear system solves to"
I am alittle confused where or what file do I need to make this change in?
"If you add , dtype=np.int64 to all thoes np array just before ir.IndexType.get(), it works fine. At least the linear system solves to"
I am alittle confused where or what file do I need to make this change in?
Wherever the error message pops up, go to that line and set it there
"If you add , dtype=np.int64 to all thoes np array just before ir.IndexType.get(), it works fine. At least the linear system solves to" I am alittle confused where or what file do I need to make this change in?
Wherever the error message pops up, go to that line and set it there
Thanks for the clarification! Could you give an example of what that looks like? I'm also struggling to patch this on my machine.
I got it working by editing my cuda_prng.py file and adding this change suggested (on line 76 for me) ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1, dtype=np.int64),