bevy_gaussian_splatting
bevy_gaussian_splatting copied to clipboard
depth order validation over ndc cells
https://github.com/mosure/bevy_gaussian_splatting/blob/5c9a20a3478b8cca2dd3df2f49fd26232087fa42/tests/gpu/radix.rs#L197
use std::{
process::exit,
sync::{
Arc,
Mutex,
},
};
use bevy::{
prelude::*,
core::FrameCount,
core_pipeline::core_3d::{
CORE_3D,
Transparent3d,
},
render::{
RenderApp,
renderer::{
RenderContext,
RenderQueue,
},
render_asset::RenderAssets,
render_graph::{
Node,
NodeRunError,
RenderGraphApp,
RenderGraphContext,
},
render_phase::RenderPhase,
view::ExtractedView,
},
};
use bevy_gaussian_splatting::{
GaussianCloud,
GaussianSplattingBundle,
random_gaussians,
sort::SortedEntries,
};
use _harness::{
TestHarness,
test_harness_app,
TestState,
TestStateArc,
};
mod _harness;
pub mod node {
pub const RADIX_SORT_TEST: &str = "radix_sort_test";
}
fn main() {
let mut app = test_harness_app(TestHarness {
resolution: (512.0, 512.0),
});
app.add_systems(Startup, setup);
if let Ok(render_app) = app.get_sub_app_mut(RenderApp) {
render_app
.add_render_graph_node::<RadixTestNode>(
CORE_3D,
node::RADIX_SORT_TEST,
)
.add_render_graph_edge(
CORE_3D,
node::RADIX_SORT_TEST,
bevy::core_pipeline::core_3d::graph::node::END_MAIN_PASS,
);
}
app.run();
}
fn setup(
mut commands: Commands,
mut gaussian_assets: ResMut<Assets<GaussianCloud>>,
) {
let cloud = gaussian_assets.add(random_gaussians(10000));
commands.spawn((
GaussianSplattingBundle {
cloud,
settings: GaussianCloudSettings {
sort_mode: SortMode::Radix,
..default()
},
..default()
},
Name::new("gaussian_cloud"),
));
commands.spawn((
Camera3dBundle {
transform: Transform::from_translation(Vec3::new(0.0, 1.5, 5.0)),
..default()
},
));
}
pub struct RadixTestNode {
gaussian_clouds: QueryState<(
&'static Handle<GaussianCloud>,
&'static Handle<SortedEntries>,
)>,
state: TestStateArc,
views: QueryState<(
&'static ExtractedView,
&'static RenderPhase<Transparent3d>,
)>,
start_frame: u32,
}
impl FromWorld for RadixTestNode {
fn from_world(world: &mut World) -> Self {
Self {
gaussian_clouds: world.query(),
state: Arc::new(Mutex::new(TestState::default())),
views: world.query(),
start_frame: 0,
}
}
}
impl Node for RadixTestNode {
fn update(
&mut self,
world: &mut World,
) {
let mut state = self.state.lock().unwrap();
if state.test_completed {
exit(0);
}
if state.test_loaded && self.start_frame == 0 {
self.start_frame = world.get_resource::<FrameCount>().unwrap().0;
}
let frame_count = world.get_resource::<FrameCount>().unwrap().0;
const FRAME_LIMIT: u32 = 10;
if state.test_loaded && frame_count >= self.start_frame + FRAME_LIMIT {
state.test_completed = true;
}
self.gaussian_clouds.update_archetypes(world);
self.views.update_archetypes(world);
}
fn run(
&self,
_graph: &mut RenderGraphContext,
render_context: &mut RenderContext,
world: &World,
) -> Result<(), NodeRunError> {
for (view, _phase,) in self.views.iter_manual(world) {
let camera_position = view.transform.translation();
for (
cloud_handle,
sorted_entries_handle,
) in self.gaussian_clouds.iter_manual(world) {
let gaussian_cloud_res = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap();
let sorted_entries_res = world.get_resource::<RenderAssets<SortedEntries>>().unwrap();
let mut state = self.state.lock().unwrap();
if gaussian_cloud_res.get(cloud_handle).is_none() || sorted_entries_res.get(sorted_entries_handle).is_none() {
continue;
} else if !state.test_loaded {
state.test_loaded = true;
}
let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();
let sorted_entries = sorted_entries_res.get(sorted_entries_handle).unwrap();
let gaussians = cloud.debug_gpu.gaussians.clone();
wgpu::util::DownloadBuffer::read_buffer(
render_context.render_device().wgpu_device(),
world.get_resource::<RenderQueue>().unwrap().0.as_ref(),
&sorted_entries.sorted_entry_buffer.slice(
0..sorted_entries.sorted_entry_buffer.size()
),
move |buffer: Result<wgpu::util::DownloadBuffer, wgpu::BufferAsyncError>| {
let binding = buffer.unwrap();
let u32_muck = bytemuck::cast_slice::<u8, u32>(&*binding);
let mut radix_sorted_indices = Vec::new();
for i in (1..u32_muck.len()).step_by(2) {
radix_sorted_indices.push((i, u32_muck[i] as usize));
}
// TODO: depth order validation over ndc cells
radix_sorted_indices.iter()
.fold(0.0, |depth_acc, &(entry_idx, idx)| {
if idx == 0 || u32_muck[entry_idx - 1] == 0xffffffff || u32_muck[entry_idx - 1] == 0x0 {
return depth_acc;
}
let position = gaussians[idx].position;
let position_vec3 = Vec3::new(position[0], position[1], position[2]);
let depth = (position_vec3 - camera_position).length();
let depth_is_non_decreasing = depth_acc <= depth;
if !depth_is_non_decreasing {
println!(
"radix keys: [..., {:#010x}, {:#010x}, {:#010x}, ...]",
u32_muck[entry_idx - 1 - 2],
u32_muck[entry_idx - 1],
u32_muck[entry_idx - 1 + 2],
);
}
assert!(depth_is_non_decreasing, "radix sort, non-decreasing check failed: {} > {}", depth_acc, depth);
depth_acc.max(depth)
});
}
);
}
}
Ok(())
}
}