gsplat icon indicating copy to clipboard operation
gsplat copied to clipboard

Is there a way to track the splats to the images?

Open abrahamezzeddine opened this issue 1 year ago • 9 comments
trafficstars

Hello,

With COLMAP, it is possible to extract data to determine which images and keypoints correspond to specific 3D points and to identify which images have been matched.

I am wondering if it is possible to track which Gaussian splats correspond to which keypoints, along with their associated image files and pixel coordinates. Is there a way to "tap" into the code to keep track of this when using Nerfstudio?

Specifically, I want to identify the originating image and pixel coordinate for each Gaussian splat (with respect to the point cloud, not the splat itself). While COLMAP provides this information for 3D points, I would like to know if it is possible to extend this tracking to the exported PLY file for the splats.

abrahamezzeddine avatar May 20 '24 18:05 abrahamezzeddine

https://github.com/nerfstudio-project/gsplat/blob/409bcd3cf63491710444e60c29d3c44608d8eafd/gsplat/project_gaussians.py#L51C7-L51C70

You can use the xys from the project_gaussians method to get the respective 2D pixel points from the point cloud with respect to that particular camera pose.

hariharan1412 avatar May 20 '24 18:05 hariharan1412

Thank you!!

Do you possibly have a small code snippet that I could use as a reference?

abrahamezzeddine avatar May 20 '24 19:05 abrahamezzeddine

As of now, I don't have any. Refer to these issues; this will really help you. 

https://github.com/nerfstudio-project/gsplat/issues/77#issue-2015727922 https://github.com/nerfstudio-project/gsplat/issues/87#issue-2047885540

hariharan1412 avatar May 20 '24 19:05 hariharan1412

So I tried to use that function, and take in both xyz, scale, rotation to produce the required input, as well as reading the ply file and transforms.json but for some reason why tensors of xys is very large, ranging from 0 to 150000... I did not expect them to be this large? If I also reduce the global value incrementally, everything is just zero. Also, it seems the point cloud does not rotate correctly according to the camera frustrum? Have I missed anything below?

import numpy as np
import torch
from plyfile import PlyData
import json
import tkinter as tk
from tkinter import filedialog
import gsplat
import sys

def parse_ply(file_path):
    plydata = PlyData.read(file_path)
    vertices = plydata['vertex']
    xyz = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
    scale = np.vstack([vertices['scale_0'], vertices['scale_1'], vertices['scale_2']]).T
    rotation = np.vstack([vertices['rot_0'], vertices['rot_1'], vertices['rot_2'], vertices['rot_3']]).T
    return xyz, scale, rotation

