pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

Add fusion_debug dump option with Val log

Open jacobhinkle opened this issue 2 years ago • 2 comments

This PR adds a new dump option called fusion_debug which prints all the Vals and Exprs in a Fusion during compileFusion, and also prints an ordered list of operations that were done on those Vals. For example:

  Logged operations:                                                                                                                      
    0) T2_g : TensorView::reorder(T2_g):                                                                                                  
    1) [ iblockIdx.x4, rS8, rthreadIdx.x7, rUS9 ] : TensorDomain::reorder([ iS4, rS5 ]):   1->0                                           
    2) T2_g : TensorView::reorder(T2_g):                                                                                                  
    3) [ iblockIdx.x4, rS8, rthreadIdx.x7, rUS9 ] : TensorDomain::reorder([ rS5, iS4 ]):   1->0                                           
    4) T2_g : TensorView::split(T2_g): 1, blockDim.x, 1, 0                                                                                
    5) [ iblockIdx.x4, rS8, rthreadIdx.x7, rUS9 ] : TensorDomain::split([ iS4, rS5 ]): 1, blockDim.x, 1, 0                                
    6) rS5 : IterDomain::split(rS5): blockDim.x, 1, 0                                                                                     
    7) rthreadIdx.x7 : IterDomain::parallelize(rS7): threadIdx.x                                                                          
    8) T2_g : TensorView::split(T2_g): 1, 1, 1, 0                                                                                         
    9) [ iblockIdx.x4, rS8, rthreadIdx.x7, rUS9 ] : TensorDomain::split([ iS4, rS6, rthreadIdx.x7 ]): 1, 1, 1, 0                          
    10) rS6 : IterDomain::split(rS6): 1, 1, 0                                                                                             
    11) rUS9 : IterDomain::parallelize(rS9): US                                                                                           
    12) iblockIdx.x4 : IterDomain::parallelize(iS4): blockIdx.x                                                                           
    13) T2_g : TensorView::reorder(T2_g):                                                                                                 
    14) [ iblockIdx.x4, rS8, rthreadIdx.x7, rUS9 ] : TensorDomain::reorder([ iblockIdx.x4, rS8, rUS9, rthreadIdx.x7 ]):   3->2  2->3  1->1
  0->0                                                                                                                                    
    15) T2_g : TensorView::rFactor(T2_g):  1 3                                                                                            
    16) ithreadIdx.x13 : IterDomain::parallelize(iS13): threadIdx.x                                                                       
    17) rUS15 : IterDomain::parallelize(rS15): US                                                                                         
    18) iblockIdx.x16 : IterDomain::parallelize(iS16): blockIdx.x                                                                         
    19) rthreadIdx.x17 : IterDomain::parallelize(rS17): threadIdx.x                                                                       
    20) iblockIdx.x2 : IterDomain::parallelize(iS2): blockIdx.x                                                                           
    21) ithreadIdx.x19 : IterDomain::parallelize(iS19): threadIdx.x                                                                       
    22) iUS21 : IterDomain::parallelize(iS21): US                                                                                         
    23) T0_g : TensorView::inlineAt(T0_g): -1, 1                                                                                          
    24) T1_l : TensorView::inlineAt(T1_l): -1, 1                                                                                          
    25) T3_l : TensorView::inlineAt(T3_l): -1, 1                                                                                          
    26) T2_g : TensorView::inlineAt(T2_g): -1, 1 

Notice that the names of Vals in this log are all abbreviated: they're only the names and not other stuff that's typically printed in toString(). To accomplish that, toString now takes a fmt argument whose type SerializationFormat can be one of Default (current behavior), NameOnly, or Debug. The Debug option prints slightly more than Default: in the case of a TensorView it prints contiguity_ for example. The SerializationFormat enum can be extended in the future to includes stuff like JSON, but I haven't explored that much yet.

jacobhinkle avatar Jan 12 '23 20:01 jacobhinkle

This PR was motivated by cases like the following. Consider this basic fusion: https://github.com/csarofeen/pytorch/blob/a06224f1976e2896d27f199984bd6b3f98707424/third_party/nvfuser/test/test_gpu_match_frontend.cpp#L197-L205 Using manual and automatic scheduling give identical fusion_ir printouts, but looking at the generated kernels, they differ slightly: image These kernels are obviously equivalent, but there are some differences possibly in the ordering of operations. When we dump fusion_debug and diff we see the following operation log (green is automatic and red is manual schedule): image Now we see a lot of differences. There are other differences in the detailed dump that precedes the op log, indicating that some intermediate tensors (that are not shown at all in the fusion math or transforms view) may be parallelized differently, e.g. image

jacobhinkle avatar Jan 13 '23 14:01 jacobhinkle

With the printout above of the Fusion operation log, I was able to exactly match the automatic scheduler exactly for this problem, and learned a few things about how the automatic scheduler does things. Here is what my previous manual schedule looked like:

  // Perform manual scheduling
  tv2->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
  tv2->split(1, 1);
  tv2->reorder({{-1, -2}, {-2, -1}});

  auto tv3 = tv2->rFactor({1, 3});

  tv3->axis(0)->parallelize(ParallelType::BIDx);
  tv3->axis(2)->parallelize(ParallelType::TIDx);
  tv3->axis(3)->parallelize(ParallelType::Unswitch);

  // propagate the mapping to other tensors
  TransformPropagatorWithCheck propagator(tv3);
  MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
  scheduler_utils::parallelizeAllLike(tv3, {tv0, tv1, tv2});

  tv1->computeAt(tv3, -1);

  inlineMost();

And here is the one that matches the auto scheduler:

  // Perform manual scheduling
  tv2->reorder({{1, 0}});
  tv2->reorder({{1, 0}});
  tv2->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
  tv2->axis(2)->parallelize(ParallelType::TIDx);
  tv2->split(1, 1);
  tv2->axis(2)->parallelize(ParallelType::Unswitch);
  tv2->axis(0)->parallelize(ParallelType::BIDx);

  // tv2->reorder({{-2, -1}}) has same effect but this shows the mapping explicitly
  tv2->reorder({{0, 0}, {1, 1}, {2, 3}, {3, 2}});

  auto tv3 = tv2->rFactor({1, 3});

  // propagate the mapping to other tensors
  TransformPropagatorWithCheck propagator(tv3);
  MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
  scheduler_utils::parallelizeAllLike(tv3, {}, allParallelTypesExcept(
          {ParallelType::Unroll,
           ParallelType::Vectorize,
           ParallelType::MisalignedVectorize}));

  inlineMost();

Note that these both give equal fusion_ir printouts. Here we can see that the reduction scheduler puts parallelize right after splits, which I've found is good practice since it's harder to keep track of the axes you want to parallelize if you do it later after numerous splits and reorders. We also see that you can call parallelizeAllLike with an empty list of tensors to indicate all of them and you can tell it all the parallel types you want to filter on.

One odd thing is the double reorder at the very beginning of the automatic schedule. This doesn't effect the fusion_ir printout or the generated kernel at all, but I put it there because it does show up in the operation log. It is probably a special case of a more general part of the scheduler and is not always a trivial transform, but I haven't yet dug in to find out where that happens in the reduction scheduler.

jacobhinkle avatar Jan 13 '23 17:01 jacobhinkle