jaxngp
jaxngp copied to clipboard
Installation with pip?
Hi, great work on this package!
I'm looking to use this package in conjunction with other dependencies in a larger project and have no experience with nix. Since there are multiple top-level packages, it's not clear to me how to correctly install jaxngp
. Further, when I try to install from a subdirectory (e.g., trying to install only jax-tcnn
), I'm unable to do so. For example, when trying to install only the jax-tcnn
subpackage using pip install "git+https://github.com/blurgyy/jaxngp.git#egg=jax-tcnn&subdirectory=deps/jax-tcnn"
, I get the error
Building wheels for collected packages: jax-tcnn
Building wheel for jax-tcnn (pyproject.toml) ... error
error: subprocess-exited-with-error
× Building wheel for jax-tcnn (pyproject.toml) did not run successfully.
│ exit code: 1
╰─> [70 lines of output]
WARNING setuptools_scm._integration.setuptools pyproject.toml does not contain a tool.setuptools_scm section
running bdist_wheel
running build
running build_py
creating build
creating build/lib.linux-x86_64-cpython-310
creating build/lib.linux-x86_64-cpython-310/jaxtcnn
copying src/jaxtcnn/__init__.py -> build/lib.linux-x86_64-cpython-310/jaxtcnn
creating build/lib.linux-x86_64-cpython-310/jaxtcnn/hashgrid_tcnn
copying src/jaxtcnn/hashgrid_tcnn/impl.py -> build/lib.linux-x86_64-cpython-310/jaxtcnn/hashgrid_tcnn
copying src/jaxtcnn/hashgrid_tcnn/lowering.py -> build/lib.linux-x86_64-cpython-310/jaxtcnn/hashgrid_tcnn
copying src/jaxtcnn/hashgrid_tcnn/__init__.py -> build/lib.linux-x86_64-cpython-310/jaxtcnn/hashgrid_tcnn
copying src/jaxtcnn/hashgrid_tcnn/abstract.py -> build/lib.linux-x86_64-cpython-310/jaxtcnn/hashgrid_tcnn
running egg_info
writing src/jax_tcnn.egg-info/PKG-INFO
writing dependency_links to src/jax_tcnn.egg-info/dependency_links.txt
writing requirements to src/jax_tcnn.egg-info/requires.txt
writing top-level names to src/jax_tcnn.egg-info/top_level.txt
writing manifest file 'src/jax_tcnn.egg-info/SOURCES.txt'
running build_ext
Traceback (most recent call last):
File "/home/albert/mambaforge/envs/pong_new/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 353, in <module>
main()
File "/home/albert/mambaforge/envs/pong_new/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 335, in main
json_out['return_val'] = hook(**hook_input['kwargs'])
File "/home/albert/mambaforge/envs/pong_new/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 251, in build_wheel
return _build_backend().build_wheel(wheel_directory, config_settings,
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/build_meta.py", line 434, in build_wheel
return self._build_with_temp_dir(
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/build_meta.py", line 419, in _build_with_temp_dir
self.run_setup()
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/build_meta.py", line 341, in run_setup
exec(code, locals())
File "<string>", line 86, in <module>
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/__init__.py", line 103, in setup
return distutils.core.setup(**attrs)
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 185, in setup
return run_commands(dist)
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 201, in run_commands
dist.run_commands()
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 969, in run_commands
self.run_command(cmd)
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/dist.py", line 989, in run_command
super().run_command(command)
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
cmd_obj.run()
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/wheel/bdist_wheel.py", line 364, in run
self.run_command("build")
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 318, in run_command
self.distribution.run_command(command)
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/dist.py", line 989, in run_command
super().run_command(command)
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
cmd_obj.run()
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/command/build.py", line 131, in run
self.run_command(cmd_name)
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 318, in run_command
self.distribution.run_command(command)
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/dist.py", line 989, in run_command
super().run_command(command)
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
cmd_obj.run()
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 88, in run
_build_ext.run(self)
File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 345, in run
self.build_extensions()
File "<string>", line 54, in build_extensions
File "/home/albert/mambaforge/envs/pong_new/lib/python3.10/os.py", line 680, in __getitem__
raise KeyError(key) from None
KeyError: 'cmakeFlags'
[end of output]
note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed building wheel for jax-tcnn
Failed to build jax-tcnn
ERROR: Could not build wheels for jax-tcnn, which is required to install pyproject.toml-based projects
Some guidance on installation would be helpful! If it helps, I only really need a jax
version of the TCNN hash encoder.
Hi @alberthli, I think there are several steps should be done to install this package:
- First of all, obtain a local copy of this repository
- The error about the environment variable can be avoided by removing the
+ os.environ["cmakeFlags"].split()
part here - To build the binding, you also have to copy the
serde-helper.h
file (here) to the directorydeps/jax-tcnn/lib
, and include it with relative paths indeps/jax-tcnn/lib/ffi.cc
anddeps/jax-tcnn/lib/impl/hashgrid.cu
- You should build tiny-cuda-nn's library first (follow tiny-cuda-nn's instructions), and probably modify
deps/jax-tcnn/CMakeLists.txt
so that CMake can find tiny-cuda-nn's library - tiny-cuda-nn's
include/
directory should be in the search path in your build environment (one way to specify this is to supply the argument-I/path/to/tiny-cuda-nn/include
to the compiler)
Here's a modified CMakeLists.txt
for pip installation:
cmake_minimum_required(VERSION 3.23)
project(volume_rendering_jax LANGUAGES CXX CUDA)
# use `cmake -DCMAKE_CUDA_ARCHITECTURES=61;62;75` to build for compute capabilities 61, 62, and 75
# set(CMAKE_CUDA_ARCHITECTURES "all")
message(STATUS "Enabled CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
message(STATUS "Using CMake version " ${CMAKE_VERSION})
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --extended-lambda")
find_package(Python COMPONENTS Interpreter Development REQUIRED)
find_package(pybind11 CONFIG REQUIRED)
find_package(fmt REQUIRED)
include_directories(${CMAKE_CURRENT_LIST_DIR}/lib)
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
pybind11_add_module(
tcnnutils
${CMAKE_CURRENT_LIST_DIR}/lib/impl/hashgrid.cu
${CMAKE_CURRENT_LIST_DIR}/lib/ffi.cc
)
# e.g. `cmake -DTCNN_MIN_GPU_ARCH=61`
message(STATUS "TCNN_MIN_GPU_ARCH=35")
target_compile_definitions(tcnnutils PUBLIC -DTCNN_MIN_GPU_ARCH=35)
target_link_libraries(tcnnutils PRIVATE tiny-cuda-nn fmt::fmt)
install(TARGETS tcnnutils DESTINATION jaxtcnn)
To add all required headers, you can first clone tiny-cuda-nn
(with all the submodules) and check out to the v1.6 tag, then go to deps/jax-tcnn/lib
and symlink the directories:
$ git clone https://github.com/nvlabs/tiny-cuda-nn.git --recursive
$ cd tiny-cuda-nn
$ git checkout v1.6
$ cd ..
$ git clone https://github.com/blurgyy/jaxngp.git
$ cd jaxngp/deps/jax-tcnn/lib
$ ln -s /path/to/tiny-cuda-nn/include/tiny-cuda-nn /path/to/tiny-cuda-nn/dependencies/* .
You should then build tiny-cuda-nn to obtain the static library libtiny-cuda-nn.a
, putting it to a proper location in your system, and run pip install -v /path/to/jaxngp/deps/jax-tcnn
.
@blurgyy Thanks for the edits/instructions - I finally got back around to looking at this and successfully built jax-tcnn
.
I have a follow-up question: my goal is to train a nerfacto
model from nerfstudio
using the tcnn
fully fused MLP (which calls the torch
bindings) and then load those weights into a jax
model that just needs to do inference. The reason for this is that I want the NeRF density field to be used in conjunction with a large jax
codebase, but I also want to reduce the amount of time it takes to train the NeRF, and nerfstudio
is very entrenched in the NeRF side of our experimental pipeline.
However, when I initialize jaxtcnn
TCNNHashGridEncoder
and CoordinateBasedMLP
modules, I find that the total number of parameters is exactly 24 fewer than the tcnn.NetworkWithInputEncoding
(12196216 + 3072 parameters vs. 12199312). I have also confirmed that the entire difference is accounted for by the hash grid and not the MLP by forcing nerfstudio
to train a NeRF by separating the encoder and MLP and recounting the parameters. I have double checked that the parameters I use to initialize the encoders are the same across the jax
and torch
versions.
Do you have any idea where this parameter discrepancy comes from in the hash grid implementation? Could this be resolved just by building the most recent version of tiny-cuda-nn
instead of v1.6
or would there be modifications needed on the jax-tcnn
side? I would be happy to provide more information and relevant files if needed.
EDIT: In my fork of this branch, I've made some minor modifications to allow jax-tcnn
to be compatible with the latest version of jax
and jaxlib
as well as the latest version of tiny-cuda-nn
, just in case there were differences in versioning causing the discrepancy. I re-built tcnn
and jax-tcnn
with these changes and the same parameter discrepancy exists. See the diff between your repo and my fork here.
To initialize the encoder, I'm using the parameters
L=16
F=2
T=2**19
N_min=16
N_max=2048
Hi @alberthli,
I recently encountered a use case where I need to calibrate the JAX's HashGridEncoder
(not TCNNHashGridEncoder
)'s parameter layout with that of tiny-cuda-nn's, I just pushed an update to address this (commit 04bcea2ab5bd2e2ecec8fc83c9ce191120f9cd8e).
I used the following parameters to initialize and train a HashGridEncoder from the JAX side, and used tcnn's pytorch bindings to load it, the parameter count and per-layer hashgrid output is checked to match (there are still absolute error not larger than 1e-3).
-
initialization from jaxngp (this repo):
from models.encoders import HashGridEncoder jax_hg = HashGridEncoder(L=16, T=2**19, F=2, N_min=32, N_max=2048, tv_scale=0.)
-
initializing using tcnn's pytorch bindings:
import math import tinycudann as tcnn import torch L = 16 F = 2 N_min = 32 N_max = 2048 encoding_config = { "otype": "HashGrid", "n_levels": L, "n_features_per_level": F, "log2_hashmap_size": 19, "base_resolution": N_min, "per_level_scale": math.exp((math.log(N_max) - math.log(N_min)) / (L - 1)), "interpolation": "Linear", } tcnn_hg = tcnn.Encoding(n_input_dims=3, encoding_config=encoding_config, dtype=torch.float32)
-
load parameters from jaxngp's hashgrid into tcnn's:
state_dict = { "params": torch.as_tensor(jax_hg_params_dict["latent codes stored on grid vertices"].ravel()).to("cuda"), } tcnn_hg.load_state_dict(state_dict)
The jax_hg_params_dict
is the hashgrid encoder's parameters, it can be saved to disk using numpy.save
and loaded using numpy.load(path_to_the_npy_file).item()
to get the dict
object. Note that the parameter key is literally latent codes stored on grid vertices
, see here.
I hope this helps.