trax copied to clipboard
Could not normally run trax using GPU in local computer
Hi, I would like to install trax locally. Firstly, I found jax I installed is not suitable for GPU, so I follow the jax github to install Cuda version jax. Next, I validate jax could detect GPU in my local computer, but I could not run the sample code like transfomer and fast math.
Environment information
OS: Pop-os(based on ubuntu 22.04)
$ pip freeze | grep trax
# trax==1.4.1
$ pip freeze | grep tensor
# tensorboard==2.12.3
# tensorboard-data-server==0.7.1
# tensorflow==2.12.0
# tensorflow-datasets==4.9.2
# tensorflow-estimator==2.12.0
# tensorflow-hub==0.13.0
# tensorflow-io-gcs-filesystem==0.32.0
# tensorflow-metadata==1.13.1
# tensorflow-text==2.12.1
$ pip freeze | grep jax
# jax==0.4.12
# jaxlib==0.4.12+cuda11.cudnn86
$ python -V
# Python 3.11.3
For bugs: reproduction and error logs
# Steps to reproduce:
1) Install trax
- pip install trax
- pip install --upgrade "jax[cuda11_pip]" -f
2) Use jax Detect GPU
- code:
import jax
- output:
# Error logs:
1) Run the sample code of pre-trained transformer in your Realme tutorial
- code:
import os
import numpy as np
import trax
# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
d_model=512, d_ff=2048,
n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
max_len=64, mode='predict')
# Initialize using pre-trained weights.
# input_signature=input_signature)
# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list([sentence]), # Operates on streams.
# Decode from the Transformer.
tokenized = tokenized[None, :] # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
model, tokenized, temperature=0.0) # Higher temperature: more diverse results.
# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1] # Remove batch and EOS.
translation =,
- Error Output:
2023-06-22 15:58:35.266959: W tensorflow/tsl/platform/cloud/] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host:".
2023-06-22 15:58:56.630331: W external/xla/xla/service/gpu/runtime/] Intercepted XLA runtime error:
INTERNAL: Failed to get stream's capture status: out of memory
2023-06-22 15:58:56.630403: E external/xla/xla/pjrt/] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to get stream's capture status: out of memory; current tracing scope: fusion; current profiling annotation: XlaModule:#hlo_module=jit_PRNGKey,program_id=0#.
Traceback (most recent call last):
File "/home/littleliu/Documents/project/trax_learning/", line 22, in <module>
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/trax/layers/", line 349, in init_from_file
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/trax/layers/", line 310, in init
raise LayerError(name, 'init', self._caller,
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/models/, line 371
layer input shapes: (ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:float32})
File [...]/trax/layers/, line 108, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
File [...]/trax/layers/, line 641, in _forward_abstract
layer created in file [...]/trax/models/, line 372
layer input shapes: (ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64})
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to get stream's capture status: out of memory; current tracing scope: fusion; current profiling annotation: XlaModule:#hlo_module=jit_PRNGKey,program_id=0#.
2) Run the sample code of Fast Math:
- code:
import trax
from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax') # Can be 'jax' or 'tensorflow-numpy'.
matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'matrix =\n{matrix}')
vector = fastnp.ones(3)
print(f'vector = {vector}')
product =, matrix)
print(f'product = {product}')
tanh = fastnp.tanh(product)
print(f'tanh(product) = {tanh}')
- Error Output:
matrix =
[[1 2 3]
[4 5 6]
[7 8 9]]
2023-06-22 16:03:23.041313: E external/xla/xla/stream_executor/cuda/] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-06-22 16:03:23.041386: E external/xla/xla/stream_executor/cuda/] Memory usage: 36175872 bytes free, 4093902848 bytes total.
2023-06-22 16:03:23.041476: E external/xla/xla/stream_executor/cuda/] Possibly insufficient driver version: 525.85.5
Traceback (most recent call last):
File "/home/littleliu/Documents/project/trax_learning/", line 7, in <module>
vector = fastnp.ones(3)
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/numpy/", line 2161, in ones
return lax.full(shape, 1, _jnp_dtype(dtype))
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/", line 1205, in full
return broadcast(fill_value, shape)
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/", line 768, in broadcast
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/", line 796, in broadcast_in_dim
return broadcast_in_dim_p.bind(
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/", line 380, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/", line 790, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/", line 132, in apply_primitive
compiled_fun = xla_primitive_callable(
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/", line 284, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/", line 277, in cached
return f(*args, **kwargs)
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/", line 223, in xla_primitive_callable
compiled = _xla_callable_uncached(
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/", line 253, in _xla_callable_uncached
return computation.compile().unsafe_call
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/", line 2329, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/", line 2651, in from_hlo
xla_executable, compile_options = _cached_compilation(
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/", line 2561, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/", line 314, in wrapper
return func(*args, **kwargs)
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.