tapnet icon indicating copy to clipboard operation
tapnet copied to clipboard

is there a standard procedure to make tapir run inference on gpu/ubuntu22.04

Open CHYjeremy opened this issue 2 years ago • 4 comments

Hi everyone,

i find it really hard to get tapir to run on gpu, is there a standard procedure to do this?

the thing i do/try is: (after i create a new conda environment)

  1. I first do this: as instructed by jax pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html (for jax/cuda/cudnn installation i suppose)
  2. then i do: pip install requirements_inference.txt

and the following error pops out jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to load PTX text as a module: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid; current tracing scope: fusion; current profiling annotation: XlaModule:#hlo_module=jit__threefry_seed,program_id=0#.

note that using only cpu version of this wouldn't hurt (simply pip install requirement_inference.txt)

could someone state your standard procedure for making it work? much thanks

CHYjeremy avatar Aug 06 '23 04:08 CHYjeremy

I've been running on ubuntu 20.04. I wasn't able to make it work using the nvidia drivers and cuda that are distributed with ubuntu; I needed to uninstall these and install nvidia's versions of the driver, CUDA, and CUDNN which match the JAX version. However, after installing, CUDA wasn't on my PATH (IIRC I did get a similar error message about failing to load PTX as a result). I found that export PATH=/usr/local/cuda/bin:$PATH before running the live demo made it work.

I have no idea if this is your issue, however. You might be better off posting this question on the JAX github.

cdoersch avatar Aug 09 '23 23:08 cdoersch

have you solved your problem? if you solved, could you please show the resolution?

xbowlove avatar Nov 25 '23 07:11 xbowlove

@xbowlove I did the following things and running the live demo on local laptop GPU worked for me.

  • created local virtual env using venv
  • git clone https://github.com/deepmind/tapnet.git
  • pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html For Reference
  • commented out jax jaxline from given tapnet/requirements_inference.txt.
  • pip install -r requirements_inference.txt

After this followed the rest step mentioned in the REAME file to run the live demo.

kumar-sanjeeev avatar Dec 13 '23 14:12 kumar-sanjeeev

thanks for your answer. but I occur the following error when 'from jaxline import platform' after finishing all the steps you offered. Traceback (most recent call last): File "/home/jishengyin/newpan/tapnet/./experiment.py", line 30, in from jaxline import platform File "/home/jishengyin/anaconda3/lib/python3.11/site-packages/jaxline/platform.py", line 34, in import tensorflow as tf File "/home/jishengyin/anaconda3/lib/python3.11/site-packages/tensorflow/init.py", line 48, in from tensorflow._api.v2 import internal File "/home/jishengyin/anaconda3/lib/python3.11/site-packages/tensorflow/_api/v2/internal/init.py", line 8, in from tensorflow._api.v2.internal import autograph File "/home/jishengyin/anaconda3/lib/python3.11/site-packages/tensorflow/_api/v2/internal/autograph/init.py", line 8, in from tensorflow.python.autograph.core.ag_ctx import control_status_ctx # line: 34 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/jishengyin/anaconda3/lib/python3.11/site-packages/tensorflow/python/autograph/core/ag_ctx.py", line 21, in from tensorflow.python.autograph.utils import ag_logging File "/home/jishengyin/anaconda3/lib/python3.11/site-packages/tensorflow/python/autograph/utils/init.py", line 17, in from tensorflow.python.autograph.utils.context_managers import control_dependency_on_returns File "/home/jishengyin/anaconda3/lib/python3.11/site-packages/tensorflow/python/autograph/utils/context_managers.py", line 19, in from tensorflow.python.framework import ops File "/home/jishengyin/anaconda3/lib/python3.11/site-packages/tensorflow/python/framework/ops.py", line 29, in from tensorflow.core.framework import attr_value_pb2 File "/home/jishengyin/anaconda3/lib/python3.11/site-packages/tensorflow/core/framework/attr_value_pb2.py", line 5, in from google.protobuf.internal import builder as _builder ImportError: cannot import name 'builder' from 'google.protobuf.internal' (unknown location)

nutsintheshell avatar Jan 15 '24 07:01 nutsintheshell