bevy_gaussian_splatting icon indicating copy to clipboard operation
bevy_gaussian_splatting copied to clipboard

depth order validation over ndc cells

Open github-actions[bot] opened this issue 7 months ago • 0 comments

https://github.com/mosure/bevy_gaussian_splatting/blob/5c9a20a3478b8cca2dd3df2f49fd26232087fa42/tests/gpu/radix.rs#L197


use std::{
    process::exit,
    sync::{
        Arc,
        Mutex,
    },
};

use bevy::{
    prelude::*,
    core::FrameCount,
    core_pipeline::core_3d::{
        CORE_3D,
        Transparent3d,
    },
    render::{
        RenderApp,
        renderer::{
            RenderContext,
            RenderQueue,
        },
        render_asset::RenderAssets,
        render_graph::{
            Node,
            NodeRunError,
            RenderGraphApp,
            RenderGraphContext,
        },
        render_phase::RenderPhase,
        view::ExtractedView,
    },
};

use bevy_gaussian_splatting::{
    GaussianCloud,
    GaussianSplattingBundle,
    random_gaussians,
    sort::SortedEntries,
};

use _harness::{
    TestHarness,
    test_harness_app,
    TestState,
    TestStateArc,
};

mod _harness;


pub mod node {
    pub const RADIX_SORT_TEST: &str = "radix_sort_test";
}


fn main() {
    let mut app = test_harness_app(TestHarness {
        resolution: (512.0, 512.0),
    });

    app.add_systems(Startup, setup);

    if let Ok(render_app) = app.get_sub_app_mut(RenderApp) {
        render_app
            .add_render_graph_node::<RadixTestNode>(
                CORE_3D,
                node::RADIX_SORT_TEST,
            )
            .add_render_graph_edge(
                CORE_3D,
                node::RADIX_SORT_TEST,
                 bevy::core_pipeline::core_3d::graph::node::END_MAIN_PASS,
            );
    }

    app.run();
}

fn setup(
    mut commands: Commands,
    mut gaussian_assets: ResMut<Assets<GaussianCloud>>,
) {
    let cloud = gaussian_assets.add(random_gaussians(10000));

    commands.spawn((
        GaussianSplattingBundle {
            cloud,
            settings: GaussianCloudSettings {
                sort_mode: SortMode::Radix,
                ..default()
            },
            ..default()
        },
        Name::new("gaussian_cloud"),
    ));

    commands.spawn((
        Camera3dBundle {
            transform: Transform::from_translation(Vec3::new(0.0, 1.5, 5.0)),
            ..default()
        },
    ));
}


pub struct RadixTestNode {
    gaussian_clouds: QueryState<(
        &'static Handle<GaussianCloud>,
        &'static Handle<SortedEntries>,
    )>,
    state: TestStateArc,
    views: QueryState<(
        &'static ExtractedView,
        &'static RenderPhase<Transparent3d>,
    )>,
    start_frame: u32,
}

impl FromWorld for RadixTestNode {
    fn from_world(world: &mut World) -> Self {
        Self {
            gaussian_clouds: world.query(),
            state: Arc::new(Mutex::new(TestState::default())),
            views: world.query(),
            start_frame: 0,
        }
    }
}


impl Node for RadixTestNode {
    fn update(
        &mut self,
        world: &mut World,
    ) {
        let mut state = self.state.lock().unwrap();
        if state.test_completed {
            exit(0);
        }

        if state.test_loaded && self.start_frame == 0 {
            self.start_frame = world.get_resource::<FrameCount>().unwrap().0;
        }

        let frame_count = world.get_resource::<FrameCount>().unwrap().0;
        const FRAME_LIMIT: u32 = 10;
        if state.test_loaded && frame_count >= self.start_frame + FRAME_LIMIT {
            state.test_completed = true;
        }

        self.gaussian_clouds.update_archetypes(world);
        self.views.update_archetypes(world);
    }

    fn run(
        &self,
        _graph: &mut RenderGraphContext,
        render_context: &mut RenderContext,
        world: &World,
    ) -> Result<(), NodeRunError> {
        for (view, _phase,) in self.views.iter_manual(world) {
            let camera_position = view.transform.translation();

            for (
                cloud_handle,
                sorted_entries_handle,
            ) in self.gaussian_clouds.iter_manual(world) {
                let gaussian_cloud_res = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap();
                let sorted_entries_res = world.get_resource::<RenderAssets<SortedEntries>>().unwrap();

                let mut state = self.state.lock().unwrap();
                if gaussian_cloud_res.get(cloud_handle).is_none() || sorted_entries_res.get(sorted_entries_handle).is_none() {
                    continue;
                } else if !state.test_loaded {
                    state.test_loaded = true;
                }

                let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();
                let sorted_entries = sorted_entries_res.get(sorted_entries_handle).unwrap();
                let gaussians = cloud.debug_gpu.gaussians.clone();

                wgpu::util::DownloadBuffer::read_buffer(
                    render_context.render_device().wgpu_device(),
                    world.get_resource::<RenderQueue>().unwrap().0.as_ref(),
                    &sorted_entries.sorted_entry_buffer.slice(
                        0..sorted_entries.sorted_entry_buffer.size()
                    ),
                    move |buffer: Result<wgpu::util::DownloadBuffer, wgpu::BufferAsyncError>| {
                        let binding = buffer.unwrap();
                        let u32_muck = bytemuck::cast_slice::<u8, u32>(&*binding);

                        let mut radix_sorted_indices = Vec::new();
                        for i in (1..u32_muck.len()).step_by(2) {
                            radix_sorted_indices.push((i, u32_muck[i] as usize));
                        }

                        // TODO: depth order validation over ndc cells

                        radix_sorted_indices.iter()
                            .fold(0.0, |depth_acc, &(entry_idx, idx)| {
                                if idx == 0 || u32_muck[entry_idx - 1] == 0xffffffff || u32_muck[entry_idx - 1] == 0x0 {
                                    return depth_acc;
                                }

                                let position = gaussians[idx].position;
                                let position_vec3 = Vec3::new(position[0], position[1], position[2]);
                                let depth = (position_vec3 - camera_position).length();

                                let depth_is_non_decreasing = depth_acc <= depth;
                                if !depth_is_non_decreasing {
                                    println!(
                                        "radix keys: [..., {:#010x}, {:#010x}, {:#010x}, ...]",
                                        u32_muck[entry_idx - 1 - 2],
                                        u32_muck[entry_idx - 1],
                                        u32_muck[entry_idx - 1 + 2],
                                    );
                                }

                                assert!(depth_is_non_decreasing, "radix sort, non-decreasing check failed: {} > {}", depth_acc, depth);

                                depth_acc.max(depth)
                            });
                    }
                );
            }
        }

        Ok(())
    }
}

github-actions[bot] avatar Dec 15 '23 16:12 github-actions[bot]