jax-windows-builder
jax-windows-builder copied to clipboard
installation is not working
Hi there,
First of all, thank you for supporting windows! I've used this build before with great success. However, at the moment it's not working, nor can I find a way to get an older version to work.
I'm trying to set up jax on a windows PC with conda, but the provided instructions do not work anymore. I also can't really get any other version to work.
I'm installing on a laptop, this is the output from nvidia-smi:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 527.83 Driver Version: 527.83 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name TCC/WDDM | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... WDDM | 00000000:01:00.0 Off | N/A |
| N/A 50C P0 9W / 30W | 0MiB / 2048MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
I tried:
conda create -n jaxtest python
conda activate jax_test
# install it
pip install jax[cuda111] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
# installs numpy, etc..
# raises a warning:
# WARNING: jax 0.4.19 does not provide the extra 'cuda111'
python -m jax
# File "C:\ProgramData\Anaconda3\envs\jax_test\Lib\site-packages\jax\_src\lib\__init__.py", line 27, in <module>
# raise ModuleNotFoundError(
# ModuleNotFoundError: jax requires jaxlib to be installed. See https://github.com/google/jax#installation for installation instructions.
This might obviously not work for cuda 12.0. However, If i run it with
pip install jax[pip_cuda12] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
I get the same result.
I also tried this for python==3.11, python==3.10
or python==3.9
. Same result.
When I just download a jaxlib it also does not work, sometimes I get a bit further but no computations can be done and I run into 'AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11''.
What should the python version be? And what would be the right command?
Instructions to get halfway (python 3.10):
pip install jaxlib==0.4.11+cuda12.cudnn89 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
# now, when importing we get
# ModuleNotFoundError: No module named 'ml_dtypes._ml_dtypes_ext'
# solve as per https://developer.apple.com/forums/thread/737890
pip install ml_dtypes==0.2.0
# now, numpy is not working
pip install -U numpy --force-reinstallation
# now, it can be run
However, when I now open python, I get:
from jax import numpy as jnp
a = jnp.zeros(5)
# external/xla/xla/stream_executor/cuda/cuda_dnn.cc:407] There was an error before creating cudnn handle (302): cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found.
So still not usable.
Use -f https://whls.blob.core.windows.net/unstable/index.html
may not work because jax changed their extras options handling. You need to open the link and download the whl file manually, and install the compatible jax , not the latest version.
So, I downloaded the wheel, and installed
pip install jax==0.4.13
As far as I can see that should be compatible with jaxlib==0.4.11 (based on the source code)
If I run it, I still get
#jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
Could you please set environment variable TF_CPP_MIN_LOG_LEVEL
to 0?
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
import jax
jax.numpy.array([0])
There used to be some useful dll info, not sure how it goes now, tho. Might worth a try.
So, I reinstalled everything from scratch, just to make sure it's not because of some old environment that I tried:
conda env create -n jax
conda activate jax
conda install numpy scipy jupyter
# this should download the same file, just putting it in for reproducability
pip install jaxlib==0.4.11+cuda12.cudnn89 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
pip install jax==0.4.13
conda install nvidiatoolkit
And then I ran the script above. I get the following output:
2023-10-25 09:56:55.529123: I external/tsl/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-25 09:56:55.650774: I external/tsl/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-25 09:56:55.653221: I external/tsl/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-25 09:56:55.916782: I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:435] TfrtCpuClient created.
2023-10-25 09:56:56.505294: I external/xla/xla/service/service.cc:168] XLA service 0x40f9d20 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-10-25 09:56:56.505474: I external/xla/xla/service/service.cc:176] StreamExecutor device (0): NVIDIA GeForce MX550, Compute Capability 7.5
2023-10-25 09:56:56.506396: I external/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc:545] Using BFC allocator.
2023-10-25 09:56:56.508144: I external/xla/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 1610416128 bytes on device 0 for BFCAllocator.
2023-10-25 09:56:56.613734: I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:438] TfrtCpuClient destroyed.
It starts with not finding cuda, but then it does seem to find it.
The full traceback is here:
XlaRuntimeError Traceback (most recent call last)
Cell In[3], line 1
----> 1 a = jax.numpy.zeros(512)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\numpy\lax_numpy.py:2153, in zeros(shape, dtype)
2151 dtypes.check_user_dtype_supported(dtype, "zeros")
2152 shape = canonicalize_shape(shape)
-> 2153 return lax.full(shape, 0, _jnp_dtype(dtype))
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\lax\lax.py:1206, in full(shape, fill_value, dtype)
1204 dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
1205 fill_value = _convert_element_type(fill_value, dtype, weak_type)
-> 1206 return broadcast(fill_value, shape)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\lax\lax.py:768, in broadcast(operand, sizes)
754 """Broadcasts an array, adding new leading dimensions
755
756 Args:
(...)
765 jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape.
766 """
767 dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
--> 768 return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\lax\lax.py:797, in broadcast_in_dim(operand, shape, broadcast_dimensions)
795 else:
796 dyn_shape, static_shape = [], shape # type: ignore
--> 797 return broadcast_in_dim_p.bind(
798 operand, *dyn_shape, shape=tuple(static_shape),
799 broadcast_dimensions=tuple(broadcast_dimensions))
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\core.py:380, in Primitive.bind(self, *args, **params)
377 def bind(self, *args, **params):
378 assert (not config.jax_enable_checks or
379 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 380 return self.bind_with_trace(find_top_trace(args), args, params)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\core.py:383, in Primitive.bind_with_trace(self, trace, args, params)
382 def bind_with_trace(self, trace, args, params):
--> 383 out = trace.process_primitive(self, map(trace.full_raise, args), params)
384 return map(full_lower, out) if self.multiple_results else full_lower(out)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\core.py:815, in EvalTrace.process_primitive(self, primitive, tracers, params)
814 def process_primitive(self, primitive, tracers, params):
--> 815 return primitive.impl(*tracers, **params)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:132, in apply_primitive(prim, *args, **params)
130 try:
131 in_avals, in_shardings = util.unzip2([arg_spec(a) for a in args])
--> 132 compiled_fun = xla_primitive_callable(
133 prim, in_avals, OrigShardings(in_shardings), **params)
134 except pxla.DeviceAssignmentMismatchError as e:
135 fails, = e.args
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\util.py:284, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
282 return f(*args, **kwargs)
283 else:
--> 284 return cached(config._trace_context(), *args, **kwargs)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\util.py:277, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
275 @functools.lru_cache(max_size)
276 def cached(_, *args, **kwargs):
--> 277 return f(*args, **kwargs)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:223, in xla_primitive_callable(prim, in_avals, orig_in_shardings, **params)
221 return out,
222 donated_invars = (False,) * len(in_avals)
--> 223 compiled = _xla_callable_uncached(
224 lu.wrap_init(prim_fun), prim.name, donated_invars, False, in_avals,
225 orig_in_shardings)
226 if not prim.multiple_results:
227 return lambda *args, **kw: compiled(*args, **kw)[0]
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:253, in _xla_callable_uncached(fun, name, donated_invars, keep_unused, in_avals, orig_in_shardings)
248 def _xla_callable_uncached(fun: lu.WrappedFun, name, donated_invars,
249 keep_unused, in_avals, orig_in_shardings):
250 computation = sharded_lowering(
251 fun, name, donated_invars, keep_unused, True, in_avals, orig_in_shardings,
252 lowering_platform=None)
--> 253 return computation.compile().unsafe_call
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\interpreters\pxla.py:2323, in MeshComputation.compile(self, compiler_options)
2320 executable = MeshExecutable.from_trivial_jaxpr(
2321 **self.compile_args)
2322 else:
-> 2323 executable = UnloadedMeshExecutable.from_hlo(
2324 self._name,
2325 self._hlo,
2326 **self.compile_args,
2327 compiler_options=compiler_options)
2328 if compiler_options is None:
2329 self._executable = executable
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\interpreters\pxla.py:2645, in UnloadedMeshExecutable.from_hlo(***failed resolving arguments***)
2642 mesh = i.mesh # type: ignore
2643 break
-> 2645 xla_executable, compile_options = _cached_compilation(
2646 hlo, name, mesh, spmd_lowering,
2647 tuple_args, auto_spmd_lowering, allow_prop_to_outputs,
2648 tuple(host_callbacks), backend, da, pmap_nreps,
2649 compiler_options_keys, compiler_options_values)
2651 if hasattr(backend, "compile_replicated"):
2652 semantics_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\interpreters\pxla.py:2555, in _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, _allow_propagation_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_keys, compiler_options_values)
2550 return None, compile_options
2552 with dispatch.log_elapsed_time(
2553 "Finished XLA compilation of {fun_name} in {elapsed_time} sec",
2554 fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
-> 2555 xla_executable = dispatch.compile_or_get_cached(
2556 backend, computation, dev, compile_options, host_callbacks)
2557 return xla_executable, compile_options
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:497, in compile_or_get_cached(backend, computation, devices, compile_options, host_callbacks)
493 use_compilation_cache = (compilation_cache.is_initialized() and
494 backend.platform in supported_platforms)
496 if not use_compilation_cache:
--> 497 return backend_compile(backend, computation, compile_options,
498 host_callbacks)
500 cache_key = compilation_cache.get_cache_key(
501 computation, devices, compile_options, backend)
503 cached_executable = _cache_read(module_name, cache_key, compile_options,
504 backend)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\profiler.py:314, in annotate_function.<locals>.wrapper(*args, **kwargs)
311 @wraps(func)
312 def wrapper(*args, **kwargs):
313 with TraceAnnotation(name, **decorator_kwargs):
--> 314 return func(*args, **kwargs)
315 return wrapper
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:465, in backend_compile(backend, module, options, host_callbacks)
460 return backend.compile(built_c, compile_options=options,
461 host_callbacks=host_callbacks)
462 # Some backends don't have `host_callbacks` option yet
463 # TODO(sharadmv): remove this fallback when all backends allow `compile`
464 # to take in `host_callbacks`
--> 465 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
Once we get it to work, I could make a conda environment file that hopefully works without having to go through the same options. Would you be interested in including that?
conda install nvidiatoolkit
does it config the PATH for you? If not, you might need to manually config the PATH to include the dir of cuda libraries.
Yes, it does. I also just checked, looking with
import os
os.environ
Gives, among others CUDA_PATH': 'C:\\ProgramData\\Anaconda3\\envs\\jax'
I also installed cupy, that works without problems.
Then does cudnn*.dll exists under that dir?