diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Providing custom depth maps for depth2img

Open subpanic opened this issue 2 years ago • 4 comments

What API design would you like to have changed or added to the library? Why? Currently there's an inconsistency with image handling around depth maps. While you can provide either a PIL image or tensor data for the main image in img2img and depth2img pipelines, currently only tensor data is accepted in depth2img.

Preferably it would also be possible to optionally provide a PIL image for the depth data to use.

What use case would this enable or better enable? Can you give us a code example? This would avoid requiring anyone injecting their own depth data to have to do the PIL image > tensor conversion manually. Would also be great to just better standardize what this depth data should be/look like.

As an example. Currently to provide depth data I'm doing the following (this is almost certainly not the ideal way, just an example):

depthImg = Image.open("depth.png")
depthOut = depthImg.convert("L")
depthOut = np.expand_dims(depthOut, axis=0)
depthOut = np.expand_dims(depthOut, axis=0).repeat(batchSize, axis=0)
depthOut = torch.from_numpy(depthOut)
depthOut = 2. * depthOut - 1.

The above depthOut can then be provided to the depth2img pipeline call via the depth_map arg (that exists but isn't documented) but I also maintain a separate small hack in prepare_depth_map and move depth_map.unsqueeze(1) up inside the if depth_map is None: block so that it only unsqueezes for depth data estimated in the method and not for injected depth data:

        if depth_map is None:
            pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values
            pixel_values = pixel_values.to(device=device)
            # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16.
            # So we use `torch.autocast` here for half precision inference.
            context_manger = torch.autocast("cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext()
            with context_manger:
                depth_map = self.depth_estimator(pixel_values).predicted_depth
        else:
            depth_map = depth_map.to(device=device, dtype=dtype)

        depth_map = torch.nn.functional.interpolate(
            depth_map.unsqueeze(1),
            size=(height // self.vae_scale_factor, width // self.vae_scale_factor),
            mode="bicubic",
            align_corners=False,
        )

becomes

        if depth_map is None:
            pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values
            pixel_values = pixel_values.to(device=device)
            # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16.
            # So we use `torch.autocast` here for half precision inference.
            context_manger = torch.autocast("cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext()
            with context_manger:
                depth_map = self.depth_estimator(pixel_values).predicted_depth
            
            depth_map = depth_map.unsqueeze(1)
        else:
            depth_map = depth_map.to(device=device, dtype=dtype)

        depth_map = torch.nn.functional.interpolate(
            depth_map,
            size=(height // self.vae_scale_factor, width // self.vae_scale_factor),
            mode="bicubic",
            align_corners=False,
        )

Ideally we could instead just optionally supply the PIL image directly to the pipeline call and in prepare_depth_map it would bypass depth estimation and just convert the PIL image to tensor data.

I would put a PR together but I'm not 100% on how best to conform the depth data. The rest of the PR around that should be pretty simple.

subpanic avatar Dec 27 '22 16:12 subpanic

Thanks @subpanic for posting this, I've bumped into the same issue. In the meantime, when running depth pipeline locally/colab NB, I've used this workaround:

img =  './i.jpg'
dep = './d.jpg'

init_image = Image.open(img)
init_depth = Image.open(dep)
size = [640,480]

# just making sure we feed the same tensor dims to the pipe
init_image = init_image.resize(size)
init_depth = init_depth.resize(size)

init_depth = init_depth.convert("L")
init_depth = np.expand_dims(init_depth, axis=0)

init_depth = torch.from_numpy(init_depth)
init_depth = 2. * init_depth - 1.

after which it's possible to call the depth_map=init_depth in the pipeline. As you mentioned, it will be much better to have a generic depth map loading method rather than a tensor only.

RELNO avatar Dec 27 '22 23:12 RELNO

Agree very much with this issue! @patil-suraj what do you think? Would you maybe like to open a PR to allow depth images @subpanic ?

patrickvonplaten avatar Jan 03 '23 12:01 patrickvonplaten

Agree; adding this to my todo list!

patil-suraj avatar Jan 25 '23 12:01 patil-suraj

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Feb 19 '23 15:02 github-actions[bot]