dalle-playground icon indicating copy to clipboard operation
dalle-playground copied to clipboard

Error on any model - DenseElementsAttr could not be constructed from the given buffer

Open Devion opened this issue 2 years ago • 1 comments

Doesn't seem to matter much which model I pick (mini, mega, etc.) they all fail with the same error. Jaxlib 1.3.7 installed

> D:\dalle\dalle-playground\backend>python app.py --port 8080 --model_version mega --save_to_disk true

--> Starting DALL-E Server. This might take up to two minutes.
wandb: Currently logged in as: anony-moose-315442. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.12.18
wandb: Run data is saved locally in D:\dalle\dalle-playground\backend\wandb\run-20220615_222754-1tbnwomv
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run lilac-field-9
wandb:  View project at https://wandb.ai/anony-moose-315442/dalle-playground-backend?apiKey=cb614629a73707160e3c82c72823ef954bcad91c
wandb:  View run at https://wandb.ai/anony-moose-315442/dalle-playground-backend/runs/1tbnwomv?apiKey=cb614629a73707160e3c82c72823ef954bcad91c
wandb: WARNING Do NOT share these links with anyone. They can be used to claim your runs.
wandb: Downloading large artifact mega-1-fp16:latest, 4938.53MB. 7 files... Done. 0:0:5.4
Some of the weights of DalleBart were initialized in float16 precision from the model checkpoint at C:\Users\blood\AppData\Local\Temp\tmp_nwoes5j:
[('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'embed_tokens', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_0', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_1', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_2', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'LayerNorm_0', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'LayerNorm_1', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_0', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_1', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_1', 'scale'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_2', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_3', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_3', 'scale'), ('model', 'encoder', 'embed_positions', 'embedding'), ('model', 'encoder', 'embed_tokens', 'embedding'), ('model', 'encoder', 'final_ln', 'bias'), ('model', 'encoder', 'layernorm_embedding', 'bias'), ('model', 'encoder', 'layernorm_embedding', 'scale'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'q_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'v_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_0', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_1', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_2', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'LayerNorm_0', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'LayerNorm_1', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_0', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_1', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_1', 'scale')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.
wandb: Downloading large artifact mega-1-fp16:latest, 4938.53MB. 7 files... Done. 0:0:2.7
Traceback (most recent call last):
  File "D:\dalle\dalle-playground\backend\app.py", line 61, in <module>
    dalle_model.generate_images("warm-up", 1)
  File "D:\dalle\dalle-playground\backend\dalle_model.py", line 92, in generate_images
    key, subkey = jax.random.split(key)
  File "C:\Python\Python39\site-packages\jax\_src\random.py", line 194, in split
    return _return_prng_keys(wrapped, _split(key, num))
  File "C:\Python\Python39\site-packages\jax\_src\random.py", line 180, in _split
    return key._split(num)
  File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 203, in _split
    return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
  File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 474, in threefry_split
    return _threefry_split(key, int(num))  # type: ignore
  File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 479, in _threefry_split
    return lax.reshape(threefry_2x32(key, counts), (num, 2))
  File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 467, in threefry_2x32
    x = threefry2x32_p.bind(key1, key2, x[0], x[1])
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "D:\dalle\dalle-playground\backend\app.py", line 61, in <module>
    dalle_model.generate_images("warm-up", 1)
  File "D:\dalle\dalle-playground\backend\dalle_model.py", line 92, in generate_images
    key, subkey = jax.random.split(key)
  File "C:\Python\Python39\site-packages\jax\_src\random.py", line 194, in split
    return _return_prng_keys(wrapped, _split(key, num))
  File "C:\Python\Python39\site-packages\jax\_src\random.py", line 180, in _split
    return key._split(num)
  File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 203, in _split
    return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
  File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 474, in threefry_split
    return _threefry_split(key, int(num))  # type: ignore
  File "C:\Python\Python39\site-packages\jax\_src\traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "C:\Python\Python39\site-packages\jax\_src\api.py", line 473, in cache_miss
    out_flat = xla.xla_call(
  File "C:\Python\Python39\site-packages\jax\core.py", line 1765, in bind
    return call_bind(self, fun, *args, **params)
  File "C:\Python\Python39\site-packages\jax\core.py", line 1781, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "C:\Python\Python39\site-packages\jax\core.py", line 678, in process_call
    return primitive.impl(f, *tracers, **params)
  File "C:\Python\Python39\site-packages\jax\_src\dispatch.py", line 182, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "C:\Python\Python39\site-packages\jax\linear_util.py", line 285, in memoized_fun
    ans = call(fun, *args)
  File "C:\Python\Python39\site-packages\jax\_src\dispatch.py", line 230, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "C:\Python\Python39\site-packages\jax\_src\profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "C:\Python\Python39\site-packages\jax\_src\dispatch.py", line 340, in lower_xla_callable
    module, keepalive = mlir.lower_jaxpr_to_module(
  File "C:\Python\Python39\site-packages\jax\interpreters\mlir.py", line 556, in lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "C:\Python\Python39\site-packages\jax\interpreters\mlir.py", line 810, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
  File "C:\Python\Python39\site-packages\jax\interpreters\mlir.py", line 929, in jaxpr_subcomp
    ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
  File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 400, in _threefry2x32_gpu_lowering
    return threefry2x32_lowering(
  File "C:\Python\Python39\site-packages\jaxlib\cuda_prng.py", line 76, in threefry2x32_lowering
    layout = ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1),
jax._src.traceback_util.UnfilteredStackTrace: ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "D:\dalle\dalle-playground\backend\app.py", line 61, in <module>
    dalle_model.generate_images("warm-up", 1)
  File "D:\dalle\dalle-playground\backend\dalle_model.py", line 92, in generate_images
    key, subkey = jax.random.split(key)
  File "C:\Python\Python39\site-packages\jax\_src\random.py", line 194, in split
    return _return_prng_keys(wrapped, _split(key, num))
  File "C:\Python\Python39\site-packages\jax\_src\random.py", line 180, in _split
    return key._split(num)
  File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 203, in _split
    return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
  File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 474, in threefry_split
    return _threefry_split(key, int(num))  # type: ignore
  File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 400, in _threefry2x32_gpu_lowering
    return threefry2x32_lowering(
  File "C:\Python\Python39\site-packages\jaxlib\cuda_prng.py", line 76, in threefry2x32_lowering
    layout = ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1),
ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.
wandb: Waiting for W&B process to finish... (failed 1). Press Ctrl-C to abort syncing.
wandb:
wandb: Synced lilac-field-9: https://wandb.ai/anony-moose-315442/dalle-playground-backend/runs/1tbnwomv?apiKey=cb614629a73707160e3c82c72823ef954bcad91c
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: .\wandb\run-20220615_222754-1tbnwomv\logs

Devion avatar Jun 15 '22 20:06 Devion

This issue is related to the default data type of jaxlib on Windows. The fix for the issue has been reported here: https://github.com/google/jax/pull/10592

But you can fix it in a simple way by following these steps:

  1. Open this file in text editor: "C:\Python\Python39\site-packages\jaxlib\cuda_prng.py"
  2. On line 75 or 76 you will have this: layout = ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1),
  3. Just change it to: layout = ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1, dtype=np.int64),
  4. Save and test again.

MarlonColhado avatar Jun 16 '22 03:06 MarlonColhado