nvdiffrec icon indicating copy to clipboard operation
nvdiffrec copied to clipboard

Would it be possible to modify this to also use CPU memory?

Open sirisian opened this issue 2 years ago • 6 comments

I looked at issue #2 and I'm curious if you could modify the project to use CPU memory by swapping in and out images and other data structures as needed? (Not sure if https://github.com/IBM/pytorch-large-model-support could be used for the tensor part). I noticed the limitations section in the paper mentioned the memory consumption already, so if this isn't easy then feel free to close this.

Part of this is I'm interested in seeing what happens with larger and more images. (This assumes training duration isn't a high priority and model quality is the goal. I have a 3090 and a lot of DDR5 memory for reference).

sirisian avatar Apr 04 '22 22:04 sirisian

Hello,

  • More images should work out of the box. We pre-load datasets into CPU memory for perf (to avoid image parsing in the training loop), but also the preloading can be disabled using the pre_load flag. The training data is then uploaded onto the device in the prepare_batch function. https://github.com/NVlabs/nvdiffrec/blob/main/train.py#L68 inside the training loop.

  • Larger images than, say 2k x 2k, are trickier to support, as we currently rely on nvdiffrast for differentiable rasterization using GPU hardware, and need to render full views (with corresponding large GPU buffers for backprop etc.). You can work around this by e.g., rendering random crops from larger images in each training iteration, or replace the rasterization primary visibility step with differentiable ray casting. With ray casting, you can sample a subset of random rays instead of full views (similar to the NeRF training setup).

jmunkberg avatar Apr 05 '22 05:04 jmunkberg

Hi @sirisian,

I don't think we'll do significant memory optimizations to the code in the short term. We want it to stay true to the paper version, and as researchers we have limited time to dedicate to this release.

The package looks fairly promising, but I'm not sure how it manages GPU/CPU transitions. As Jacob mentioned above, we rely on quite a few CUDA kernels which have no CPU fallback. If the LMS library always run kernels on the CUDA device and temporarily swap back to CPU memory it could work, but if it requires both CPU and CUDA kernels it will be very hard.

Another optimization you might want to try is to change the batching code. Currently we batch by rendering to [N, H, W, C] tensors, where N is the batch size. Temporary results need to be stored at the same resolution, so memory consumption grows linearly with N. The main benefit of batching is gradient averaging, so it would be possible to instead run a loop over N x [1, H, W, C] forward + backward passes and just average the final gradients.

JHnvidia avatar Apr 05 '22 09:04 JHnvidia

Thanks for all the ideas.

More images should work out of the box.

Larger images than, say 2k x 2k, are trickier to support

You can work around this by e.g., rendering random crops from larger images in each training iteration

Wondering if there's a more naive way to do something like this without modifying the code (or very little). Like if one could take a single camera image 9Kx7K and treat it like 9x7=63 cameras each with a resolution of 1Kx1K pixels and throw away masked out cameras. (So only the fake cameras not masked out are included). The frustums wouldn't follow normal camera intrinsic definitions though, so I imagine this wouldn't be viable without relating them back to a single original camera frustum and instrinsics.

sirisian avatar Apr 07 '22 22:04 sirisian

Yes, you need to adjust the camera frustum. We have a function util.perspective_offcenter that does this. We tested random cropping at some point but removed it for the public release.

In dataset_llff.py or dataset_nerf.py , you can modify the end of the _parse_frame(self, idx) function with something like this to enable random cropping (the code below is untested, but you get the idea):

...
CROP_SIZE = 256
height = img.shape[0]
width  = img.shape[1]
xstart = np.random.randint(0, width-CROP_SIZE)
ystart = np.random.randint(0, height-CROP_SIZE)
img = img[ystart:ystart+CROP_SIZE, xstart:xstart+CROP_SIZE, :]
_rx = xstart / width
_ry = ystart / height
# Override projection matrix and mvp
proj_mtx = util.perspective_offcenter(fovy, CROP_SIZE/width, _rx, _ry, width / height, self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
mvp      = proj_mtx @ mv

jmunkberg avatar Apr 08 '22 05:04 jmunkberg

@jmunkberg adding this code in seems to break an assertion in render/render.py line 215 assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1]) Any suggestions on working around this?

For context, I changed the end of _parse_frame in dataset_llff.py to:

........
        if self.FLAGS.crop_dataset is True:
            print('cropping dataset')
            if self.FLAGS.CROP_SIZE is None:
                CROP_SIZE = 256
            else:
                CROP_SIZE = self.FLAGS.CROP_SIZE
            height = img.shape[0]
            width = img.shape[1]
            xstart = np.random.randint(0, width - CROP_SIZE)
            ystart = np.random.randint(0, height - CROP_SIZE)
            img = img[ystart:ystart + CROP_SIZE, xstart:xstart + CROP_SIZE, :]
            _rx = xstart / width
            _ry = ystart / height

            # Override projection matrix and mvp
            proj = util.perspective_offcenter(self.fovy[idx, ...], CROP_SIZE / width, _rx, _ry, width / height,
                                              self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])

        else:
            # Setup transforms
            proj = util.perspective(self.fovy[idx, ...], self.aspect, self.FLAGS.cam_near_far[0],
                                    self.FLAGS.cam_near_far[1])
        mv = torch.linalg.inv(self.imvs[idx, ...])

        campos = torch.linalg.inv(mv)[:3, 3]
        mvp = proj @ mv

        return img[None, ...], mv[None, ...], mvp[None, ...], campos[None, ...]

Where I pass the customs flags self.FLAGS.crop_dataset and self.FLAGS.CROP_SIZE in with the rest of the flags.

mexicantexan avatar Jun 24 '22 00:06 mexicantexan

Anyway, just I assume that self.fovy[idx, ...] should not indexed with idx(frame id) if your images are RGBA type.

hzhshok avatar Aug 01 '22 08:08 hzhshok