TensoRF
TensoRF copied to clipboard
Optimizing compute_appfeature?
Hey team,
So I'm building a renderer for TensoRF, and work on optimizing the code in this repo so that it can run in real-time.
I wrote a vectorized implementation (I think) of the TensorVMSplit.compute_appfeature() that leverages NumPy functions. I tested it on an AWS EC 2 instance (i.e. a g3s.xlarge), and it seems to run without error.
I'm just looking for feedback - is this a good direction to pursue? Do folks know if there are potentially easier/better ways to go about run this function w/o running into an out-of-memory error?
def compute_appfeature(self, xyz_sampled):
"""
Returns the appearance feature vectors for a set of XYZ locations.
Parameters:
xyz_sampled: multi-dimensional Tensor. Last dim should have a shape of 3.
Returns: multidimensional tensor.
Last dim will have same shape as data_dim_color
"""
def compute_factors(idx_plane, grid_mode='plane'):
"""
Helper function used to compute the factors used for
vector-matrix decomposition.
Parameters:
idx_plane (int): points to either the XY, XZ, or YZ planes
grid_mode (str): specifies whether we want a
matrix/vector factor
Returns: torch.Tensor: the factor needed for VM decomposition
"""
grid = None
if grid_mode == 'plane':
grid = coordinate_plane # defined below
else: # grid_mode == 'line'
grid = coordinate_line # defined below
input_plane = self.app_plane[idx_plane].cpu()
factor = F.grid_sample(
input_plane,
grid[[idx_plane]],
align_corners=True,
).view(-1, *xyz_sampled.shape[:1])
return factor
### MAIN CODE
xyz_sampled = xyz_sampled.to(device="cpu")
... # unchanged code
# figure out the vector-matrix outer products, trying vectorization
app_plane_indices = np.array(list(range(len(self.app_plane))))
compute_VM_factors = np.vectorize(compute_factors, otypes=[torch.Tensor])
plane_coef_point = compute_VM_factors(app_plane_indices, 'grid') # 1D np.ndarray with 2D Tensors
plane_coef_point = torch.cat(list(plane_coef_point)).to(device=self.device) # 2D Tensor
# same type of object as plane_coef_point
line_coef_point = compute_VM_factors(app_plane_indices, 'line') # same as above
line_coef_point = torch.cat(list(line_coef_point)).to(device=self.device)
return self.basis_mat((plane_coef_point * line_coef_point).T)