`ForwardContext` is `None` with gradient checkpointing enabled
Environment info
-
adaptersversion: latest main
Information
Model I am using (Bert, XLNet ...): any
Language I am using the model on (English, Chinese ...): any
Adapter setup I am using (if any): Affects all adapter methods reliant on ForwardContext: Reft, Prefix-Tuning, Prompt Tuning, Fusion, Parallel composition
To reproduce
When enabling gradient checkpointing before adapter training, ie:
model.gradient_checkpointing_enable()
ForwardContext will not be correctly set during forward/ backward passes. This means all functionality depending on ForwardContext will not work together gradient checkpointing. This affects some adapter types (reft, prompt tuning, prefix tuning; these won't work with gradient checkpointing currently) but not others (lora, bottleneck), also affects composition such as fusion and parallel.
E.g. will throw this error:
AttributeError: 'NoneType' object has no attribute 'output_adapter_gating_scores'
Also see https://github.com/adapter-hub/adapters/discussions/677.
To reproduce, try training ReFT using the QLoRA Llama notebook and gradient checkpointing enabled.