packnet-sfm icon indicating copy to clipboard operation
packnet-sfm copied to clipboard

Reprojected loss function

Open Wenchao-Du opened this issue 4 years ago • 5 comments

@VitorGuizilini where can I find the code and model for the paper "Robust Semi-Supervised Monocular Depth Estimation with Reprojected Distances (CoRL 2019 spotlight)" ? thank you

Wenchao-Du avatar Feb 07 '21 08:02 Wenchao-Du

Hi, thank you for your interest. We still have not added support for this loss function, but I'm planning to do that soon, I'll keep you informed.

VitorGuizilini-TRI avatar Mar 27 '21 16:03 VitorGuizilini-TRI

Hi, are you still planning to release the code for the loss function? Thanks

pjckoch avatar May 25 '21 10:05 pjckoch

As far as I understand it, part of it should be similar to the view_synthesis() function: https://github.com/TRI-ML/packnet-sfm/blob/c03e4bf929f202ff67819340135c53778d36047f/packnet_sfm/geometry/camera_utils.py#L27-L59

First, to get the world coordinates, call cam.reconstruct() with the lidar depth, then call the same function with the predicted depth. After that, use the predicted pose to project both world coordinates from lidar and prediction to the reference camera, using ref_cam.project(). Then, we can compute the euclidean distance between the two results as our loss, right?

So, something like the following should work, shouldn't it? Am I missing something?

import torch
from utils.depth import depth2inv, inv2depth
from utils.camera import Camera
from utils.image import match_scales

def reprojected_distance_loss(depth_pred: torch.Tensor, depth_gt: torch.Tensor, mask: torch.Tensor,
                                                    ref_cam: Camera, cam: Camera) -> torch.Tensor:
    # Reconstruct world points from target_camera
    world_points = cam.reconstruct(depth_gt, frame='w')
    world_points_pred = cam.reconstruct(depth_pred, frame='w')
    # Project world points onto reference camera (returns normalized pixel coordinates)
    ref_coords = ref_cam.project(world_points, frame='w')
    ref_coords_pred = ref_cam.project(world_points_pred, frame='w')
    return torch.linalg.norm(ref_coords[mask] - ref_coords_pred[mask], dim=1).mean()


masks = []
depth_gts = match_scales(depth_gt, preds, self.n, mode='nearest', align_corners=None)
depth_preds = [inv2depth(preds[i]) for i in range(self.n)]

for i in range(self.n):
    masks.append((depth_gts[i] > 0).detach())

for i in range(len(poses)):
    # Generate cameras for all scales
    cams, ref_cams = [], []
    for j in range(self.n):
        _, _, DH, DW = depth_preds[j].shape
        scale_factor = DW / float(W)
        cams.append(Camera(K=K.float()).scaled(scale_factor).to(device))
        ref_cams.append(Camera(K=ref_K.float(), Tcw=poses[i]).scaled(scale_factor).to(device))
    loss += sum([reprojected_distance_loss(depth_preds[i], depth_gts[i], masks[i], ref_cams[i], cams[i]) for i in range(self.n)])

loss /= len(poses)
loss /= self.n

pjckoch avatar May 26 '21 08:05 pjckoch

@VitorGuizilini-TRI, any update on when you plan to add the implementation of the Reprojected Distance Loss function? Thanks

iariav avatar May 17 '22 08:05 iariav

@iariav, @Wenchao-Du, @VitorGuizilini-TRI I've implemented the Reprojected Distance Loss function and pull requested. Hopefully, it will satisfy your needs.

Best regards!

aartykov avatar Aug 13 '22 15:08 aartykov