dalle-playground
dalle-playground copied to clipboard
Error on any model - DenseElementsAttr could not be constructed from the given buffer
Doesn't seem to matter much which model I pick (mini, mega, etc.) they all fail with the same error. Jaxlib 1.3.7 installed
> D:\dalle\dalle-playground\backend>python app.py --port 8080 --model_version mega --save_to_disk true
--> Starting DALL-E Server. This might take up to two minutes.
wandb: Currently logged in as: anony-moose-315442. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.12.18
wandb: Run data is saved locally in D:\dalle\dalle-playground\backend\wandb\run-20220615_222754-1tbnwomv
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run lilac-field-9
wandb: View project at https://wandb.ai/anony-moose-315442/dalle-playground-backend?apiKey=cb614629a73707160e3c82c72823ef954bcad91c
wandb: View run at https://wandb.ai/anony-moose-315442/dalle-playground-backend/runs/1tbnwomv?apiKey=cb614629a73707160e3c82c72823ef954bcad91c
wandb: WARNING Do NOT share these links with anyone. They can be used to claim your runs.
wandb: Downloading large artifact mega-1-fp16:latest, 4938.53MB. 7 files... Done. 0:0:5.4
Some of the weights of DalleBart were initialized in float16 precision from the model checkpoint at C:\Users\blood\AppData\Local\Temp\tmp_nwoes5j:
[('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'embed_tokens', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_0', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_1', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_2', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'LayerNorm_0', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'LayerNorm_1', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_0', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_1', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_1', 'scale'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_2', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_3', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_3', 'scale'), ('model', 'encoder', 'embed_positions', 'embedding'), ('model', 'encoder', 'embed_tokens', 'embedding'), ('model', 'encoder', 'final_ln', 'bias'), ('model', 'encoder', 'layernorm_embedding', 'bias'), ('model', 'encoder', 'layernorm_embedding', 'scale'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'q_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'v_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_0', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_1', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_2', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'LayerNorm_0', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'LayerNorm_1', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_0', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_1', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_1', 'scale')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.
wandb: Downloading large artifact mega-1-fp16:latest, 4938.53MB. 7 files... Done. 0:0:2.7
Traceback (most recent call last):
File "D:\dalle\dalle-playground\backend\app.py", line 61, in <module>
dalle_model.generate_images("warm-up", 1)
File "D:\dalle\dalle-playground\backend\dalle_model.py", line 92, in generate_images
key, subkey = jax.random.split(key)
File "C:\Python\Python39\site-packages\jax\_src\random.py", line 194, in split
return _return_prng_keys(wrapped, _split(key, num))
File "C:\Python\Python39\site-packages\jax\_src\random.py", line 180, in _split
return key._split(num)
File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 203, in _split
return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 474, in threefry_split
return _threefry_split(key, int(num)) # type: ignore
File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 479, in _threefry_split
return lax.reshape(threefry_2x32(key, counts), (num, 2))
File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 467, in threefry_2x32
x = threefry2x32_p.bind(key1, key2, x[0], x[1])
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "D:\dalle\dalle-playground\backend\app.py", line 61, in <module>
dalle_model.generate_images("warm-up", 1)
File "D:\dalle\dalle-playground\backend\dalle_model.py", line 92, in generate_images
key, subkey = jax.random.split(key)
File "C:\Python\Python39\site-packages\jax\_src\random.py", line 194, in split
return _return_prng_keys(wrapped, _split(key, num))
File "C:\Python\Python39\site-packages\jax\_src\random.py", line 180, in _split
return key._split(num)
File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 203, in _split
return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 474, in threefry_split
return _threefry_split(key, int(num)) # type: ignore
File "C:\Python\Python39\site-packages\jax\_src\traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "C:\Python\Python39\site-packages\jax\_src\api.py", line 473, in cache_miss
out_flat = xla.xla_call(
File "C:\Python\Python39\site-packages\jax\core.py", line 1765, in bind
return call_bind(self, fun, *args, **params)
File "C:\Python\Python39\site-packages\jax\core.py", line 1781, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "C:\Python\Python39\site-packages\jax\core.py", line 678, in process_call
return primitive.impl(f, *tracers, **params)
File "C:\Python\Python39\site-packages\jax\_src\dispatch.py", line 182, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "C:\Python\Python39\site-packages\jax\linear_util.py", line 285, in memoized_fun
ans = call(fun, *args)
File "C:\Python\Python39\site-packages\jax\_src\dispatch.py", line 230, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "C:\Python\Python39\site-packages\jax\_src\profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "C:\Python\Python39\site-packages\jax\_src\dispatch.py", line 340, in lower_xla_callable
module, keepalive = mlir.lower_jaxpr_to_module(
File "C:\Python\Python39\site-packages\jax\interpreters\mlir.py", line 556, in lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "C:\Python\Python39\site-packages\jax\interpreters\mlir.py", line 810, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
File "C:\Python\Python39\site-packages\jax\interpreters\mlir.py", line 929, in jaxpr_subcomp
ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 400, in _threefry2x32_gpu_lowering
return threefry2x32_lowering(
File "C:\Python\Python39\site-packages\jaxlib\cuda_prng.py", line 76, in threefry2x32_lowering
layout = ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1),
jax._src.traceback_util.UnfilteredStackTrace: ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "D:\dalle\dalle-playground\backend\app.py", line 61, in <module>
dalle_model.generate_images("warm-up", 1)
File "D:\dalle\dalle-playground\backend\dalle_model.py", line 92, in generate_images
key, subkey = jax.random.split(key)
File "C:\Python\Python39\site-packages\jax\_src\random.py", line 194, in split
return _return_prng_keys(wrapped, _split(key, num))
File "C:\Python\Python39\site-packages\jax\_src\random.py", line 180, in _split
return key._split(num)
File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 203, in _split
return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 474, in threefry_split
return _threefry_split(key, int(num)) # type: ignore
File "C:\Python\Python39\site-packages\jax\_src\prng.py", line 400, in _threefry2x32_gpu_lowering
return threefry2x32_lowering(
File "C:\Python\Python39\site-packages\jaxlib\cuda_prng.py", line 76, in threefry2x32_lowering
layout = ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1),
ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.
wandb: Waiting for W&B process to finish... (failed 1). Press Ctrl-C to abort syncing.
wandb:
wandb: Synced lilac-field-9: https://wandb.ai/anony-moose-315442/dalle-playground-backend/runs/1tbnwomv?apiKey=cb614629a73707160e3c82c72823ef954bcad91c
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: .\wandb\run-20220615_222754-1tbnwomv\logs
This issue is related to the default data type of jaxlib on Windows. The fix for the issue has been reported here: https://github.com/google/jax/pull/10592
But you can fix it in a simple way by following these steps:
- Open this file in text editor: "C:\Python\Python39\site-packages\jaxlib\cuda_prng.py"
- On line 75 or 76 you will have this:
layout = ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1),
- Just change it to:
layout = ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1, dtype=np.int64),
- Save and test again.