SLAM-LLM
SLAM-LLM copied to clipboard
fix #92 for fsdp training
What does this PR do?
Fixes #92
Issue explanation
This KeyError occurs because the value of fsdp_config.sharding_strategy is mandatorily converted into a str obj rather than ShardingStrategy obj. This is the intrinsic feature of omegaconf.dictconfig.DictConfig. See this for more details.
How to solve
Firstly, change the type annotation of FSDPConfig.sharding_strategy from str to torch.distributed.fsdp.ShardingStrategy in ${example}_config.py.
from torch.distributed.fsdp import ShardingStrategy
...
@dataclass
class FSDPConfig:
...
sharding_strategy: ShardingStrategy = "NO_SHARD"
...
Then, we can remove the type conversion codes in fintune.py, as the conversion will be done automatically.
# from torch.distributed.fsdp import ShardingStrategy
# fsdp_config.sharding_strategy = getattr(ShardingStrategy, fsdp_config.sharding_strategy)
Testing
-
[x] Test MALA_ASR finetune & decode scripts on 4 x V100
-
[ ] Other examples have not been actually tested, but they should work similarly
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?
Thanks for contributing 🎉!