trax
trax copied to clipboard
Could not normally run trax using GPU in local computer
Description
Hi, I would like to install trax locally. Firstly, I found jax I installed is not suitable for GPU, so I follow the jax github to install Cuda version jax. Next, I validate jax could detect GPU in my local computer, but I could not run the sample code like transfomer and fast math.
Environment information
OS: Pop-os(based on ubuntu 22.04)
$ pip freeze | grep trax
# trax==1.4.1
$ pip freeze | grep tensor
# tensorboard==2.12.3
# tensorboard-data-server==0.7.1
# tensorflow==2.12.0
# tensorflow-datasets==4.9.2
# tensorflow-estimator==2.12.0
# tensorflow-hub==0.13.0
# tensorflow-io-gcs-filesystem==0.32.0
# tensorflow-metadata==1.13.1
# tensorflow-text==2.12.1
$ pip freeze | grep jax
# jax==0.4.12
# jaxlib==0.4.12+cuda11.cudnn86
$ python -V
# Python 3.11.3
For bugs: reproduction and error logs
# Steps to reproduce:
1) Install trax
- pip install trax
- pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
2) Use jax Detect GPU
- code:
import jax
print(jax.devices())
- output:
[gpu(id=0)]
# Error logs:
1) Run the sample code of pre-trained transformer in your Realme tutorial
- code:
import os
import numpy as np
import trax
# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
input_vocab_size=33300,
d_model=512, d_ff=2048,
n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
max_len=64, mode='predict')
# Initialize using pre-trained weights.
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
weights_only=True)
# input_signature=input_signature)
# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list(trax.data.tokenize(iter([sentence]), # Operates on streams.
vocab_dir='gs://trax-ml/vocabs/',
vocab_file='ende_32k.subword'))[0]
# Decode from the Transformer.
tokenized = tokenized[None, :] # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
model, tokenized, temperature=0.0) # Higher temperature: more diverse results.
# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1] # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
vocab_dir='gs://trax-ml/vocabs/',
vocab_file='ende_32k.subword')
print(translation)
- Error Output:
2023-06-22 15:58:35.266959: W tensorflow/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".
2023-06-22 15:58:56.630331: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Failed to get stream's capture status: out of memory
2023-06-22 15:58:56.630403: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2461] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to get stream's capture status: out of memory; current tracing scope: fusion; current profiling annotation: XlaModule:#hlo_module=jit_PRNGKey,program_id=0#.
Traceback (most recent call last):
File "/home/littleliu/Documents/project/trax_learning/tryTrax.py", line 22, in <module>
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/trax/layers/base.py", line 349, in init_from_file
self.init(input_signature)
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/trax/layers/base.py", line 310, in init
raise LayerError(name, 'init', self._caller,
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/models/transformer.py, line 371
layer input shapes: (ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:float32})
File [...]/trax/layers/combinators.py, line 108, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File [...]/trax/layers/base.py, line 641, in _forward_abstract
layer created in file [...]/trax/models/transformer.py, line 372
layer input shapes: (ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64})
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to get stream's capture status: out of memory; current tracing scope: fusion; current profiling annotation: XlaModule:#hlo_module=jit_PRNGKey,program_id=0#.
2) Run the sample code of Fast Math:
- code:
import trax
from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax') # Can be 'jax' or 'tensorflow-numpy'.
matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'matrix =\n{matrix}')
vector = fastnp.ones(3)
print(f'vector = {vector}')
product = fastnp.dot(vector, matrix)
print(f'product = {product}')
tanh = fastnp.tanh(product)
print(f'tanh(product) = {tanh}')
- Error Output:
matrix =
[[1 2 3]
[4 5 6]
[7 8 9]]
2023-06-22 16:03:23.041313: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-06-22 16:03:23.041386: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 36175872 bytes free, 4093902848 bytes total.
2023-06-22 16:03:23.041476: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:453] Possibly insufficient driver version: 525.85.5
Traceback (most recent call last):
File "/home/littleliu/Documents/project/trax_learning/fastnumpy.py", line 7, in <module>
vector = fastnp.ones(3)
^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2161, in ones
return lax.full(shape, 1, _jnp_dtype(dtype))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 1205, in full
return broadcast(fill_value, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 768, in broadcast
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 796, in broadcast_in_dim
return broadcast_in_dim_p.bind(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 380, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 790, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 132, in apply_primitive
compiled_fun = xla_primitive_callable(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/util.py", line 284, in wrapper
return cached(config._trace_context(), *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/util.py", line 277, in cached
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 223, in xla_primitive_callable
compiled = _xla_callable_uncached(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 253, in _xla_callable_uncached
return computation.compile().unsafe_call
^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2329, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2651, in from_hlo
xla_executable, compile_options = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2561, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.