vector-quantize-pytorch icon indicating copy to clipboard operation
vector-quantize-pytorch copied to clipboard

VectorQuantize not JIT-safe

Open JackMcCoy opened this issue 3 years ago • 8 comments

The current code controls the program based on the values of tensors (like the "initted" buffer) and will not work when compiling a jit trace.

JackMcCoy avatar Dec 02 '21 02:12 JackMcCoy

@JackMcCoy darn, i'm not sure how to approach that - i guess if we aren't doing a kmeans init of the codebook on the first pass, i could remove that piece of logic

lucidrains avatar Dec 02 '21 20:12 lucidrains

@JackMcCoy darn, i'm not sure how to approach that - i guess if we aren't doing a kmeans init of the codebook on the first pass, i could remove that piece of logic

Yeah, that was my thought. I set kmeans_init and use_cosine_sim to False (though I think both are good features/ just didn't see a straightforward way of using them) and thought perhaps running jit.trace with checks would work, and it ran, but didn't train properly (maybe something else going on here, so I'm not saying this is conclusive). I was going to try copying the code into my repo and deleting the "initted" variable/check, but then decided to try running it without JIT compiling that network and it wasn't any slower, so didn't test it further.

Other libraries have tensor-based control flow functions, but I don't see anything like that for PyTorch.

JackMcCoy avatar Dec 02 '21 20:12 JackMcCoy

Another option would seem to be putting the control flow handling in separate methods which are then tagged with @torch.jit.ignore

JackMcCoy avatar Dec 02 '21 20:12 JackMcCoy

@JackMcCoy ok, let's try the torch.jit.ignore! see if the latest version helps

lucidrains avatar Dec 02 '21 21:12 lucidrains

Unfortunately, no. Looking at the code again, ema_inplace() is another issue standing in the way... Obviously that's an important performance choice. Keeping it and having a JIT-safe routing might end up having to look fairly messy. Maybe a separate version is the best option, if it's worthwhile to do.

Any thoughts? Looking around, it seems that JIT would possibly fuse the out-of-place operations. I may poke around/ will be sure to comment on anything I find.

JackMcCoy avatar Dec 02 '21 23:12 JackMcCoy

neither is clip Jit safe. Or rather, Jit is not safe, stop using it ? :P

fractaldna22 avatar Jan 12 '22 05:01 fractaldna22

neither is clip Jit safe. Or rather, Jit is not safe, stop using it ? :P

you know, there are ways to use a quantized codebook besides with clip!

JackMcCoy avatar Jan 12 '22 05:01 JackMcCoy

In our internal version of VQ-VAE we had the same issue with DDP syncing the codebooks and we resorted to the following solution:

    @torch.jit.unused
    def ddp_sync(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None:
        if self._ddp_sync and torch.distributed.is_initialized():
            torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM)
            torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM)
        else:
            pass
            
    def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            ....
                if self._ddp_sync:
                    self.ddp_sync(encodings_sum, dw)

Basically, anything that is not jitable should be able to be turned off and avoided with if statements.

danieltudosiu avatar Oct 22 '22 17:10 danieltudosiu