trax icon indicating copy to clipboard operation
trax copied to clipboard

trax can not find GPU

Open shengyushen opened this issue 4 years ago • 3 comments

Description

I have gpu, but trax just can not find it, it report that: WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

...

Environment information

OS: <your answer here>
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.3 LTS"

$ pip freeze | grep trax
trax==1.3.6

$ pip freeze | grep tensor
mesh-tensorflow==0.1.17
tensorboard==2.4.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.1
tensorflow-datasets==4.1.0
tensorflow-estimator==2.3.0
tensorflow-gpu==2.3.1
tensorflow-metadata==0.25.0
tensorflow-text==2.3.0


$ pip freeze | grep jax
jax==0.2.6
jaxlib==0.1.57

$ python -V
Python 3.7.4

For bugs: reproduction and error logs

run the following python script

import os
import numpy as np
import trax

# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
    input_vocab_size=33300,
    d_model=512, d_ff=2048,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=2048, mode='predict')

# Initialize using pre-trained weights.
#model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
model.init_from_file('./ende_wmt32k.pkl.gz',
                     weights_only=True)

# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
                                    vocab_dir='gs://trax-ml/vocabs/',
                                    vocab_file='ende_32k.subword'))[0]

# Decode from the Transformer.
tokenized = tokenized[None, :]  # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
    model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.

# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
                                   vocab_dir='gs://trax-ml/vocabs/',
                                   vocab_file='ende_32k.subword')
print(translation)
...
# Error logs:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
2020-11-28 14:51:01.156049: W tensorflow/core/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "Not found:
Could not locate the credentials file.". Retrieving token from GCE failed with "Failed precondition: Error executing an HTTP request: HTTP response code 302 with body '<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "h
ttp://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
  <meta http-equiv="X-UA-Compatible" content="IE=7" />
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
<meta name="keywords" content="MWG,Proxy" />
<title>Huawei Proxy Notification</title>
<style type=text/css>
body {
        float: none;
        background-color: #CCCCCC;
        text-align: center;
        font-size: 0.75em;
        pad'".

shengyushen avatar Nov 28 '20 14:11 shengyushen

Have you tried getting jaxlib from pip3 install --upgrade jax jaxlib==0.1.57+cuda112 -f https://storage.googleapis.com/jax-releases/jax_releases.html

Obviously depending on your CUDA version

More here: https://github.com/google/jax#installation

Koesters avatar Dec 02 '20 03:12 Koesters

@Koesters Thanks, this works for me, and a minor correction, at current date, the jaxlib==0.1.57+cuda112 doesn't exist. use jaxlib==0.1.57+cuda111 instead

Z-Y00 avatar Jan 23 '21 15:01 Z-Y00

This didn't work for me unfortunately. I had to use jaxlib==0.1.59 according to an import error. Then execution got a bit further but still failed.

pip3 install --upgrade jax jaxlib==0.1.59+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
Traceback (most recent call last):
  File "test-trax.py", line 15, in <module>
    weights_only=True)
  File "/home/ubuntu/trax/trax/layers/base.py", line 334, in init_from_file
    input_signature, unsafe=True)
  File "/home/ubuntu/trax/trax/layers/base.py", line 492, in weights_and_state_signature
    rng, state, weights = self.rng, self.state, self.weights
  File "/home/ubuntu/trax/trax/layers/base.py", line 510, in rng
    self._rng = fastmath.random.get_prng(self._rng_seed_int)
  File "/home/ubuntu/trax/trax/fastmath/ops.py", line 68, in get_prng
    return backend()['random_get_prng'](seed)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/api.py", line 398, in f_jitted
    return cpp_jitted_f(context, *args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/api.py", line 295, in cache_miss
    donated_invars=donated_invars)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/core.py", line 1275, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/core.py", line 1266, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/core.py", line 1278, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/core.py", line 631, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py", line 581, in _xla_call_impl
    *unsafe_map(arg_spec, args))
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/linear_util.py", line 260, in memoized_fun
    ans = call(fun, *args)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py", line 727, in _xla_callable
    compiled = backend_compile(backend, built, options)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py", line 352, in backend_compile
    return backend.compile(built_c, compile_options=options)
RuntimeError: Unknown: no kernel image is available for execution on the device
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc(44): 'cuLinkAddData( link_state, CU_JIT_INPUT_CUBIN, static_cast<void*>(image.bytes.data()), image.bytes.size(), "", 0, nullptr, nullptr)'

mokshasoft avatar Feb 02 '21 07:02 mokshasoft