AITemplate icon indicating copy to clipboard operation
AITemplate copied to clipboard

Include split+cat in fuse_split optimization

Open erjiang opened this issue 1 year ago • 4 comments

This change extends _fuse_split_and_strided_op to also optimize split followed by cat (when both are on the same dim). The split op is removed and the input_accessors of the cat op are updated.

The new test_fuse_split_cat.py test case tests for split+cat across dim 0, dim 1, and a third case where split and cat along different dims (a case that we are not optimizing yet).

erjiang avatar Jun 01 '23 00:06 erjiang

@erjiang Thanks for the PR! A few thoughts:

  • Tests are breaking at the moment, please have a look
  • To test that the transform actually works, you might want to check not only that the outputs are numerically correct, but also that the graph has expected structure - e.g. ops which had to be eliminated got eliminated and/or correct new ops are present in the graph. This way you'll catch situations when the tests passes just because the transform didn't get applied for some reason
  • It would be good to test a few more complicated scenarios, including when some of the tensors involved in the transformation feed into more than one op or when some of them are inputs/outputs of the graph (you don't want to eliminate an output, for example)
  • You made the transformation part of fuse_split - would it actually work this way? As far as I understand, that transform eliminates a split op by fusing it with a downstream op. In your case, you'd want to eliminate split + concat pair, wouldn't you?

sgrigory avatar Jun 01 '23 09:06 sgrigory

Thanks for the comments, appreciate the review! I should add that this PR is meant to go with T154346827 but requested some help from Yang to import it into Phabricator.

  • To test that the transform actually works, you might want to check not only that the outputs are numerically correct, but also that the graph has expected structure - e.g. ops which had to be eliminated got eliminated and/or correct new ops are present in the graph. This way you'll catch situations when the tests passes just because the transform didn't get applied for some reason

Good point, I will find similar test cases and add that check in.

  • You made the transformation part of fuse_split - would it actually work this way? As far as I understand, that transform eliminates a split op by fusing it with a downstream op. In your case, you'd want to eliminate split + concat pair, wouldn't you?

My understanding was that we should eliminate the split and update the TensorAccessors of the concat op to get the same result, so similar to how the existing pass works. Can we eliminate both split and concat, and update the op after concat? Maybe that could work if next_op is a certain type, otherwise fallback to only removing the split?

erjiang avatar Jun 01 '23 18:06 erjiang

  • You made the transformation part of fuse_split - would it actually work this way? As far as I understand, that transform eliminates a split op by fusing it with a downstream op. In your case, you'd want to eliminate split + concat pair, wouldn't you?

My understanding was that we should eliminate the split and update the TensorAccessors of the concat op to get the same result, so similar to how the existing pass works. Can we eliminate both split and concat, and update the op after concat? Maybe that could work if next_op is a certain type, otherwise fallback to only removing the split?

That's right. The only case where we can eliminate both split and concat is that we don't permute the split's outputs. For example, we can only remove split for the following example:

x1, x2 = split(t)
y = concat(x2, x1)

chenyang78 avatar Jun 02 '23 07:06 chenyang78

@erjiang There is a test failure:

FAILED tests/unittest/compiler/test_transform_memory_ops.py::MemoryOpTransformationTestCase::test_non_fusible_split_reshape_cat - AssertionError: 2 != 3

Please take a look. It might be the case where the assertion became invalid because of the new pass, but we want to check it. Thanks.

chenyang78 avatar Jun 03 '23 07:06 chenyang78

@erjiang There is a test failure:

FAILED tests/unittest/compiler/test_transform_memory_ops.py::MemoryOpTransformationTestCase::test_non_fusible_split_reshape_cat - AssertionError: 2 != 3

Please take a look. It might be the case where the assertion became invalid because of the new pass, but we want to check it. Thanks.

Yep, already fixed in my working branch. There was an additional split op optimized out in that test case so the assertion failed.

erjiang avatar Jun 05 '23 05:06 erjiang

@chenyang78 I've run into an issue that I could use some suggestions on: there was a test failure in test_split_bmm_fusion.py until I added an alignment check: https://github.com/erjiang/AITemplate/commit/f189e0f0f22ce4e65818c4515776865c7209fd4b

2 errors detected in the compilation of "bmm_rcr_3.cu".
make: *** [Makefile:9: bmm_rcr_3.obj] Error 255
make: *** Waiting for unfinished jobs....

2023-06-05 14:24:05,077 INFO <aitemplate.backend.builder> /home/eric/AITemplate/3rdparty/cutlass/include/cutlass/arch/memory_sm80.h:

...
109       /// Size of the access in bytes
110       int SizeInBytes>
111   struct cp_async<SizeInBytes, CacheOperation::Always> {
112   
113     /// Copy
114     CUTLASS_DEVICE
115     cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
116       #if CUDA_CP_ASYNC_ACTIVATED
117   
118         // Make sure the size is supported.
119         static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16),
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
120                   "Size is not supported");
121   
122         unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
123   
124         asm volatile(
125             "{\n"
126             "  .reg .pred p;\n"
127             "  setp.ne.b32 p, %0, 0;\n"
128   #if CUTLASS_ENABLE_L2_PREFETCH
129             "  @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n"
...

But this alignment check seems to be more strict than necessary, and prevents a few optimizations that work fine, including the simple case of torch.split(input_1, [139, 373], 0) then concat. Do you know if/how I can narrow the alignment check, or if there's something else I should be doing instead?

erjiang avatar Jun 05 '23 21:06 erjiang

@chenyang78 I've run into an issue that I could use some suggestions on: there was a test failure in test_split_bmm_fusion.py until I added an alignment check: erjiang@f189e0f

But this alignment check seems to be more strict than necessary, and prevents a few optimizations that work fine, including the simple case of torch.split(input_1, [139, 373], 0) then concat. Do you know if/how I can narrow the alignment check, or if there's something else I should be doing instead?

_check_dim_alignment is supposed to be only used for gemm/bmm ops. Sorry for the bad code documentation.

I think we can simply return True if the op is concat, e.g.:

if op._attrs["op"] == "concatenate":
    return True

chenyang78 avatar Jun 06 '23 08:06 chenyang78

I borrowed the alignment check from the mm ops because without it, some tests fail with what looks like a failed alignment assertion. Have you seen the above build error before? (The one with static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16) ...) Any suggestions for alternative ways to fix it?

