xla
xla copied to clipboard
Skip the execution if all Pending IRs are device data
Today if we have a bunch of DeviceData IR and called xm.mark_step
, we will execute a graph looks like
# ENTRY %IrToHlo.4 (p0.1: f32[4,2,2], p1.2: f32[4,2,2]) -> (f32[4,2,2], f32[4,2,2]) {
# %p0.1 = f32[4,2,2]{2,1,0} parameter(0)
# %p1.2 = f32[4,2,2]{2,1,0} parameter(1)
# ROOT %tuple.3 = (f32[4,2,2]{2,1,0}, f32[4,2,2]{2,1,0}) tuple(f32[4,2,2]{2,1,0} %p0.1, f32[4,2,2]{2,1,0} %p1.2)
# }
This is pretty much a no-op graph that will generate new buffer(might or might not be aliased with input) with the same value. Execution of such graph will lead to runtime overhead as well as device execution overhead. For training it is usually not a big deal but for inference this can be very annoying, especially in dynamo bridge we sometimes needs to mark_step
to sync inputs. My change will force the tensor to drop reference to the DeviceData IR and point to the XLAData
direcly to save the execution of that tensor.
My change will incur a problem if 2 tensors share the same buffer, in the old design both tensor will get new buffers after execution. With my new change no execution happens and they all point to the same PJRT buffer. However if this input buffer is being aliased with the output(in place update for one tensor), another tensor will point to a deleted buffer. To address this issue I make sure to always allocate new output buffer for cloned tensor.
lol it breaks test_shard_hashing
, I think this was where I gave up last time. Let me took another look why that's the case.
test/test_input_output_aliases.py
still fails, I know why it failed but what's weird(input shape and output shape has different layout so it won't alias) is that it also failed even if I build from the master... Not sure what happened.
I also need to add more test about torch.clone
which will make a new DeviceData that share the same XLAData, need to make sure aliasing won't cause issue here.
More test failures.. will need to look into this.
Finally test is green lol. I don't want to merge this to 2.3 since it is too risky, I will leave this to nightly for now.
Thanks @alanwaketan I am going to wait for the branch cut to happen and then merge this change.