apex
apex copied to clipboard
Fix an issue where `child_name` can be None and make the overall ASP fail
When attempting to sparsify a transformers model, it appears for some reason child_name can be None and thus fx_graph.get(None) returns None and make the overall process crash.
This PR attempts to hotfix this but do no investigate why the FX graph generates such None children.
I'm testing the output of the fx_graph.get(...) to be non-null rather than child_name being null because the first one will cover the second (fx_graph.get(None) -> None).
CC @jpool-nv
Thanks for report such issue and the fix. Can you also upload the transformer model structure to help us understand the failing case?
@ChongyuNVIDIA sorry for the delay, please find below an example repro:
Please note you need transformers main branch in order to have FX support for BLOOM.
from apex.contrib.sparsity import ASP
from transformers import BloomForCausalLM, BloomTokenizerFast
from torch.optim import AdamW
if __name__ == '__main__':
model = BloomForCausalLM.from_pretrained("bigscience/bloom-350m")
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-350m")
optimizer = AdamW(model.parameters(), amsgrad=True)
ASP.prune_trained_model(model, optimizer)