Edit: Alternatively, I can instead keep the alignment check and simply limit the optimization to aligned splits and not optimize odd splits.

erjiang avatar Jun 06 '23 19:06 erjiang

I borrowed the alignment check from the mm ops because without it, some tests fail with what looks like a failed alignment assertion. Have you seen the above build error before? (The one with static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16) ...) Any suggestions for alternative ways to fix it?

@erjiang Could you share one failing test? I can take a look at it today. Thanks.

chenyang78 avatar Jun 07 '23 00:06 chenyang78

@chenyang78 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot avatar Jun 07 '23 00:06 facebook-github-bot

@erjiang Could you share one failing test? I can take a look at it today. Thanks.

Try this revision that does not have the alignment check: https://github.com/erjiang/AITemplate/commit/f07834ccfd18763189ce1f7389a5c59a38419d49 (or just remove the alignment check for concatenate op).

This test command fails with the static_assert build errors:

python tests/unittest/compiler/test_split_bmm_fusion.py SplitBmmFusionTestCase.test_split_bmm_rcr_fusion_static

erjiang avatar Jun 07 '23 00:06 erjiang

@erjiang Could you share one failing test? I can take a look at it today. Thanks.

Try this revision that does not have the alignment check: erjiang@f07834c (or just remove the alignment check for concatenate op).

This test command fails with the static_assert build errors:

python tests/unittest/compiler/test_split_bmm_fusion.py SplitBmmFusionTestCase.test_split_bmm_rcr_fusion_static

Thanks, @erjiang . I played with it a bit. I think we are actually uncovering a latent bug in the fuse_split pass, where we should take into account the alignment requirement for gemm (and likely concat) with respect to the number of elements from the split_dim. We can compute it by code like this:

total_elems_from_split_dim = stride * split_input._attrs["shape"][split_dim].value()

So, one fix would be, we can pass this total_elems_from_split_dim to _check_alignment, where we make an additional check with this new argument, similar to alignment.valid_alignment(offset, dtype).

chenyang78 avatar Jun 07 '23 22:06 chenyang78

@chenyang78 Thanks! I implemented your suggestion, skipping the _check_alignment test for concatenate, and it seems to work, including optimizing out split for some odd-sized splits. At this point I don't know of any other issues.

erjiang avatar Jun 08 '23 17:06 erjiang

BTW, there is a lint error, as well. Please take a look. Thanks.

chenyang78 avatar Jun 09 '23 08:06 chenyang78

BTW, there is a lint error, as well. Please take a look. Thanks.

Seems to be an issue with the linter - it's detecting a minor issue with a file that I didn't edit. The lint-breaking change was in D46442470 which just landed on Wednesday, and I don't want to make edits in other files in this PR. I can make a separate diff for the formatting instead.

Edit: nevermind, the formatting change from ufmt conflicts with arc lint, so either imports ordering will cause one linter to complain.

erjiang avatar Jun 09 '23 16:06 erjiang

BTW, there is a lint error, as well. Please take a look. Thanks.

Seems to be an issue with the linter - it's detecting a minor issue with a file that I didn't edit. The lint-breaking change was in D46442470 which just landed on Wednesday, and I don't want to make edits in other files in this PR. I can make a separate diff for the formatting instead.

Edit: nevermind, the formatting change from ufmt conflicts with arc lint, so either imports ordering will cause one linter to complain.

Ah, I see. Thanks for the info.

chenyang78 avatar Jun 09 '23 18:06 chenyang78

@chenyang78 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot avatar Jun 10 '23 07:06 facebook-github-bot

@chenyang78 merged this pull request in facebookincubator/AITemplate@db2a9e9cfdb0452bfd697a445cdd9f4fdcfe555d.

facebook-github-bot avatar Jun 11 '23 07:06 facebook-github-bot