transformers
transformers copied to clipboard
Add torch compile for mixtral
This PR is working in progress and it tries to add torch compile support for Mixtral, it currently also contains changes from #30642 because there are some common ground shared between these two models, and there are several issues regarding Mixtral:
- we have to set the following flag to True in order to capture full graph with MOE
torch._dynamo.config.capture_dynamic_output_shape_ops = True
I believe it's inevitable because MistralSparseMoeBlock
uses torch.where
to extract tokens that each expert cares about, and the number and indexes of tokens that each expert attends to are variable, even if we do make a static shape(which means we zero out the non-care tokens for each expert), we are adding extra computation cost because zero-out values still get to take participate in computation, and each expert will have to run full tokens in terms of computation, which makes the whole point of computation-saving of MOE invalid.
- The logits tests on main branch are currently failing on my dev machine
=========================================== short test summary info ===========================================
FAILED tests/models/mixtral/test_modeling_mixtral.py::MixtralModelTest::test_custom_4d_attention_mask - AssertionError: assert False
FAILED tests/models/mixtral/test_modeling_mixtral.py::MixtralIntegrationTest::test_small_model_logits - AssertionError: Tensor-likes are not close!
FAILED tests/models/mixtral/test_modeling_mixtral.py::MixtralIntegrationTest::test_small_model_logits_batched - AssertionError: Tensor-likes are not close!
=========================== 3 failed, 112 passed, 35 skipped, 47 warnings in 34.78s ===========================