TensoRF icon indicating copy to clipboard operation
TensoRF copied to clipboard

Optimizing compute_appfeature?

Open UPstartDeveloper opened this issue 3 years ago • 0 comments

Hey team,

So I'm building a renderer for TensoRF, and work on optimizing the code in this repo so that it can run in real-time.

I wrote a vectorized implementation (I think) of the TensorVMSplit.compute_appfeature() that leverages NumPy functions. I tested it on an AWS EC 2 instance (i.e. a g3s.xlarge), and it seems to run without error.

I'm just looking for feedback - is this a good direction to pursue? Do folks know if there are potentially easier/better ways to go about run this function w/o running into an out-of-memory error?


def compute_appfeature(self, xyz_sampled):
        """
        Returns the appearance feature vectors for a set of XYZ locations.

        Parameters:
            xyz_sampled: multi-dimensional Tensor. Last dim should have a shape of 3.

        Returns: multidimensional tensor. 
            Last dim will have same shape as data_dim_color
        """
        def compute_factors(idx_plane, grid_mode='plane'):
            """
            Helper function used to compute the factors used for 
            vector-matrix decomposition.

            Parameters:
                idx_plane (int): points to either the XY, XZ, or YZ planes
                grid_mode (str): specifies whether we want a 
                                 matrix/vector factor

            Returns: torch.Tensor: the factor needed for VM decomposition
            """
            grid = None
            if grid_mode == 'plane':
                grid = coordinate_plane  # defined below
            else:  # grid_mode == 'line'
                grid = coordinate_line  # defined below
            
            input_plane = self.app_plane[idx_plane].cpu()
            factor = F.grid_sample(
                input_plane,
                grid[[idx_plane]],
                align_corners=True,
            ).view(-1, *xyz_sampled.shape[:1])

            return factor

        ### MAIN CODE
        xyz_sampled = xyz_sampled.to(device="cpu")

        ...  # unchanged code

        # figure out the vector-matrix outer products, trying vectorization
        app_plane_indices = np.array(list(range(len(self.app_plane))))
        compute_VM_factors = np.vectorize(compute_factors, otypes=[torch.Tensor])

        plane_coef_point = compute_VM_factors(app_plane_indices, 'grid')  # 1D np.ndarray with 2D Tensors
        plane_coef_point = torch.cat(list(plane_coef_point)).to(device=self.device)  # 2D Tensor

        # same type of object as plane_coef_point
        line_coef_point = compute_VM_factors(app_plane_indices, 'line')  # same as above
        line_coef_point = torch.cat(list(line_coef_point)).to(device=self.device)

        return self.basis_mat((plane_coef_point * line_coef_point).T)

UPstartDeveloper avatar Jul 24 '22 22:07 UPstartDeveloper