def load_transform_json(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data

def normalize_quaternions(quats):
    norms = np.linalg.norm(quats, axis=1, keepdims=True)
    normalized_quats = quats / norms
    return normalized_quats

def progress_update(current, total):
    progress = (current / total) * 100
    sys.stdout.write(f"\rProgress: {progress:.2f}%")
    sys.stdout.flush()

def project_and_plot_frame(means3d, scales, quats, camera_params, frame):
    viewmat = torch.tensor(frame['transform_matrix'], dtype=torch.float32)
    fx = camera_params["fl_x"]
    fy = camera_params["fl_y"]
    cx = camera_params["cx"]
    cy = camera_params["cy"]
    img_height = camera_params["h"]
    img_width = camera_params["w"]
    glob_scale = 1.0
    block_width = 16
    clip_thresh = 0.1
    
    if torch.cuda.is_available():
        viewmat = viewmat.cuda()
    
    print("Projecting 3D Gaussians to 2D...")
    xys, depths, radii, conics, compensation, num_tiles_hit, cov3d = gsplat.project_gaussians(
        means3d, scales, glob_scale, quats, viewmat, fx, fy, cx, cy, img_height, img_width, block_width, clip_thresh
    )
    
    print("\nPrinting tensor content for comparison...")

    
    torch.set_printoptions(threshold=10)  # Set the print threshold to 10
    print(f"xys (2D projections):\n{xys[:10]}")
    print(f"depths (z-depths):\n{depths[:10]}")
    print(f"radii (2D gaussian radii):\n{radii[:10]}")
    print(f"conics (conic parameters):\n{conics[:10]}")
    print(f"compensation (density compensation):\n{compensation[:10]}")
    print(f"num_tiles_hit (tiles hit per gaussian):\n{num_tiles_hit[:10]}")
    print(f"cov3d (3D covariances):\n{cov3d[:10]}")
    print("\nProcessing complete.")

def main():
    # File selection dialog
    root = tk.Tk()
    root.withdraw()
    
    ply_file_path = filedialog.askopenfilename(title="Select PLY File", filetypes=[("PLY files", "*.ply")])
    transform_json_path = filedialog.askopenfilename(title="Select Transform JSON File", filetypes=[("JSON files", "*.json")])
    
    if not ply_file_path or not transform_json_path:
        print("File selection cancelled.")
        return
    
    print("Parsing PLY file...")
    xyz, scale, rotation = parse_ply(ply_file_path)
    print(f"Extracted {xyz.shape[0]} points from PLY file.")
    
    print("Loading transform JSON...")
    transform_data = load_transform_json(transform_json_path)
    camera_params = transform_data
    
    means3d = torch.tensor(xyz, dtype=torch.float32)
    scales = torch.tensor(scale, dtype=torch.float32)
    
    print("Normalizing quaternions...")
    normalized_quats = normalize_quaternions(rotation)
    quats = torch.tensor(normalized_quats, dtype=torch.float32)
   
    if torch.cuda.is_available():
        print("CUDA is available. Moving tensors to GPU.")
        means3d = means3d.cuda()
        scales = scales.cuda()
        quats = quats.cuda()
    else:
        print("CUDA is not available. Running on CPU.")
    
    for frame in camera_params['frames']:
        print(f"Processing frame: {frame['file_path']}")
        project_and_plot_frame(means3d, scales, quats, camera_params, frame)

if __name__ == "__main__":
    main()

xys (2D projections): tensor([[35513.7109, 377.4275], [32701.4141, 1122.3706], [ 0.0000, 0.0000], ..., [ 0.0000, 0.0000], [67154.5703, -507.2152], [31027.0039, 1390.0424]], device='cuda:0') depths (z-depths): tensor([0.3947, 0.4317, 0.0000, 0.3716, 0.0000, 0.1823, 0.4330, 0.0000, 0.2000, 0.4450], device='cuda:0') radii (2D gaussian radii): tensor([153113, 120826, 0, 176440, 0, 373944, 164592, 0, 334831, 144265], device='cuda:0', dtype=torch.int32) conics (conic parameters): tensor([[5.7169e-10, 4.6618e-10, 1.5412e-09], [8.2018e-10, 8.2071e-10, 3.9233e-09], [0.0000e+00, 0.0000e+00, 0.0000e+00], ..., [0.0000e+00, 0.0000e+00, 0.0000e+00], [1.9080e-10, 1.3240e-10, 2.3889e-10], [4.3249e-10, 7.9552e-12, 1.5409e-09]], device='cuda:0') compensation (density compensation): tensor([1.0000, 1.0000, 0.0000, 1.0000, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000], device='cuda:0') num_tiles_hit (tiles hit per gaussian): tensor([47628, 47628, 0, 47628, 0, 47628, 47628, 0, 47628, 47628], device='cuda:0', dtype=torch.int32) cov3d (3D covariances): tensor([[34.5615, 1.5018, 0.6180, 40.5671, 2.4734, 34.0930], [30.7935, 2.2675, 1.0720, 34.1519, -4.6022, 17.5621], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], ..., [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [38.9945, 6.2411, 7.7874, 42.0054, 4.3823, 52.2032], [48.7078, -4.1139, 2.3271, 41.7742, 5.0261, 47.7809]], device='cuda:0')

Any advice is appreciated and helpful. Thank you.

I plan to segment images for objects and see which gaussian 2D points ends up in specific areas in the images and then assign those specific points an object ID. So when I click on a specific Point3D in space in ply file, I will know to what object segment it belongs to and to what image it comes from. This is important for example if I want to view the images that contributed to that point. In COLMAP, it is for example possible to select a single point in 3D space and see which images that contributes to it.

abrahamezzeddine avatar May 21 '24 10:05 abrahamezzeddine

I think the problem is with your camera extrinctis. Convert the camera from c2w to w2c

transform_matrix_viewmats = torch.tensor(frame['transform_matrix'], dtype=torch.float32).to(device)

R = transform_matrix_viewmats[:3, :3]  # 3 x 3
T = transform_matrix_viewmats[:3, 3:4]  # 3 x 1

R_edit = torch.diag(torch.tensor([1, -1, -1], device=device, dtype=R.dtype))
R = R @ R_edit

R_inv = R.T
T_inv = -R_inv @ T

viewmats = torch.eye(4, device=R.device, dtype=R.dtype)
viewmats[:3, :3] = R_inv
viewmats[:3, 3:4] = T_inv

Try it and let me know, if you are facing any issues

hariharan1412 avatar May 22 '24 12:05 hariharan1412

I plan to segment images for objects and see which gaussian 2D points ends up in specific areas in the images and then assign those specific points an object ID. So when I click on a specific Point3D in space in ply file, I will know to what object segment it belongs to and to what image it comes from. This is important for example if I want to view the images that contributed to that point. In COLMAP, it is for example possible to select a single point in 3D space and see which images that contributes to it.

Yes, what you are trying to achieve is possible, although you have to sort out a lot of logic to achieve it. All the best

hariharan1412 avatar May 22 '24 12:05 hariharan1412

I think the problem is with your camera extrinctis. Convert the camera from c2w to w2c

transform_matrix_viewmats = torch.tensor(frame['transform_matrix'], dtype=torch.float32).to(device)

R = transform_matrix_viewmats[:3, :3]  # 3 x 3
T = transform_matrix_viewmats[:3, 3:4]  # 3 x 1

R_edit = torch.diag(torch.tensor([1, -1, -1], device=device, dtype=R.dtype))
R = R @ R_edit

R_inv = R.T
T_inv = -R_inv @ T

viewmats = torch.eye(4, device=R.device, dtype=R.dtype)
viewmats[:3, :3] = R_inv
viewmats[:3, 3:4] = T_inv

Try it and let me know, if you are facing any issues

I tried this, and certainly worked a little better indeed! Thank you very much! However, it seems that the camera is being projected "outside" the whole point cloud. I assume tha camera is correctly rotated, but is placed well outside the intended viewpoint, instead of actual location we see in the Vizer when we train the Gaussian Splat with Nerfstudio.

import numpy as np
import torch
from plyfile import PlyData
import json
import tkinter as tk
from tkinter import filedialog
import gsplat
import sys
import matplotlib.pyplot as plt

def parse_ply(file_path):
    plydata = PlyData.read(file_path)
    vertices = plydata['vertex']
    xyz = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
    scale = np.vstack([vertices['scale_0'], vertices['scale_1'], vertices['scale_2']]).T
    rotation = np.vstack([vertices['rot_0'], vertices['rot_1'], vertices['rot_2'], vertices['rot_3']]).T
    print("Printing first 10 xyz values", xyz[:10])
    return xyz, scale, rotation

def load_transform_json(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data

def normalize_quaternions(quats):
    norms = np.linalg.norm(quats, axis=1, keepdims=True)
    normalized_quats = quats / norms
    return normalized_quats

def progress_update(current, total):
    progress = (current / total) * 100
    sys.stdout.write(f"\rProgress: {progress:.2f}%")
    sys.stdout.flush()

def project_and_plot_frame(means3d, scales, quats, camera_params, frame):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform_matrix_viewmats = torch.tensor(frame['transform_matrix'], dtype=torch.float32).to(device)

    R = transform_matrix_viewmats[:3, :3]  # 3 x 3
    T = transform_matrix_viewmats[:3, 3:4]  # 3 x 1

    R_edit = torch.diag(torch.tensor([1, -1, -1], device=device, dtype=R.dtype))
    R = R @ R_edit

    R_inv = R.T
    T_inv = -R_inv @ T

    viewmats = torch.eye(4, device=R.device, dtype=R.dtype)
    viewmats[:3, :3] = R_inv
    viewmats[:3, 3:4] = T_inv

    fx = camera_params["fl_x"]
    fy = camera_params["fl_y"]
    cx = camera_params["cx"]
    cy = camera_params["cy"]
    img_height = camera_params["h"]
    img_width = camera_params["w"]
    glob_scale = 1.0
    block_width = 16
    clip_thresh = 0.01
    
    if torch.cuda.is_available():
        means3d = means3d.cuda()
        scales = scales.cuda()
        quats = quats.cuda()
    
    xys, depths, radii, conics, compensation, num_tiles_hit, cov3d = gsplat.project_gaussians(
        means3d.contiguous(), scales.contiguous(), glob_scale, quats.contiguous(), 
        viewmats, fx, fy, cx, cy, img_height, img_width, block_width, clip_thresh
    )
    
    print("Checking projection results...")
    print(f"xys (2D projections):\n{xys[:10]}")
    print(f"depths (z-depths):\n{depths[:10]}")

    """Filter out invalid projections"""
    valid_indices = (xys[:, 0] > 0) & (xys[:, 1] > 0) & (xys[:, 0] < img_width) & (xys[:, 1] < img_height) & (depths > clip_thresh)
    xys_filtered = xys[valid_indices].cpu().numpy()

    print(f"Filtered xys (2D projections):\n{xys_filtered[:10]}")

    if xys_filtered.size == 0:
        print("No valid points to plot for this frame.")
        return
    
    """Plotting in 2D"""
    print("\nPlotting valid 2D projections...")
    plt.figure(figsize=(10, 8))
    plt.scatter(xys_filtered[:, 0], xys_filtered[:, 1], c='blue', s=1)
    plt.xlabel('X (pixels)')
    plt.ylabel('Y (pixels)')
    plt.title(f'Valid 2D Projections for {frame["file_path"]}')
    plt.gca().invert_yaxis()  # Invert Y axis to match image coordinates
    plt.show()

def main():
    # File selection dialog
    root = tk.Tk()
    root.withdraw()
    
    ply_file_path = filedialog.askopenfilename(title="Select PLY File", filetypes=[("PLY files", "*.ply")])
    transform_json_path = filedialog.askopenfilename(title="Select Transform JSON File", filetypes=[("JSON files", "*.json")])
    
    if not ply_file_path or not transform_json_path:
        print("File selection cancelled.")
        return
    
    print("Parsing PLY file...")
    xyz, scale, rotation = parse_ply(ply_file_path)
    print(f"Extracted {xyz.shape[0]} points from PLY file.")
    
    print("Loading transform JSON...")
    transform_data = load_transform_json(transform_json_path)
    camera_params = transform_data
    
    means3d = torch.tensor(xyz, dtype=torch.float32)
    scales = torch.tensor(scale, dtype=torch.float32)
    
    print("Normalizing quaternions...")
    normalized_quats = normalize_quaternions(rotation)
    quats = torch.tensor(normalized_quats, dtype=torch.float32)
   
    if torch.cuda.is_available():
        print("CUDA is available. Moving tensors to GPU.")
        means3d = means3d.cuda()
        scales = scales.cuda()
        quats = quats.cuda()
    
    for frame in camera_params['frames']:
        print(f"Processing frame: {frame['file_path']}")
        project_and_plot_frame(means3d, scales, quats, camera_params, frame)

if __name__ == "__main__":
    main()

image

abrahamezzeddine avatar May 22 '24 13:05 abrahamezzeddine

xy_to_pix = torch.floor(xys).long()  
valid_indices = (
    (xy_to_pix[:, 0] > 0)
    & (xy_to_pix[:, 0] < W)
    & (xy_to_pix[:, 1] > 0)
    & (xy_to_pix[:, 1] < H)
)

xy_to_pix = xy_to_pix[valid_indices]
valid_colors = colors[valid_indices]

u = xy_to_pix[: ,0]
v = xy_to_pix[: ,1]

image = np.zeros((H, W, 3), dtype=np.uint8)

u_int = u.long().cpu().numpy()
v_int = v.long().cpu().numpy()
for idx in range(len(u_int)):
    color_point = (valid_colors[idx] * 255).int().cpu().numpy() 
    image[v_int[idx], u_int[idx]] = color_point

plt.imshow(image)
plt.axis('off')
plt.show()

Try plotting this verify your ouput

hariharan1412 avatar May 22 '24 13:05 hariharan1412

xy_to_pix = torch.floor(xys).long()  
valid_indices = (
    (xy_to_pix[:, 0] > 0)
    & (xy_to_pix[:, 0] < W)
    & (xy_to_pix[:, 1] > 0)
    & (xy_to_pix[:, 1] < H)
)

xy_to_pix = xy_to_pix[valid_indices]
valid_colors = colors[valid_indices]

u = xy_to_pix[: ,0]
v = xy_to_pix[: ,1]

image = np.zeros((H, W, 3), dtype=np.uint8)

u_int = u.long().cpu().numpy()
v_int = v.long().cpu().numpy()
for idx in range(len(u_int)):
    color_point = (valid_colors[idx] * 255).int().cpu().numpy() 
    image[v_int[idx], u_int[idx]] = color_point

plt.imshow(image)
plt.axis('off')
plt.show()

Try plotting this verify your ouput

I tried to plot it, looks a little better but it seems I am very zoomed out and I see the whole point cloud instead. I am not sure if the camera is in the right position because

  1. I see the whole point cloud in the plot
  2. No limit to the depth, so everything is rendered, even the farthest points.

Thanks a lot.

abrahamezzeddine avatar May 22 '24 21:05 abrahamezzeddine

Hi, Are you successful with your work? did it worked ?

hariharan1412 avatar Jun 05 '24 09:06 hariharan1412

transform_matrix_viewmats = torch.tensor(frame['transform_matrix'], dtype=torch.float32).to(device)

R = transform_matrix_viewmats[:3, :3]  # 3 x 3
T = transform_matrix_viewmats[:3, 3:4]  # 3 x 1

R_edit = torch.diag(torch.tensor([1, -1, -1], device=device, dtype=R.dtype))
R = R @ R_edit

R_inv = R.T
T_inv = -R_inv @ T

viewmats = torch.eye(4, device=R.device, dtype=R.dtype)
viewmats[:3, :3] = R_inv
viewmats[:3, 3:4] = T_inv

Hello,

I skipped this one actually. :) I have however, sent you the code with explanations on how I tracked the guassians, since you have to add the code in different places to make it work. You might have to check your spam folder.

abrahamezzeddine avatar Jun 16 '24 21:06 abrahamezzeddine

Hello, Thank you so much for sharing the code with me.

as for the rotation, if you are using Inria-trained gaussian splat, this rotation simply won't work because Inria uses OpenCV conversion and Nerf uses OpenGL conversion, So please apply your transformation accordingly

hariharan1412 avatar Jun 17 '24 06:06 hariharan1412

@hariharan1412 @abrahamezzeddine

can you please provide the updated code with explanations for this task?

pknmax avatar Jun 17 '24 12:06 pknmax

hello @abrahamezzeddine, can you please provide your code with explanations for this task? Thanks, [email protected]

pknmax avatar Jun 17 '24 21:06 pknmax

hello @abrahamezzeddine, can you please provide your code with explanations for this task? Thanks, [email protected]

Hello,

The code I shared is not for this. It was for something else for Nerfstudio indexing and not gsplat.

abrahamezzeddine avatar Jun 18 '24 06:06 abrahamezzeddine

I think the problem is with your camera extrinctis. Convert the camera from c2w to w2c

transform_matrix_viewmats = torch.tensor(frame['transform_matrix'], dtype=torch.float32).to(device)

R = transform_matrix_viewmats[:3, :3]  # 3 x 3
T = transform_matrix_viewmats[:3, 3:4]  # 3 x 1

R_edit = torch.diag(torch.tensor([1, -1, -1], device=device, dtype=R.dtype))
R = R @ R_edit

R_inv = R.T
T_inv = -R_inv @ T

viewmats = torch.eye(4, device=R.device, dtype=R.dtype)
viewmats[:3, :3] = R_inv
viewmats[:3, 3:4] = T_inv

Try it and let me know, if you are facing any issues

I tried this, and certainly worked a little better indeed! Thank you very much! However, it seems that the camera is being projected "outside" the whole point cloud. I assume tha camera is correctly rotated, but is placed well outside the intended viewpoint, instead of actual location we see in the Vizer when we train the Gaussian Splat with Nerfstudio.

import numpy as np
import torch
from plyfile import PlyData
import json
import tkinter as tk
from tkinter import filedialog
import gsplat
import sys
import matplotlib.pyplot as plt

def parse_ply(file_path):
    plydata = PlyData.read(file_path)
    vertices = plydata['vertex']
    xyz = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
    scale = np.vstack([vertices['scale_0'], vertices['scale_1'], vertices['scale_2']]).T
    rotation = np.vstack([vertices['rot_0'], vertices['rot_1'], vertices['rot_2'], vertices['rot_3']]).T
    print("Printing first 10 xyz values", xyz[:10])
    return xyz, scale, rotation

def load_transform_json(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data

def normalize_quaternions(quats):
    norms = np.linalg.norm(quats, axis=1, keepdims=True)
    normalized_quats = quats / norms
    return normalized_quats

def progress_update(current, total):
    progress = (current / total) * 100
    sys.stdout.write(f"\rProgress: {progress:.2f}%")
    sys.stdout.flush()

def project_and_plot_frame(means3d, scales, quats, camera_params, frame):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform_matrix_viewmats = torch.tensor(frame['transform_matrix'], dtype=torch.float32).to(device)

    R = transform_matrix_viewmats[:3, :3]  # 3 x 3
    T = transform_matrix_viewmats[:3, 3:4]  # 3 x 1

    R_edit = torch.diag(torch.tensor([1, -1, -1], device=device, dtype=R.dtype))
    R = R @ R_edit

    R_inv = R.T
    T_inv = -R_inv @ T

    viewmats = torch.eye(4, device=R.device, dtype=R.dtype)
    viewmats[:3, :3] = R_inv
    viewmats[:3, 3:4] = T_inv

    fx = camera_params["fl_x"]
    fy = camera_params["fl_y"]
    cx = camera_params["cx"]
    cy = camera_params["cy"]
    img_height = camera_params["h"]
    img_width = camera_params["w"]
    glob_scale = 1.0
    block_width = 16
    clip_thresh = 0.01
    
    if torch.cuda.is_available():
        means3d = means3d.cuda()
        scales = scales.cuda()
        quats = quats.cuda()
    
    xys, depths, radii, conics, compensation, num_tiles_hit, cov3d = gsplat.project_gaussians(
        means3d.contiguous(), scales.contiguous(), glob_scale, quats.contiguous(), 
        viewmats, fx, fy, cx, cy, img_height, img_width, block_width, clip_thresh
    )
    
    print("Checking projection results...")
    print(f"xys (2D projections):\n{xys[:10]}")
    print(f"depths (z-depths):\n{depths[:10]}")

    """Filter out invalid projections"""
    valid_indices = (xys[:, 0] > 0) & (xys[:, 1] > 0) & (xys[:, 0] < img_width) & (xys[:, 1] < img_height) & (depths > clip_thresh)
    xys_filtered = xys[valid_indices].cpu().numpy()

    print(f"Filtered xys (2D projections):\n{xys_filtered[:10]}")

    if xys_filtered.size == 0:
        print("No valid points to plot for this frame.")
        return
    
    """Plotting in 2D"""
    print("\nPlotting valid 2D projections...")
    plt.figure(figsize=(10, 8))
    plt.scatter(xys_filtered[:, 0], xys_filtered[:, 1], c='blue', s=1)
    plt.xlabel('X (pixels)')
    plt.ylabel('Y (pixels)')
    plt.title(f'Valid 2D Projections for {frame["file_path"]}')
    plt.gca().invert_yaxis()  # Invert Y axis to match image coordinates
    plt.show()

def main():
    # File selection dialog
    root = tk.Tk()
    root.withdraw()
    
    ply_file_path = filedialog.askopenfilename(title="Select PLY File", filetypes=[("PLY files", "*.ply")])
    transform_json_path = filedialog.askopenfilename(title="Select Transform JSON File", filetypes=[("JSON files", "*.json")])
    
    if not ply_file_path or not transform_json_path:
        print("File selection cancelled.")
        return
    
    print("Parsing PLY file...")
    xyz, scale, rotation = parse_ply(ply_file_path)
    print(f"Extracted {xyz.shape[0]} points from PLY file.")
    
    print("Loading transform JSON...")
    transform_data = load_transform_json(transform_json_path)
    camera_params = transform_data
    
    means3d = torch.tensor(xyz, dtype=torch.float32)
    scales = torch.tensor(scale, dtype=torch.float32)
    
    print("Normalizing quaternions...")
    normalized_quats = normalize_quaternions(rotation)
    quats = torch.tensor(normalized_quats, dtype=torch.float32)
   
    if torch.cuda.is_available():
        print("CUDA is available. Moving tensors to GPU.")
        means3d = means3d.cuda()
        scales = scales.cuda()
        quats = quats.cuda()
    
    for frame in camera_params['frames']:
        print(f"Processing frame: {frame['file_path']}")
        project_and_plot_frame(means3d, scales, quats, camera_params, frame)

if __name__ == "__main__":
    main()

image

@pknmax , The code you are looking for is already here

hariharan1412 avatar Jun 18 '24 06:06 hariharan1412

I tried to plot it, looks a little better but it seems I am very zoomed out and I see the whole point cloud instead. I am not sure if the camera is in the right position because

  1. I see the whole point cloud in the plot
  2. No limit to the depth, so everything is rendered, even the farthest points.

Thanks a lot.

I'm having this issue as well, what exactly was the root of the issue? Was it that the quaternions from the original 3DGS paper + code use a different system (left/right handedness) and or notation ([w,x,y,z] vs [x,y,z,w])?

I first thought it was a scaling issue as the splat scene produced by gsplat looked tiny by comparison to the output scene of the original 3DGS splatting process.

arcman7 avatar Aug 02 '24 21:08 arcman7

While Colmap provides W2C coordinated, Gsplat anticipates W2C world camera. Perform the conversion from c2w to w2c and retest

hariharan1412 avatar Aug 02 '24 21:08 hariharan1412

While Colmap provides W2C coordinated, Gsplat anticipates W2C world camera. Perform the conversion from c2w to w2c and retest

Sorry, I was not clear about the issue I'm facing (I misread the issue in this thread but the symptoms are similar) - I've run gsplat and original 3DGS splatting against the same colmap db and the the two output splat scenes are very different in terms of the size and rotation. If I render both splat scenes in the supersplat editor I find that the splat scene from gsplat is much smaller as well as being oriented differently:

image

GSplat is plotted in blue.

arcman7 avatar Aug 02 '24 23:08 arcman7

Please share me the command that you used to train gsplat

hariharan1412 avatar Aug 03 '24 00:08 hariharan1412

Sure thing, thank you in advance for your help :)

cfg.json:

{"disable_viewer": true, "ckpt": null, "data_dir": "../../dataset/indoor/milago_dark3", "data_factor": 1, "result_dir": "../../dataset/indoor/milago_dark3/gsplat/", "test_every": 8, "patch_size": null, "global_scale": 1.0, "port": 8080, "batch_size": 1, "steps_scaler": 1.0, "max_steps": 30000, "eval_steps": [7000, 30000], "save_steps": [7000, 30000], "sh_degree": 3, "sh_degree_interval": 1000, "init_opa": 0.1, "ssim_lambda": 0.2, "near_plane": 0.01, "far_plane": 10000000000.0, "prune_opa": 0.005, "grow_grad2d": 0.0002, "grow_scale3d": 0.01, "prune_scale3d": 0.1, "refine_start_iter": 500, "refine_stop_iter": 15000, "reset_every": 3000, "refine_every": 100, "packed": false, "sparse_grad": false, "absgrad": false, "antialiased": false, "random_bkgd": false, "pose_opt": false, "pose_opt_lr": 1e-05, "pose_opt_reg": 1e-06, "pose_noise": 0.0, "app_opt": false, "app_embed_dim": 16, "app_opt_lr": 0.001, "app_opt_reg": 1e-06, "depth_loss": false, "depth_lambda": 0.01, "tb_every": 100, "tb_save_image": false}

Edit: these were the arguments used with the python simple_trainer.py command

arcman7 avatar Aug 03 '24 00:08 arcman7

Do you think this could likely be caused using a left-handed quaternion system where a right-handed system is expected?

arcman7 avatar Aug 06 '24 21:08 arcman7

I first thought it was a scaling issue as the splat scene produced by gsplat looked tiny by comparison to the output scene of the original 3DGS splatting process.

Yes, this is true; Splatfacto won't give you a true scale; it'll only give you -1 to 1 scale (possibly I'm wrong too).

hariharan1412 avatar Aug 07 '24 11:08 hariharan1412

I first thought it was a scaling issue as the splat scene produced by gsplat looked tiny by comparison to the output scene of the original 3DGS splatting process.

Yes, this is true; Splatfacto won't give you a true scale; it'll only give you -1 to 1 scale (possibly I'm wrong too).

I can confirm that splatfacto will normalize the values between -1 to 1 due to better performance with CUDA as I understand.

abrahamezzeddine avatar Aug 07 '24 11:08 abrahamezzeddine

Thanks @hariharan1412 and @abrahamezzeddine that context helps a lot!

arcman7 avatar Aug 07 '24 20:08 arcman7