equilib icon indicating copy to clipboard operation
equilib copied to clipboard

support for differentiable warping

Open jacobsn opened this issue 2 years ago • 6 comments

We were hoping to use this library for making differentiable perspective cutouts from panoramas.

We noticed two issues getting in the way:

  1. Some of the function signatures have Python primitive type hints, such as: https://github.com/haruishi43/equilib/blob/3cd9be017c0583560821891d764b332318dba032/equilib/equi2pers/torch.py#L24
  2. Some of the torch_util functions are using np functions, such as: https://github.com/haruishi43/equilib/blob/3cd9be017c0583560821891d764b332318dba032/equilib/torch_utils/rotation.py#L110

We are trying to decide if we should make the necessary changes. I wonder if this is something that's already in the works or if you would be interested in a pull request.

jacobsn avatar May 17 '22 19:05 jacobsn

Hi @jacobsn , Thanks for considering my project. Yes, I think having a fully differentiable version is a good idea. I haven't had time to make test scripts for checking where the function fails to be differentiable, though. If you could provide a PR that implements it, I would happily add it to the library.

Some questions:

  • What errors do primitve type hints cause?
  • torch_util functions do have np functions. Are you trying to learn some rotation? We can change the function input to be a more general type that encompasses floats (torch.float, np.float, float, etc).

haruishi43 avatar May 19 '22 08:05 haruishi43

re: Question 1: I wasn't actually writing any of the software; @kcbhatraju was writing the code. I know he mentioned having trouble pushing tensors in and had to resort to calling .item().

re: Question 2: yes, we are trying to learn rotations, and potentially FoVs.

jacobsn avatar May 19 '22 17:05 jacobsn

Currently, I don't fully understand the situation, but if there are some scripts or error logs to work off from, I could help out. PRs are always appreciated though.

haruishi43 avatar May 20 '22 01:05 haruishi43

To someone who comes to this issue: If you are interested in applying this project in pytorch forward/backward process, and require the convert process to be differentiable only w.r.t the pixel color, not w.r.t the rotation. It can do it.

A simple test in my project: (I only upload the pseudo code)

cubemap = torch.nn.Parameter(torch.zeros(1, 6,64,64,3).float().to('cuda'))
e_optim = torch.optim.Adam(
            [cubemap],
            lr=1e-3,
            weight_decay=0.0001
        )
c2e = Cube2Equi(height=128,width=256,cube_format='dict')
e2c = Equi2Cube(w_face=64,cube_format='dict',z_down= False)
target = torch.zeros([6,3,64,64]).to('cuda')
target = [white, green, red, black, black, black]
comupte_color_loss = losses.photo_loss
for i in range(1000):
    equi = c2e(cubemap)
    recover_cubemap = e2c(equi)
    loss = comupte_color_loss(recover_cubemap,target)
    loss backward and update the origin-cubemap
    print loss and save origin cubemap/equi

The result: iter 0: loss 0.623 iter 100: loss 0.565 iter 200: loss 0.520 ... iter 900: loss 0.422

Initial cubemap and equi: 0_cubemap_recover 0_equi

The optimized cubemap and equi: 900_cubemap_recover 900_equi

JiejiangWu avatar Nov 04 '22 09:11 JiejiangWu

@JiejiangWu, thanks for checking if it runs backprop. I guess the rotation matrices are treated as a transformation with frozen parameters. To differentiate using input rotation, the code would need to support torch.tensor as rotation input which I haven't done.

haruishi43 avatar Nov 04 '22 12:11 haruishi43

@JiejiangWu, thanks for checking if it runs backprop. I guess the rotation matrices are treated as a transformation with frozen parameters. To differentiate using input rotation, the code would need to support torch.tensor as rotation input which I haven't done.

Yes! To make the rotation parameters differentiable, the first thing is extractly to take a torch.tensor as rotation input. Besides, there are some more complex thing to do. To my understanding, the rotation is applied to change the index of sampling in convert process, while the index is not differentiable if the loss was defined in the pixel domain. It may need some tricky strategies, which can be described as differentiable/soft indexing.

JiejiangWu avatar Nov 06 '22 13:11 JiejiangWu