vector-quantize-pytorch
vector-quantize-pytorch copied to clipboard
VectorQuantize not JIT-safe
The current code controls the program based on the values of tensors (like the "initted" buffer) and will not work when compiling a jit trace.
@JackMcCoy darn, i'm not sure how to approach that - i guess if we aren't doing a kmeans init of the codebook on the first pass, i could remove that piece of logic
@JackMcCoy darn, i'm not sure how to approach that - i guess if we aren't doing a kmeans init of the codebook on the first pass, i could remove that piece of logic
Yeah, that was my thought. I set kmeans_init and use_cosine_sim to False (though I think both are good features/ just didn't see a straightforward way of using them) and thought perhaps running jit.trace with checks would work, and it ran, but didn't train properly (maybe something else going on here, so I'm not saying this is conclusive). I was going to try copying the code into my repo and deleting the "initted" variable/check, but then decided to try running it without JIT compiling that network and it wasn't any slower, so didn't test it further.
Other libraries have tensor-based control flow functions, but I don't see anything like that for PyTorch.
Another option would seem to be putting the control flow handling in separate methods which are then tagged with @torch.jit.ignore
@JackMcCoy ok, let's try the torch.jit.ignore! see if the latest version helps
Unfortunately, no. Looking at the code again, ema_inplace() is another issue standing in the way... Obviously that's an important performance choice. Keeping it and having a JIT-safe routing might end up having to look fairly messy. Maybe a separate version is the best option, if it's worthwhile to do.
Any thoughts? Looking around, it seems that JIT would possibly fuse the out-of-place operations. I may poke around/ will be sure to comment on anything I find.
neither is clip Jit safe. Or rather, Jit is not safe, stop using it ? :P
neither is clip Jit safe. Or rather, Jit is not safe, stop using it ? :P
you know, there are ways to use a quantized codebook besides with clip!
In our internal version of VQ-VAE we had the same issue with DDP syncing the codebooks and we resorted to the following solution:
@torch.jit.unused
def ddp_sync(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None:
if self._ddp_sync and torch.distributed.is_initialized():
torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM)
else:
pass
def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
....
if self._ddp_sync:
self.ddp_sync(encodings_sum, dw)
Basically, anything that is not jitable should be able to be turned off and avoided with if statements.