apex icon indicating copy to clipboard operation
apex copied to clipboard

NVFuser JIT eager mode InstanceNorm3d

Open eqy opened this issue 3 years ago • 12 comments

Prototype version with "cuDNN conv style" caching.

eqy avatar Feb 25 '22 23:02 eqy

Also, can there be multiple kernels per fusion key or, since all the inputs are forced to contiguous, there's a single kernel only, and tuning is runtime at the level of thread/block sizes?

My understanding is that there would only be one compilation per fusion key; the restriction on contiguous is for initial testing purposes. The kernel is supposed to be general for multiple shapes, but the fuser folks would know more on launch bounds tuning.

eqy avatar Mar 10 '22 01:03 eqy

My understanding is that there would only be one compilation per fusion key;

I don't think that's true, at the very least there could be generated kernels with int32 indexing for smaller tensors, and int64 indexing for larger ones, there could be other situations where different parallelization strategy will be chosen depending on the sizes?

ngimel avatar Mar 10 '22 02:03 ngimel

My understanding is that there would only be one compilation per fusion key;

I don't think that's true, at the very least there could be generated kernels with int32 indexing for smaller tensors, and int64 indexing for larger ones, there could be other situations where different parallelization strategy will be chosen depending on the sizes?

Sorry, I mean "compilation step," or rather that if something is in the cache, it would be unexpected for running the fusion to trigger another compilation. But in general I also agree that more data needs to be collected for the latency of first execution and subsequent executions.

eqy avatar Mar 10 '22 03:03 eqy

Sorry, I mean "compilation step," or rather that if something is in the cache, it would be unexpected for running the fusion to trigger another compilation. But in general I also agree that more data needs to be collected for the latency of first execution and subsequent executions.

I've checked the code, and it is definitely possible to trigger another compilation even if something is in the cache, for example, for differently aligned input, or for some very different input sizes (I don't know which exact attributes of the tensors, other than alignment, will trigger recompile), or for large inputs (where int64 indexing has to be used). Which is fine, and the existing cache still saves some computation, but it should be documented that even hitting function level cache doesn't guarantee fast turnaround. It would be good to understand which parameter ranges will result in kernelRuntime match in runFusionWithInputs. And we still need the measurements for cache miss and full cache hit (function level cache is hit, and no recompilation is triggered in runFusionWithInputs)

ngimel avatar Mar 10 '22 18:03 ngimel

You're correct @ngimel. We can end up generating many kernels, if we want to limit the number of kernels we generate we would need implement coarser grained heuristics (definitely possible to do but we haven't done it yet). Cache misses will be dominated by nvrtc compile time. Cache hits are also not really easy to quantify as our finest grained cache is on input sizes, then there'd be a cache for heuristics, then this top level cache. The lowest level miss will not force recompilation so will be cheaper, but the higher two will require recompilation and will be dominated by nvrtc compile time. I'll work with Eddie to measure the hits at the lower two levels.

csarofeen avatar Mar 10 '22 19:03 csarofeen

Awesome, thanks! So as far as I understand, lowest level will adjust launch parameters, maybe dynamic shared memory size and things like that? For my education, can you roughly describe what'll trigger recompilation? E.g. I compiled for a tensor of size (B,C,H,W), then how different (B1,C1,H1,W1) should be to trigger it? Or if you have documentation somewhere that I can look up?

ngimel avatar Mar 10 '22 21:03 ngimel

The highest level that Eddie was working on, or the heuristics? Heuristic recompilations are all dependent on heuristic changes and subject to change from one release to the next.

Practically for InstanceNorm 3D the big switches are: Persistent vs non-persistent (we're working on grid persistence now) Alignment on vectorization size (1, 2, 4, 8) If not vectorized same factors for unrolling (depends on what fits well) Grid reduction vs non-grid reduction For Persistent cases these decisions are returned by: https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/scheduler/normalization.h#L21 which returns this reduction heuristic structure, if that structure is returned unchanged by this fun mega "==" check: https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h#L111-L140 then no recompilation has to happen.

The function to determine these parameters is quite complex and "predictability" (obviously we know what will happen in an instance because code) may be relatively low as things can easily change as we continue to tune the heuristics.

If this is unacceptable for eager mode (I can imagine it may be as we don't want to recompile too much), we could generate simpler/flatter heuristics in place of the current ones which is more oriented to maximum performance at compilation cost. The thing we don't understand is how quickly these heuristics would converge, in practice on something like large transformers the answer is really quick even with dynamic sequence sizes. If we had something instead like a 3D segmentation model where inputs could change in all directions frequently, that's a lot more degrees of freedom and something like this approach may not be applicable. I don't know how we could easily solve the latter problem without impacting the former case (lower performance but less recompilation).

The one good news about this, is the heuristic structure is what determines recompilation, i.e. if we reimplemented https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp#L35-L492 to be coarser, recompilation will be less frequent, so we do have have control over this as this was required to not have to constantly change lots of code as we update heuristics.

csarofeen avatar Mar 10 '22 21:03 csarofeen

Thanks, that's very helpful!

ngimel avatar Mar 10 '22 22:03 ngimel

Some ballpark numbers on the basic compile-cache-reuse workflow on the small workload in the tests; first execution is around 0.5-0.9s on V100, with about 30-60us of that being the actual kernel execution time and the bulk of it being the compilation. After the first execution the cache lookup takes around 0.9-1.5us.

eqy avatar Mar 15 '22 19:03 eqy

0.9-1.5us is great

ngimel avatar Mar 16 '22 00:03 ngimel

I want to have pytorch ship ATen/native/utils/*.h, torch/csrc/jit/codegen/cuda/*.h, and torch/csrc/jit/codegen/cuda/ops/*.h before merging this.
For this, what I guess we need to do is to update

  • https://github.com/pytorch/pytorch/blob/543eaac415bfba59e648a3da8be5c4f964f9bc6e/setup.py#L938
  • https://github.com/pytorch/pytorch/blob/543eaac415bfba59e648a3da8be5c4f964f9bc6e/aten/src/ATen/CMakeLists.txt#L149
  • some file e.g. https://github.com/pytorch/pytorch/blob/master/torch/CMakeLists.txt and/or https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/nvfuser.cmake

crcrpar avatar May 03 '22 17:05 crcrpar

Getting closer to remove PYTORCH_HOME env var dependency. https://github.com/pytorch/pytorch/pull/78281

crcrpar avatar Jun 22 '22 16:06 crcrpar