jax-windows-builder icon indicating copy to clipboard operation
jax-windows-builder copied to clipboard

ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.

Open tgbrooks opened this issue 2 years ago • 14 comments

A Stack Overflow user reported a problem with these JAX builds of the error in the title when running the following code:

import jax.numpy as jnp
a = jnp.array([[1, 2], [3, 5]])
b = jnp.array([1, 2])
x = jnp.linalg.solve(a, b)
print(x)

This appears to happen in versions 0.3.7 and 0.3.5 but not 0.3.2. I can't check at the moment but believe similar errors also happened when trying to generate any random numbers on these versions.

tgbrooks avatar Apr 27 '22 15:04 tgbrooks

Thx for the info. Will take some time to investigate into it once I have free time

cloudhan avatar Apr 28 '22 04:04 cloudhan

Windows bites!

thoes

blahblah
ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),

is simply caused by np.arange(num_bd, -1, -1) default to int32 on windows...

If you add , dtype=np.int64 to all thoes np array just before ir.IndexType.get(), it works fine. At least the linear system solves to

[-1.0000002  1.0000001]

¯\ _(ツ)_/¯

cloudhan avatar Apr 28 '22 08:04 cloudhan

(I hope this is fixed in JAX head at this point. Please let me know if it isn't...)

hawkinsp avatar May 12 '22 13:05 hawkinsp

@hawkinsp currently suffer from

external/org_tensorflow/tensorflow/core/tpu/tpu_initializer_helper.cc(18): fatal error C1083: Cannot open include file: 'dirent.h': No such file or directory

during building of jaxlib.

If I use pip install -e .[minimum-jaxlib] for main branch jax with the existing jaxlib==0.3.7, the following error occurs on import

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\jax-windows-builder\jax\jax\__init__.py", line 120, in <module>
    from jax.experimental.maps import soft_pmap as soft_pmap
  File "D:\jax-windows-builder\jax\jax\experimental\maps.py", line 26, in <module>
    from jax import numpy as jnp
  File "D:\jax-windows-builder\jax\jax\numpy\__init__.py", line 18, in <module>
    from jax.numpy import fft as fft
  File "D:\jax-windows-builder\jax\jax\numpy\fft.py", line 15, in <module>
    from jax._src.numpy.fft import (
  File "D:\jax-windows-builder\jax\jax\_src\numpy\fft.py", line 19, in <module>
    from jax import lax
  File "D:\jax-windows-builder\jax\jax\lax\__init__.py", line 359, in <module>
    from jax.lax import linalg as linalg
  File "D:\jax-windows-builder\jax\jax\lax\linalg.py", line 15, in <module>
    from jax._src.lax.linalg import (
  File "D:\jax-windows-builder\jax\jax\_src\lax\linalg.py", line 979, in <module>
    gpu_linalg.cuda_lu_pivots_to_permutation),
AttributeError: module 'jaxlib.cuda_linalg' has no attribute 'cuda_lu_pivots_to_permutation'

So it might take some time to verify it.

cloudhan avatar May 13 '22 05:05 cloudhan

@hawkinsp The fix in upstream can be verfied.

cloudhan avatar May 18 '22 12:05 cloudhan

This will be left as open because it affects artifacts of version 0.3.5 and 0.3.7

cloudhan avatar May 18 '22 12:05 cloudhan

Any ETA on updating jaxlib builds?

adam-hartshorne avatar Jun 13 '22 11:06 adam-hartshorne

Hey, getting the same problem :(

Jayy001 avatar Jun 14 '22 15:06 Jayy001

Nevermind, fixed with the simple solution provided by @cloudhan

If you add , dtype=np.int64 to all thoes np array just before ir.IndexType.get(), it works fine. At least the linear system solves to

With this I was successfully able to run DALL-E Mini locally on my windows machine, thanks so much!

Jayy001 avatar Jun 14 '22 16:06 Jayy001

Note that change has already been made at JAX head.

hawkinsp avatar Jun 14 '22 16:06 hawkinsp

"If you add , dtype=np.int64 to all thoes np array just before ir.IndexType.get(), it works fine. At least the linear system solves to"

I am alittle confused where or what file do I need to make this change in?

Peelz4Dead avatar Jun 15 '22 23:06 Peelz4Dead

"If you add , dtype=np.int64 to all thoes np array just before ir.IndexType.get(), it works fine. At least the linear system solves to"

I am alittle confused where or what file do I need to make this change in?

Wherever the error message pops up, go to that line and set it there

Jayy001 avatar Jun 16 '22 12:06 Jayy001

"If you add , dtype=np.int64 to all thoes np array just before ir.IndexType.get(), it works fine. At least the linear system solves to" I am alittle confused where or what file do I need to make this change in?

Wherever the error message pops up, go to that line and set it there

Thanks for the clarification! Could you give an example of what that looks like? I'm also struggling to patch this on my machine.

Jernik avatar Jun 16 '22 13:06 Jernik

I got it working by editing my cuda_prng.py file and adding this change suggested (on line 76 for me) ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1, dtype=np.int64),

Peelz4Dead avatar Jun 16 '22 17:06 Peelz4Dead