DeepSpeedExamples
DeepSpeedExamples copied to clipboard
Possible to include an example of DeepNVMe + state dict
The examples under https://github.com/deepspeedai/DeepSpeedExamples/tree/master/deepnvme/file_access are great. However, the problem is real-world applications need to be able to pass a model state dict to a method and load them.
Would it make sense add utilities for this? Or is the recommended path for it to be used with https://github.com/deepspeedai/DeepSpeedExamples/tree/master/deepnvme/model_checkpoint/torch (which is a bit tough given it requires modifying an existing torch installation)?
@sayakpaul thanks for question. This is kind of feedback we are looking for to make DeepNVMe useful for the community.
Just to clarify your ask. Currently, the DeepNVMe read/write APIs take tensors. Are you suggesting new APIs that take state_dicts as well?
Would it make sense add utilities for this? Or is the recommended path for it to be used with https://github.com/deepspeedai/DeepSpeedExamples/tree/master/deepnvme/model_checkpoint/torch (which is a bit tough given it requires modifying an existing torch installation)?
Yes, this not easy-to-use, and is really a PoC at best. We plan to work with torch folks for a cleaner integration if there is sufficient interest. Is this a line you are generally interested in and open to collaborating?
Currently, the DeepNVMe read/write APIs take tensors. Are you suggesting new APIs that take state_dicts as well?
Precisely yes.
We plan to work with torch folks for a cleaner integration if there is sufficient interest. Is this a line you are generally interested in and open to collaborating?
Oh that is great to hear. In terms of collaborations, I am happy to test (I think we are connected through Slack via HF). I have a draft PR in the Diffusers repo here for the following feature.
We overlap communication with computation in the sense that we offload layers of a model currently not being used to RAM/Disk while completing the current computation. For the disk offloading part, I am trying to leverage DeepNVMe in https://github.com/huggingface/diffusers/pull/11758/. As you would notice I have copy-pasted stuff from https://github.com/deepspeedai/DeepSpeedExamples/tree/master/deepnvme/model_checkpoint (with credits) to make it work. This is what prompted me to open this issue thread. Any feedback is welcome.
Even in https://github.com/deepspeedai/DeepSpeedExamples/tree/master/deepnvme/model_checkpoint/, there's no state dict loading code (I am assuming that is not relevant for the purposes of benchmarking).
Another interesting experiment I conducted between regular Torch save+load vs. AIO save + torch load vs. safetensors save + load.
Code copy-pasted from this repo and hacked together
from save_model_utils import get_model, validate_arguments, parse_arguments
from torch_save_utils import load_io_ops, _test_ds_fast_save, test_save
import safetensors.torch
import os
import time
import torch
def test_sft_save(file, buffer, args):
st = time.time()
safetensors.torch.save_file(filename=file, tensors=buffer)
return time.time() - st
def main():
print(
f'Performance test of torch.save() integration of fast model checkpointing.'
)
print(f'torch version = {torch.__version__}')
torch.manual_seed(42)
args = parse_arguments()
if not validate_arguments(args):
quit()
load_io_ops(args)
model, tokenizer, model_name, ckpt_name = get_model(args.model)
inputs = tokenizer("I am good", return_tensors="pt").to("cuda")
if args.half:
model = model.half()
if args.gpu:
model = model.to("cuda")
with torch.no_grad():
model.eval()
pre_logits = model(**inputs).logits
if not args.safetensors:
file = os.path.join(args.folder, f'{ckpt_name}.pt')
else:
file = os.path.join(args.folder, f'{ckpt_name}.safetensors')
if os.path.exists(file):
os.remove(file)
if not args.regular_torch_save and not args.safetensors:
write_sec = _test_ds_fast_save(file, model.state_dict(), args, False)
elif args.regular_torch_save:
write_sec = test_save(file, model.state_dict(), args)
else:
write_sec = test_sft_save(file, model.state_dict(), args)
ckpt_size = os.path.getsize(file)
gb_size = ckpt_size / (1024**3)
gb_per_sec = gb_size / write_sec
print(
f'{gb_size:5.2f} GB, {write_sec:5.2f} secs, {gb_per_sec:5.2f} GB/s'
)
st = time.time()
if args.safetensors:
loaded_sd = safetensors.torch.load_file(file, device="cuda")
else:
loaded_sd = torch.load(file, weights_only=True, map_location="cuda")
load_sec = time.time() - st
print(f"Loaded in {load_sec:5.2f} seconds.")
model.load_state_dict(loaded_sd)
with torch.no_grad():
model.eval()
post_logits = model(**inputs).logits
assert torch.allclose(pre_logits, post_logits, atol=1e-3, rtol=1e-3)
os.remove(file)
if __name__ == "__main__":
main()
Results:
python compare_outputs.py --model phi3 --folder /data --gpu
14.23 GB, 1.19 secs, 11.97 GB/s
Loaded in 12.18 seconds.
python compare_outputs.py --model phi3 --folder /data --gpu --regular_torch_save
14.23 GB, 9.99 secs, 1.42 GB/s
Loaded in 9.08 seconds.
python compare_outputs.py --model phi3 --folder /data --gpu --safetensors
14.23 GB, 16.37 secs, 0.87 GB/s
Loaded in 1.64 seconds.
Oh that is great to hear. In terms of collaborations, I am happy to test (I think we are connected through Slack via HF)
Cool. I have pinged on slack.
https://github.com/deepspeedai/DeepSpeedExamples/tree/master/deepnvme/model_checkpoint (with credits) to make it work. This is what prompted me to open this issue thread. Any feedback is welcome.
This is very cool. Thanks for sharing. It seems you are using the API as best as possible. As you have observed the API is really designed for torch.save(). A state_dict API would be more efficient and convenient for your use case.
Even in https://github.com/deepspeedai/DeepSpeedExamples/tree/master/deepnvme/model_checkpoint/, there's no state dict loading code (I am assuming that is not relevant for the purposes of benchmarking).
Yes, FastPersist did not address checkpoint loading. But there is now good reason to revisit that.
Another interesting experiment I conducted between regular Torch save+load vs. AIO save + torch load vs. safetensors save + load.
Thanks for sharing these early results. But I am not sure how to interpret. Can you please confirm the store and load bandwidths of each of the 3 combinations?