safetensors icon indicating copy to clipboard operation
safetensors copied to clipboard

Efficient key-wise streaming

Open ljleb opened this issue 1 year ago • 3 comments

Feature request

I'm interested in streaming the tensors in a model key by key without having to hold all keys at the same time in memory. Something like this:

with safe_open("model.safetensors", framework="pt", device="cpu") as f:
    for key in f.keys():
        tensor = f.get_tensor(stream=True)
        # `tensor` will be garbage collected in the next GC pass
        #  as soon as the next iteration removes the only reference to it

Motivation

When I use safetensors.safe_open to load multiple models, the memory usage does not drop down even when the deserialized tensors do not have a reference held to them. This is a key by key streamed merge of 5 stable diffusion 1.5 checkpoints using a weighted sum:

(each vertical gray line is ~8GB)

image

For reference, this is my successful attempt at reading keys memory efficient in python: https://github.com/ljleb/sd-mecha/blob/9548ef83dd5d3fccdaf09c8b22dee7a0a7727613/sd_mecha/streaming.py#L12

And this is my successful attempt at making writing keys memory efficient: https://github.com/ljleb/sd-mecha/blob/9548ef83dd5d3fccdaf09c8b22dee7a0a7727613/sd_mecha/streaming.py#L156

Which looks like this:

image

Note that my implementation is relatively slow compared to simply using safetensors directly (approximately 1.1x to 1.3x slower according to some quick test I made). Is there any way the same could be achieved but in a more computationally efficient way using the rust bindings? Specifically, I need to stream the keys and the tensors without them being held somewhere else in memory.

Your contribution

I don't really know Rust but if nobody has time for this and there isn't a problem with my suggested approach to the API above, I will eventually have to implement this efficiently in one way or another for my merging lib.

ljleb avatar Feb 18 '24 23:02 ljleb

I am assuming holding the tensors in memory for the lifetime of the f object is necessary, but if my assumption is wrong, then this may be a bug.

ljleb avatar Feb 18 '24 23:02 ljleb

Hi - working on a similar problem at the moment, but couldn't see such increased memory usage. Have you tried running gc.collect() explicitly to free up memory?

mar-muel avatar Mar 01 '24 03:03 mar-muel

@mar-muel Have you tried running gc.collect() explicitly to free up memory?

I did try to gc.collect for every key, to no avail. I concede it may be caused by my setup, but I'm not sure how that could be the case, given I am using the library in the expected way (i.e. calling f.get_tensor(key) for each key in the state dict)

IIUC, currently tensors are mmap'd. I am extrapolating a little bit outside of my domain of expertise, but if that is the case, then I believe the paging mechanism of the OS should be what's expected to take care of the unused memory? If that is not the case, then I'm as confused as you are.

ljleb avatar Mar 11 '24 19:03 ljleb

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

github-actions[bot] avatar Apr 11 '24 01:04 github-actions[bot]