xla icon indicating copy to clipboard operation
xla copied to clipboard

Skip the execution if all Pending IRs are device data

Open JackCaoG opened this issue 11 months ago • 6 comments

Today if we have a bunch of DeviceData IR and called xm.mark_step, we will execute a graph looks like

    # ENTRY %IrToHlo.4 (p0.1: f32[4,2,2], p1.2: f32[4,2,2]) -> (f32[4,2,2], f32[4,2,2]) {
    # %p0.1 = f32[4,2,2]{2,1,0} parameter(0)
    # %p1.2 = f32[4,2,2]{2,1,0} parameter(1)
    # ROOT %tuple.3 = (f32[4,2,2]{2,1,0}, f32[4,2,2]{2,1,0}) tuple(f32[4,2,2]{2,1,0} %p0.1, f32[4,2,2]{2,1,0} %p1.2)
    # }

This is pretty much a no-op graph that will generate new buffer(might or might not be aliased with input) with the same value. Execution of such graph will lead to runtime overhead as well as device execution overhead. For training it is usually not a big deal but for inference this can be very annoying, especially in dynamo bridge we sometimes needs to mark_step to sync inputs. My change will force the tensor to drop reference to the DeviceData IR and point to the XLAData direcly to save the execution of that tensor.

My change will incur a problem if 2 tensors share the same buffer, in the old design both tensor will get new buffers after execution. With my new change no execution happens and they all point to the same PJRT buffer. However if this input buffer is being aliased with the output(in place update for one tensor), another tensor will point to a deleted buffer. To address this issue I make sure to always allocate new output buffer for cloned tensor.

JackCaoG avatar Feb 28 '24 23:02 JackCaoG

lol it breaks test_shard_hashing, I think this was where I gave up last time. Let me took another look why that's the case.

JackCaoG avatar Feb 29 '24 01:02 JackCaoG

test/test_input_output_aliases.py still fails, I know why it failed but what's weird(input shape and output shape has different layout so it won't alias) is that it also failed even if I build from the master... Not sure what happened.

JackCaoG avatar Mar 01 '24 02:03 JackCaoG

I also need to add more test about torch.clone which will make a new DeviceData that share the same XLAData, need to make sure aliasing won't cause issue here.

JackCaoG avatar Mar 01 '24 02:03 JackCaoG

More test failures.. will need to look into this.

JackCaoG avatar Mar 07 '24 20:03 JackCaoG

Finally test is green lol. I don't want to merge this to 2.3 since it is too risky, I will leave this to nightly for now.

JackCaoG avatar Mar 11 '24 22:03 JackCaoG

Thanks @alanwaketan I am going to wait for the branch cut to happen and then merge this change.

JackCaoG avatar Mar 12 '24 02:03 JackCaoG