Ex4DGS icon indicating copy to clipboard operation
Ex4DGS copied to clipboard

Out of Memory issue

Open Drow999 opened this issue 1 year ago • 3 comments

Thank you for the excellent work on this project! I am currently reproducing your work and had a question regarding oom. I tried training this model using a 3090. Could this be a memory leak issue? The RAM size of the 4090 and 3090 should be the same? image

Drow999 avatar Nov 12 '24 15:11 Drow999

I have same issue with RTX4090

KyungdaePark avatar Nov 14 '24 02:11 KyungdaePark

I understand this issue arises because CPU image loading is faster than GPU processing. Additionally, the dataloading process relies on joblib. Unfortunately, I have not found a complete solution within joblib to address this issue. Reducing the number of CPU workers may help mitigate the problem. I recommend lowering the n_job argument in the following two lines.

https://github.com/juno181/Ex4DGS/blob/81ec2b8658c8effe6d6375c2355ea197e759839d/scene/init.py#L185

https://github.com/juno181/Ex4DGS/blob/81ec2b8658c8effe6d6375c2355ea197e759839d/scene/init.py#L229

juno181 avatar Nov 14 '24 12:11 juno181

Okay. replacing getTrainCameras() and getTestCameras() with this worked for me.

def getTrainCameras2(self, scale=1.0, shuffle=True, return_as="generator", return_path=False, get_img=True):
        if self.lazy_loader:
            t_cams = list(compress(self.train_cameras[scale], self.samplelist))
            t_imgs = [(i.image_path, i.resolution, i.im_scale) for i in t_cams]

            if shuffle:
                temp = list(zip(t_cams, t_imgs))
                random.shuffle(temp)
                res1, res2 = zip(*temp)
                t_cams, t_imgs = list(res1), list(res2)

            if return_path:
                return t_cams, t_imgs

            def im_reader(path, resolution, im_scale):
                ImageFile.LOAD_TRUNCATED_IMAGES = True
                return (PILtoTorch(Image.open(path), resolution)[:3, ...] / im_scale).clamp(0, 1)

            if get_img:
                if return_as == "list":
                    imgs = [im_reader(path, resolution, im_scale) for path, resolution, im_scale in t_imgs]
                    return t_cams, imgs
                else:  # Assume "generator"
                    def img_generator():
                        for path, resolution, im_scale in t_imgs:
                            yield im_reader(path, resolution, im_scale)

                    return t_cams, img_generator()
            else:
                return t_cams, None

        else:
            t_cams = list(compress(self.train_cameras[scale], self.samplelist))
            if return_path:
                t_imgs = [(i.image_path, i.resolution) for i in t_cams]
            else:
                t_imgs = [i.image for i in t_cams]

            if shuffle:
                temp = list(zip(t_cams, t_imgs))
                random.shuffle(temp)
                res1, res2 = zip(*temp)
                t_cams, t_imgs = list(res1), list(res2)

            if return_path:
                return t_cams, t_imgs

            if return_as == "list":
                return t_cams, t_imgs
            else:
                def img_iterator():
                    for img in t_imgs:
                        yield img

                return t_cams, img_iterator()

    def getTestCameras2(self, scale=1.0, shuffle=True, return_as="generator", return_path=False, get_img=True):
        if self.lazy_loader:
            t_cams = list(compress(self.test_cameras[scale], self.test_samplelist))
            t_imgs = [(i.image_path, i.resolution, i.im_scale) for i in t_cams]

            if shuffle:
                temp = list(zip(t_cams, t_imgs))
                random.shuffle(temp)
                res1, res2 = zip(*temp)
                t_cams, t_imgs = list(res1), list(res2)

            if return_path:
                return t_cams, t_imgs

            # Define image reader function
            def im_reader(path, resolution, im_scale):
                ImageFile.LOAD_TRUNCATED_IMAGES = True
                return (PILtoTorch(Image.open(path), resolution)[:3, ...] / im_scale).clamp(0, 1)

            if get_img:
                if return_as == "list":
                    imgs = [im_reader(path, resolution, im_scale) for path, resolution, im_scale in t_imgs]
                    return t_cams, imgs
                else:  # Assume "generator"
                    def img_generator():
                        for path, resolution, im_scale in t_imgs:
                            yield im_reader(path, resolution, im_scale)

                    return t_cams, img_generator()
            else:
                return t_cams, None
        else:
            t_cams = list(compress(self.test_cameras[scale], self.test_samplelist))
            if return_path:
                t_imgs = [(i.image_path, i.resolution) for i in t_cams]
            else:
                t_imgs = [i.image for i in t_cams]

            if shuffle:
                temp = list(zip(t_cams, t_imgs))
                random.shuffle(temp)
                res1, res2 = zip(*temp)
                t_cams, t_imgs = list(res1), list(res2)

            if return_path:
                return t_cams, t_imgs

            if return_as == "list":
                return t_cams, t_imgs
            else:  # Assume "generator"
                def img_iterator():
                    for img in t_imgs:
                        yield img

                return t_cams, img_iterator()

vikramsandu avatar Nov 15 '24 09:11 vikramsandu