torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Disabling LoRA with compiled models

Open BenjaminBossan opened this issue 1 year ago • 2 comments

In PEFT, we've encountered a bunch of issues with LoRA when using compiled models. I was curious how torchtune solves these issues, so I checked one of them, namely disabling the LoRA adapter. It seems that torchtune runs into the same issue, as disabling does not work with compiled models.

To show this bug, I added the following tests to test_lora.py:

# torch v2.2.2
from torchtune.modules.peft.peft_utils import disable_adapter

...

    def test_disable_lora(self, inputs, lora_linear) -> None:
        enabled = lora_linear(inputs)

        with disable_adapter(lora_linear):
            disabled = lora_linear(inputs)

        # passes
        assert not torch.allclose(enabled, disabled, atol=1e-4, rtol=1e-4)

    def test_disable_lora_compiled(self, inputs, lora_linear) -> None:
        with disable_adapter(lora_linear):
            disabled_no_compile = lora_linear(inputs)

        lora_linear = torch.compile(lora_linear)
        enabled = lora_linear(inputs)

        with disable_adapter(lora_linear):
            disabled_compile = lora_linear(inputs)

        # both asserts fail
        assert not torch.allclose(enabled, disabled_compile, atol=1e-4, rtol=1e-4)
        torch.testing.assert_close(disabled_no_compile, disabled_compile)

For me, the first test for disabling LoRA works, as it doesn't use torch.compile. However, the second test fails, disabling has no effect on the model. Can you confirm that this is indeed an issue and are there any ideas/plans to resolve this in torchtune or PyTorch?

BenjaminBossan avatar Apr 15 '24 12:04 BenjaminBossan

@BenjaminBossan thanks so much for this really interesting issue. @yf225 mentioned that he'd be very interested in taking a look at this. So I'll let him comment more intelligently on what might be going on here.

kartikayk avatar Apr 15 '24 22:04 kartikayk

Looked into the issue - it's because torch.compile doesn't currently guard on the user NN module attribute (in this case .disabled), and thus mutating the attribute value did not trigger a recompile, which causes compiled output to be different from eager output.

Opened an issue on PyTorch github repo: https://github.com/pytorch/pytorch/issues/124717, we will follow up on this soon.

yf225 avatar Apr 23 '24 09:04 yf225

This should be fixed now per https://github.com/pytorch/pytorch/issues/124717#issuecomment-2157261609.

yf225 avatar Jun 27 '24 21:06 yf225