trax
trax copied to clipboard
trax can not find GPU
Description
I have gpu, but trax just can not find it, it report that: WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
...
Environment information
OS: <your answer here>
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.3 LTS"
$ pip freeze | grep trax
trax==1.3.6
$ pip freeze | grep tensor
mesh-tensorflow==0.1.17
tensorboard==2.4.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.1
tensorflow-datasets==4.1.0
tensorflow-estimator==2.3.0
tensorflow-gpu==2.3.1
tensorflow-metadata==0.25.0
tensorflow-text==2.3.0
$ pip freeze | grep jax
jax==0.2.6
jaxlib==0.1.57
$ python -V
Python 3.7.4
For bugs: reproduction and error logs
run the following python script
import os
import numpy as np
import trax
# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
input_vocab_size=33300,
d_model=512, d_ff=2048,
n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
max_len=2048, mode='predict')
# Initialize using pre-trained weights.
#model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
model.init_from_file('./ende_wmt32k.pkl.gz',
weights_only=True)
# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list(trax.data.tokenize(iter([sentence]), # Operates on streams.
vocab_dir='gs://trax-ml/vocabs/',
vocab_file='ende_32k.subword'))[0]
# Decode from the Transformer.
tokenized = tokenized[None, :] # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
model, tokenized, temperature=0.0) # Higher temperature: more diverse results.
# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1] # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
vocab_dir='gs://trax-ml/vocabs/',
vocab_file='ende_32k.subword')
print(translation)
...
# Error logs:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
2020-11-28 14:51:01.156049: W tensorflow/core/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "Not found:
Could not locate the credentials file.". Retrieving token from GCE failed with "Failed precondition: Error executing an HTTP request: HTTP response code 302 with body '<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "h
ttp://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta http-equiv="X-UA-Compatible" content="IE=7" />
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
<meta name="keywords" content="MWG,Proxy" />
<title>Huawei Proxy Notification</title>
<style type=text/css>
body {
float: none;
background-color: #CCCCCC;
text-align: center;
font-size: 0.75em;
pad'".
Have you tried getting jaxlib from pip3 install --upgrade jax jaxlib==0.1.57+cuda112 -f https://storage.googleapis.com/jax-releases/jax_releases.html
Obviously depending on your CUDA version
More here: https://github.com/google/jax#installation
@Koesters Thanks, this works for me, and a minor correction, at current date, the jaxlib==0.1.57+cuda112
doesn't exist.
use jaxlib==0.1.57+cuda111
instead
This didn't work for me unfortunately. I had to use jaxlib==0.1.59 according to an import error. Then execution got a bit further but still failed.
pip3 install --upgrade jax jaxlib==0.1.59+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
Traceback (most recent call last):
File "test-trax.py", line 15, in <module>
weights_only=True)
File "/home/ubuntu/trax/trax/layers/base.py", line 334, in init_from_file
input_signature, unsafe=True)
File "/home/ubuntu/trax/trax/layers/base.py", line 492, in weights_and_state_signature
rng, state, weights = self.rng, self.state, self.weights
File "/home/ubuntu/trax/trax/layers/base.py", line 510, in rng
self._rng = fastmath.random.get_prng(self._rng_seed_int)
File "/home/ubuntu/trax/trax/fastmath/ops.py", line 68, in get_prng
return backend()['random_get_prng'](seed)
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/api.py", line 398, in f_jitted
return cpp_jitted_f(context, *args, **kwargs)
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/api.py", line 295, in cache_miss
donated_invars=donated_invars)
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/core.py", line 1275, in bind
return call_bind(self, fun, *args, **params)
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/core.py", line 1266, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/core.py", line 1278, in process
return trace.process_call(self, fun, tracers, params)
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/core.py", line 631, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py", line 581, in _xla_call_impl
*unsafe_map(arg_spec, args))
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/linear_util.py", line 260, in memoized_fun
ans = call(fun, *args)
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py", line 727, in _xla_callable
compiled = backend_compile(backend, built, options)
File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py", line 352, in backend_compile
return backend.compile(built_c, compile_options=options)
RuntimeError: Unknown: no kernel image is available for execution on the device
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc(44): 'cuLinkAddData( link_state, CU_JIT_INPUT_CUBIN, static_cast<void*>(image.bytes.data()), image.bytes.size(), "", 0, nullptr, nullptr)'