In place operations in backward pass cause errors in asynchronous concurrent training in FSDP
Hi, I am trying to modify the LoRA training recipe with FSDP to support conditional training. So I have two LoRA adapters in a module and use a binary variable to control which LoRA to propagate to during the forward pass.
Then, I was trying to use async io to submit two concurrent tasks which would propagate through different LoRA adapters. I notice that my code runs fine if I do not submit multiple asynchronous tasks at the same time but would see the following error when there are concurrent tasks.
Output 0 of TBackward0 is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.
I checked the implementation and remove all the in place operations in the forward pass but the problem still exists. I am suspecting this is due to certain operations in the backward pass.
When I submit multiple concurrent forward passes, this issue does not exist. It only exists when one of the tasks contain backpropagation.
https://github.com/pytorch/pytorch/issues/124017 mentioned a similar issue but their problem is quite different. I am wondering can the FSDP paradigm support concurrent calls to the model during the training phase?
More specifically, my question would be, are there in place operations in the backward pass that cannot be avoided? If so, does that suggest a more fundamental conflict with implementing asynchronous concurrent training this way; if not, is there a way to make sure all operations are out of place in the backward pass?