plenoxels
plenoxels copied to clipboard
question about code
many thanks for your great work!
I have question about the code, can you give me some more detailed explanation?
`
#Compute when the rays enter and leave the grid
offsets_pos = jax.lax.stop_gradient((radius - rays_o) / rays_d)
offsets_neg = jax.lax.stop_gradient((-radius - rays_o) / rays_d)
offsets_in = jax.lax.stop_gradient(jnp.minimum(offsets_pos, offsets_neg))
offsets_out = jax.lax.stop_gradient(jnp.maximum(offsets_pos, offsets_neg))
start = jax.lax.stop_gradient(jnp.max(offsets_in, axis=-1, keepdims=True))
stop = jax.lax.stop_gradient(jnp.min(offsets_out, axis=-1, keepdims=True))
first_intersection = jax.lax.stop_gradient(rays_o + start * rays_d)
`
Let's focus on a single ray. Then offset_pos
is the distance along the ray at which the ray crosses the planes defined by the positive-coordinate faces of the grid cube, and offset_neg
is the same but for the planes defined by the negative-coordinate faces of the cube. For example, if the radius is 1, then offset_pos
would be a 3-vector containing the distance along the ray before it achieves x coordinate equal to 1, y coordinate equal to 1, and z coordinate equal to 1, respectively (and offset_neg
would be a 3-vector with the same ray distances necessary to reach coordinate values of -1). For each coordinate direction, there are two potential points of intersection with the cube, one with the positive face and one with the negative face, stored in offset_pos
and offset_neg
.
For offset_in
and offset_out
, we want to know which of these two potential intersections is achieved first vs second. In other words, we want to know how far along the ray it is is potentially entering vs leaving the cube. In each dimension, the potential entry point is the face intersection with smaller ray distance (offset_in
) and the potential exit point is the face intersection with larger ray distance (offset_out
).
Finally, we want to know which face is actually entered and which is exited, if there is any intersection at all. The ray only actually enters the cube when it has passed the offset_in
for all three dimensions, hence the definition of start
with a maximum over offsets_in
. Likewise, the ray exits the cube as soon as it does so in any of the three dimensions, hence the minimum in the definition of start
. Note that if the ray does not actually intersect the cube, we will compute values such that stop
< start
and be left with essentially an empty ray.
many thanks for your reply! 1:Can I think that the origin of the world coordinate system is the center of the grid, 2: and is the code logic of the jax version the same as in svox2? 3:Is there a volume density in per vertex or per small cube in hte grid?
Let's focus on a single ray. Then
offset_pos
is the distance along the ray at which the ray crosses the planes defined by the positive-coordinate faces of the grid cube, andoffset_neg
is the same but for the planes defined by the negative-coordinate faces of the cube. For example, if the radius is 1, thenoffset_pos
would be a 3-vector containing the distance along the ray before it achieves x coordinate equal to 1, y coordinate equal to 1, and z coordinate equal to 1, respectively (andoffset_neg
would be a 3-vector with the same ray distances necessary to reach coordinate values of -1). For each coordinate direction, there are two potential points of intersection with the cube, one with the positive face and one with the negative face, stored inoffset_pos
andoffset_neg
.For
offset_in
andoffset_out
, we want to know which of these two potential intersections is achieved first vs second. In other words, we want to know how far along the ray it is is potentially entering vs leaving the cube. In each dimension, the potential entry point is the face intersection with smaller ray distance (offset_in
) and the potential exit point is the face intersection with larger ray distance (offset_out
).Finally, we want to know which face is actually entered and which is exited, if there is any intersection at all. The ray only actually enters the cube when it has passed the
offset_in
for all three dimensions, hence the definition ofstart
with a maximum overoffsets_in
. Likewise, the ray exits the cube as soon as it does so in any of the three dimensions, hence the minimum in the definition ofstart
. Note that if the ray does not actually intersect the cube, we will compute values such thatstop
<start
and be left with essentially an empty ray.
sorry to bother you again! these codes also have questions, can you give me some more detailed explanation?:
count = int(resolution*3 / uniform) intersections = jnp.linspace(start=realstart + uniform*voxel_len, stop=realstart + uniform*voxel_len*(count+1), num=count, endpoint=False)
- Yes, the origin is the center of the grid.
- Generally no. The two versions compute the same rendering formula, but the implementations are quite different.
- This is not really an important distinction. In this implementation all values are stored per voxel, but a voxel isn't really that different from a vertex since we are doing trilinear interpolation anyway.
-
count
here is just a loose upper bound on the number of samples we might need along the ray. For JAX to JIT-compile, we need to have the same number of possible intersections along each ray (it's probably not the most efficient implementation; if you see a faster way feel free to submit a PR). There can't be more than3*resolution
voxel lengths along the portion of the ray inside the cube, and we are taking a sample everyuniform
fraction of a voxel length. Then the intersection points are just spaced out by this distance along the ray, starting where the ray first intersects the cube.
sorry to bother you again! why the last intersection was discarded?
pt_sigma = jnp.sum(weights * neighbor_sigma, axis=1)[:-1]
many thanks for your great work! I have question about the code, can you give me some more detailed explanation?
` #Compute when the rays enter and leave the grid
offsets_pos = jax.lax.stop_gradient((radius - rays_o) / rays_d) offsets_neg = jax.lax.stop_gradient((-radius - rays_o) / rays_d) offsets_in = jax.lax.stop_gradient(jnp.minimum(offsets_pos, offsets_neg)) offsets_out = jax.lax.stop_gradient(jnp.maximum(offsets_pos, offsets_neg)) start = jax.lax.stop_gradient(jnp.max(offsets_in, axis=-1, keepdims=True)) stop = jax.lax.stop_gradient(jnp.min(offsets_out, axis=-1, keepdims=True)) first_intersection = jax.lax.stop_gradient(rays_o + start * rays_d)
`
@sarafridov Why directly divide by rays_d, what if the direction component is 0?