apex
apex copied to clipboard
Fix an issue where `child_name` can be None and make the overall ASP fail
When attempting to sparsify a transformers model, it appears for some reason child_name
can be None
and thus fx_graph.get(None)
returns None
and make the overall process crash.
This PR attempts to hotfix this but do no investigate why the FX graph generates such None
children.
I'm testing the output of the fx_graph.get(...)
to be non-null rather than child_name
being null because the first one will cover the second (fx_graph.get(None) -> None
).
CC @jpool-nv
Thanks for report such issue and the fix. Can you also upload the transformer model structure to help us understand the failing case?
@ChongyuNVIDIA sorry for the delay, please find below an example repro:
Please note you need transformers main branch in order to have FX support for BLOOM.
from apex.contrib.sparsity import ASP
from transformers import BloomForCausalLM, BloomTokenizerFast
from torch.optim import AdamW
if __name__ == '__main__':
model = BloomForCausalLM.from_pretrained("bigscience/bloom-350m")
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-350m")
optimizer = AdamW(model.parameters(), amsgrad=True)
ASP.prune_trained_model(model, optimizer)