SLAM-LLM icon indicating copy to clipboard operation
SLAM-LLM copied to clipboard

fix #92 for fsdp training

Open nuaalixu opened this issue 1 year ago • 0 comments

What does this PR do?

Fixes #92

Issue explanation

This KeyError occurs because the value of fsdp_config.sharding_strategy is mandatorily converted into a str obj rather than ShardingStrategy obj. This is the intrinsic feature of omegaconf.dictconfig.DictConfig. See this for more details.

How to solve

Firstly, change the type annotation of FSDPConfig.sharding_strategy from str to torch.distributed.fsdp.ShardingStrategy in ${example}_config.py.

from torch.distributed.fsdp import ShardingStrategy
...
@dataclass
class FSDPConfig:
    ...
    sharding_strategy: ShardingStrategy = "NO_SHARD" 
    ...

Then, we can remove the type conversion codes in fintune.py, as the conversion will be done automatically.

        # from torch.distributed.fsdp import ShardingStrategy
        # fsdp_config.sharding_strategy = getattr(ShardingStrategy, fsdp_config.sharding_strategy)

Testing

  • [x] Test MALA_ASR finetune & decode scripts on 4 x V100

  • [ ] Other examples have not been actually tested, but they should work similarly

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ ] Did you read the contributor guideline, Pull Request section?
  • [ ] Was this discussed/approved via a Github issue? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes?
  • [ ] Did you write any new necessary tests?

Thanks for contributing 🎉!

nuaalixu avatar Oct 14 '24 07:10 nuaalixu