openfold
openfold copied to clipboard
Minor optimizations & fixes to support ESMFold
Hi OpenFold team! Thanks for your great implementation, here are some suggestions to the codebase to support upcoming ESMFold release:
- Constant tensors like default_frames, group_idx, atom_mask and lit_positions in StructureModule are lazily initialized as buffers to allow for flawless CPU<->GPU model conversion;
- Vectorized ops for row_mul and row_vec_mul + disable mixed precision for them, as precision loss may occur;
- Constant quaternions are cached to not be recreated on each call.
Thanks!
Do those autocast fixes work during DeepSpeed training, where an APEX-based autocast framework is used instead of the native torch one? The reason those operations are spelled out like that manually in the first place is to avoid automatic casting of all kinds.
@gahdritz Yeah this won't work with APEX amp, right, only the native torch. I guess I can just roll back this change.