[RFC] TorchStore - A Shared-Memory Tensor Store
🚀 Feature
TorchStore is a key-value store that holds ATen tensors in shared memory so that they can be accessed across process boundaries without any expensive copy operations.
Problem
With the trajectory of ever-increasing model sizes the conventional handling of inference loads by multiple worker processes has started to become a memory bottleneck. Since each worker requires its own instance of the model (or models), the memory use proportionally increases with the number of workers. Considering that today’s multi-core systems can easily handle a large number of workers, this means a big portion of physical memory potentially goes wasted with redundant copies of the same model. Some users, as described here, have successfully leveraged Apache Arrow’s Plasma Object Store to circumvent the problem, but the overall user experience is less than ideal and a production-ready solution requires quite a bit of manual work.
Similarly training tasks in various machine learning fields such as graph neural networks or knowledge distillation happen to use read-only models during training. The same problem as in inference exists here as well. Many redundant copies of the same model waste the precious memory resources of the host machine. The problem is even more amplified in the training case since the memory is typically already under pressure due to the optimizer and backpropagation states. This issue in the PyTorch Lightning repository describes a use case that suffers from this problem.
Yet another negative implication is the load times of such replicated models. Some large models can take up to several minutes to load and during that time they typically saturate the I/O bandwidth of the host machine. Attempting to load the same model concurrently by many worker processes will only prolong the load times, most of the time significantly. The same blog post that was mentioned earlier in the first problem statement describes this problem in more depth.
Why not Tensor.share_memory_()?
Since PyTorch already provides the share_memory_() method as part of its Tensor API, a legitimate question is why we need a separate solution. There are two major reasons:
- In training tasks, in order to leverage
share_memory_(), one has to control both the parent process and the worker processes. The typical implementation loads a tensor or a module in the parent process and then shares it with the spawned worker processes viatorch.multiprocessing. However with TorchElastic, which is now part of PyTorch, and the upcomingtorchrunlauncher script, the task of spawning worker processes is handled by PyTorch itself. This means user-code has no chance to influence the initialization logic. - For inference workloads the lifetime of worker processes are not bound by a parent process. They are independently initiated and terminated, usually by a 3rd party request coordinator. In such setting
share_memory_()is of no use. - ~With
share_memory_()a tensor can only be shared from the parent process to its child processes. If a child process wishes to share a tensor with the rest of its peers, there is no mechanism for it.~ This turned out to be not true. See Edward's feedback below.
What is TorchStore?
TorchStore is a key-value store that holds ATen tensors in shared memory. The tensors can be accessed across process boundaries (e.g. between workers) without any expensive serialization or de-serialization. In other words sharing tensors between processes via TorchStore is a zero-copy operation.
Starting the Store
The store can be run either as a daemon or as an in-process service. In order to run it as a daemon the torchstored program can be used:
$ torchstored --name my-store
# Or alternatively run in foreground. This allows easy monitoring in terminal.
$ torchstored --name my-store --no-fork
The daemon mode is useful if a persistent store is desired that is not bound to any particular task. For instance an inference server can start the daemon and bootstrap it by loading models and tensors (e.g. large embedding tables). The ephemeral worker processes can simply connect to the store and get access to its entries.
In contrast the in-process mode is more appropriate if the entries held by the store are specific to a particular task. For instance if the store holds a teacher model for a knowledge distillation task that is run on a multi-tenant cluster, there is no need to preserve it after the training is finished. The in-process service can be started via a simple API call.
import torchstore
store = torchstore.start(name="my-store")
Connecting to the Store
To connect to a store, simply call connect() using the name of the store. In order to support use cases where the store is started by one of the worker processes (i.e. in-process mode) while its peers attempt to connect to it, connect() has built-in retry logic and can seamlessly handle connection attempts made to a store that has not started yet, but is about to start.
import torchstore
store = torchstore.connect(name="my-store")
Saving to the Store
The store supports both CPU and CUDA tensors. A store entry can hold a tensor, a list of tensors, or a map of string-tensor pairs. Moreover saving multiple tensors as part of an entry (i.e. a list or a map) preserves their view relationships, meaning a shared “storage” is only saved once. One limitation though is that, similar to Tensor.share_memory_(), a tensor in store cannot be resized.
The save() API is fairly straightforward:
import torch
import torchstore
cpu_tensor = torch.ones([10, 10])
gpu_tensor = torch.ones([10, 10], device="cuda:0")
list_of_tensors = [torch.rand([10, 10]) for _ in range(4)]
dict_of_tensors = {
"tensor1": torch.rand([10, 10]),
"tensor2": torch.rand([10, 10]),
}
store = torchstore.connect(name="my-store")
store.save("entry1", cpu_tensor)
store.save("entry2", gpu_tensor)
store.save("entry3", list_of_tensors)
store.save("entry4", dist_of_tensors)
Although using the save() API as demonstrated above is perfectly fine, there is one issue with that example. Each tensor is first initialized locally and then copied over to the store. This means, albeit temporarily, twice as much memory is required to save a tensor. The storage() API demonstrated below mitigates this problem. It temporarily redirects PyTorch’s CPU and CUDA allocators to their store versions so that later save() calls won’t require any copy operation.
import torch
import torchstore
store = torchstore.connect(name="my-store"):
with store.storage():
# The storage for this tensor is allocated from the store.
tensor_alloc_in_store = torch.ones([10, 10])
# This is now a zero-copy operation.
store.save("my-tensor", tensor_alloc_in_store)
If a tensor gets allocated from the store as demonstrated above, but does not get saved into the store via a save() call, its storage will be deallocated once it goes out of scope. This means there is no harm in allocating from the store, but no gain either. The general advice is to save tensors allocated in the scope of storage() to the store.
Another difficulty with a regular save() call is the management of an entry's lifetime. In particular in publish/subscribe problems where an entry has no specific owner, determining the right time to delete it from the store poses a problem. To mitigate that problem, the save() API accepts a time-to-live argument that specifies how long an entry should be kept in store once its external reference count drops to zero. Thanks to @VoVAllen for bringing up this problem.
import torch
import torchstore
store = torchstore.connect(name="my-store")
# Delete 'my-tensor' from the store after 30 seconds unless a client
# retrieves it in the meantime.
store.save("my-tensor", torch.ones([10, 10]), ttl=30)
Saving a Model to the Store
Since a model’s state_dict is a map of string-tensor pairs, saving a model to the store resembles very much saving it to a file.
import torch
import torchstore
class MyModel(torch.nn.Module):
...
store = torchstore.connect(name="my-store")
with store.storage():
m = MyModel()
# Save the state dictionary of the model to the store.
store.save("my-model", m.state_dict())
Loading from the Store
In order to retrieve a store entry, the load() API should be called. Its simplest use is demonstrated below.
import torchstore
store = torchstore.connect(name="my-store")
tensor = store.load("my-tensor")
Beyond its basic usage, load() also accepts a timeout argument. If the specified timeout is greater than zero, the call will wait until either the entry is in the store or the request times out. This mechanism helps peer processes (i.e. workers) to coordinate; one process can add a new entry while the other waits for the entry to appear in the store. A timeout value of zero (default) means to fail immediately if the entry cannot be found. A value of -1 means to wait indefinitely.
import torchstore
store = torchstore.connect(name="my-store")
# Wait up to 30 seconds for 'my-tensor' to show up in the store.
tensor = store.load("my-tensor", timeout=30)
Besides a single entry, the load() API can also be used to load a set of entries at the same time. This reduces the amount of IPC communication required and is also helpful to specify a single timeout for a group of logically related entries. Thanks to @yifuwang for this feature suggestion.
import torchstore
store = torchstore.connect(name="my-store")
# Load the entries 'entry1', 'entry2', and 'entry3' at the same time. Wait up
# to 30 seconds for all three entries to show up in the store.
list_of_tensors = store.load(["entry1", "entry2", "entry3"], timeout=30)
Closely related to load() is the contains() API that, as its name suggests, simply checks whether the store contains a specified entry.
import torchstore
store = torchstore.connect(name="my-store")
if store.contains("my-tensor"):
...
Loading a Model from the Store
A model can be loaded from the store similar to how its loaded from a file. A state_dict that was saved earlier via save() can be retrieved and passed to the model.
import torch
import torchstore
store = torchstore.connect(name="my-store")
class MyModel(torch.nn.Module):
...
m = MyModel()
m.load_state_dict(store.load("my-model"))
Deleting from the Store
The store internally uses a reference counting mechanism. Removing an entry from the store decrements its reference count and marks the entry as removed so that it cannot be retrieved anymore. However entries that are already in use (i.e. loaded in processes) will be kept alive until they go out of scope.
import torchstore
store = torchstore.connect(name="my-store")
tensor = store.load("my-tensor")
# The store will mark 'my-tensor' as removed and release its internal reference; however
# the outstanding references will keep the entry alive.
store.remove("my-tensor")
# Once the reference count of `tensor` drops to zero with the next statement, the
# store will completely discard and deallocate the entry.
tensor = None
Shortcoming
One shortcoming of PyTorch is the lack of immutable tensors which requires users of TorchStore to be extra cautious. Ignoring the slight nuance between a read-only versus immutable tensor; if a tensor gets accidentally modified, the changes will be immediately reflected in other processes. There is unfortunately no mechanism to prevent this kind of behavior. Having said that, this “feature” can be helpful in certain scenarios where one writer process is allowed to modify the shared tensor, while others treat it as read-only (this use case also clarifies the difference between read-only and immutable).
Alternatives
For training tasks one alternative is to keep using Tensor.share_memory_() and extending TorchElastic with some form of pre-processing hook. The user will provide a function that performs the necessary initialization in the launcher process and TorchElastic will forward the return value of the function to the worker processes via torch.multiprocessing. Unfortunately this alternative won’t address the problem for inference tasks or for frequently-run tasks that want to store models or tensors persistently.
Another alternative is to leverage Apache Arrow’s Plasma Object Store. The advantage is that we won’t need to implement our own store and will simply wrap Plasma. However this alternative also has its disadvantages: 1) no support for the storage() API, saving a tensor to the store will always require double-allocation, 2) no support for in-process store meaning the user has to manually start a separate process for the store regardless of the use case, 3) dependency on the Apache Arrow library which has a large dependency closure.
Design Considerations
- TorchStore will be available only for Linux and macOS systems. Since the implementation is highly OS-dependent, there are no plans to support non-POSIX systems (e.g. Windows).
- There will be Python and C++ APIs that follow the conventions, idioms, and guidelines of the PyTorch code base.
- The project will be hosted in its own repository for easier iteration and faster release cadence. Similar to the evolvement of the TorchElastic project, this decision can be re-considered once the project reaches certain maturity.
cc @VitalyFedyunin @mrshenli @pritamdamania @rohan-varma @kiukchung @bowangbj @aivanou @SciPioneer @H-Huang @tchaton @rusty1s @BarclayII @VoVAllen @houseroad @anj-s @divchenko
Triage review to discuss the right tags to apply to this
@jbschlosser I think there was a race condition :) looks like we updated the tag at the same time. I had no intent to remove your tag.
Since this is shared-memory, would I be correct if I assume torchstore aims to support multiple processes on the same machine, instead of across machines?
Looks very nice to me. Some thoughts
- For the remove part, one thing needs to be decided is whether to free the memory immediately, or set an upper limit of the shared memory and use LRU cache to invalidate some. One thing I met before is transferring tensor between process using PlasmaStore, the producer might exit before the consumer receive the tensor, which mean the reference count goes to zero when producer exits. If the tensor is freed directly, the error happens. Thus it's tricky for producer to keep track such scenario. Another solution is adding some callback/event notification mechanism for producer to hold the reference till it's consumes. I think this rfc might also help simplify the shared memory logic in multiprocessing dataloader.
- For the
with store.storage()API, I like this design. This reminds an old feature request about the external memory allocator https://github.com/pytorch/pytorch/issues/43144. Is it possible to bring this into the design scope also? - PlasmaStore supports hugepages as one option. However I'm not sure how much performance improvement we can gain from this. I think it's fine to skip this at the first version
CC @wenleix @VitalyFedyunin
The problem solved here seems reasonable (TorchStore is a management process that can be used to deal with shared memory in situations where there is no obvious parent process to do memory management) but it seems to me that there are some underlying tools that you have to build that would be of interest to OSS users, instead of having to buy into a complete key-value store. In particular:
- You still need some way of connecting to the store without a preexisting queue that is already shared between the processes (e.g., in the inference case)
- You need some way of allocating a tensor but having it be owned by a different process from the start
- You need a distributed refcounted mechanism (maybe you're planning to use the built-in one here?)
This also makes me raise an eyebrow:
It temporarily redirects PyTorch’s CPU and CUDA allocators to their store versions so that later save() calls won’t require any copy operation.
I'm not sure how you are actually going to implement this in a reasonable way. The CPU/CUDA allocators are global state and not really designed to be context manager overwritten in this way.
This is not important, but there is a factual misstatement here:
With share_memory_() a tensor can only be shared from the parent process to its child processes. If a child process wishes to share a tensor with the rest of its peers, there is no mechanism for it.
So long as the child processes share a queue, they can directly transmit to one another, without needing to go through a parent process:
import torch
import torch.multiprocessing as mp
def send(q):
x = torch.randn(4)
x.share_memory_()
print(f"send: {x}")
q.put(x)
def recv(q):
print(f"recv: {q.get()}")
if __name__ == '__main__':
q = mp.Queue()
p1 = mp.Process(target=send, args=(q,))
p2 = mp.Process(target=recv, args=(q,))
p1.start()
p2.start()
p1.join()
p2.join()
This applies for CUDA tensors too, although when @VitalyFedyunin implemented the cross-reference refcounting, we made it so that you couldn't forward on a received CUDA tensor to yet another process.
The proposal sounds reasonable to me, but there are a couple of aspects which makes implementation complicated:
- Distributed reference counting is hard, especially because you are proposing server-(many)clients architecture. This means you would have to take care of the 'dead' client situation when reference is claimed but never released.
- It is hard to manage ownership and scope, you might want to simplify it and take an approach like: server allocates tensors in CPU/CUDA/Device memory and shares IPC Handles with clients. Meaning if someone saved tensor to the server and modified it after, changes are not going to be reflected on other clients.
- Taking into account p2. what will happen if I save a new tensor under the same name?
Thanks for your feedback folks!
@VitalyFedyunin below my answers:
- Distributed reference counting is hard, especially because you are proposing server-(many)clients architecture. This means you would have to take care of the 'dead' client situation when reference is claimed but never released.
This is right. Although I deliberately avoided going into the technical details of the solution and tried to get feedback on the general idea, I think I should include at least a brief section in the RFC.
Once a client calls connect() it will open up a Unix Domain Socket to the server. The server will keep record of every entry that has been requested by that connection (more technically incref'ed by that connection). If for some reason the connection breaks (e.g. the worker process crashes), the server will perform a clean up, meaning it will decrement the ref count of all outstanding entries.
- It is hard to manage ownership and scope, you might want to simplify it and take an approach like: server allocates tensors in CPU/CUDA/Device memory and shares IPC Handles with clients. Meaning if someone saved tensor to the server and modified it after, changes are not going to be reflected on other clients.
This is exactly how I plan to implement it. The clients will never perform an allocation themselves. They will request from the server to allocate some buffer and in response they will get a device-specific IPC handle back. They will be responsible to "decref" that allocation once it is not needed anymore.
Even though the API overall gives an impression that this is a "distributed" solution, it is in fact much more a client-server solution. The majority of the logic and bookkeeping happens in the server (including memory management).
- Taking into account p2. what will happen if I save a new tensor under the same name?
Having an allocation back from server and saving a tensor are two separate actions. You can think of the IPC handle that you get back from the allocation request as a temporary (in C++ limbo a prvalue), while saving a tensor is assigning that temporary to a named variable (lvalue in C++).
Having said that, there are three cases related to your last question:
- The most straightforward use case; if there is no entry with that name, your tensor will simply be saved to the store.
- If there is already an entry with that name, you will get an error. This is like trying to define a second variable with the same name.
- If an earlier call to
remove()deleted an entry with the same name, the allocation of the old entry will still be kept alive as long as there are clients using it. However this won't prevent another entry with the same name to be saved to the store.
@ezyang please see my answers below.
You still need some way of connecting to the store without a preexisting queue that is already shared between the processes (e.g., in the inference case)
This will be a typical client-server architecture. Each client will have a persistent connection to the server via Unix Domain Sockets. The name specified with the connect() call will be used to determine the socket path (e.g. for "my-socket" it will be /var/run/torchstore/my-socket)
You need some way of allocating a tensor but having it be owned by a different process from the start
The tensors will be always allocated and owned by the server. The clients will get back an IPC handle from the server to access the allocated buffer. Refcounting is also centralized. The server will have its own ref to an entry until remove() is called, and it will increment the refcount each time an entry handle is sent to a client. The clients are responsible to decref their handle once they go out of scope (using RAII or a similar mechanism). This is pretty similar to how native Python container and sequence types refcount their elements.
You need a distributed refcounted mechanism (maybe you're planning to use the built-in one here?)
See my answer to the previous item.
This also makes me raise an eyebrow:
It temporarily redirects PyTorch’s CPU and CUDA allocators to their store versions so that later save() calls won’t require any copy operation.
I'm not sure how you are actually going to implement this in a reasonable way. The CPU/CUDA allocators are global state and not really designed to be context manager overwritten in this way.
My idea is to record the thread that called the context manager in a thread-local variable, and then override the existing allocator using SetAllocator(). Later on if an allocation is requested I either forward it to the store if it is the context-owning thread, or call the original allocator. And finally at the end of the context, I revert back to the original allocator.
This is not important, but there is a factual misstatement here:
With share_memory_() a tensor can only be shared from the parent process to its child processes. If a child process wishes to share a tensor with the rest of its peers, there is no mechanism for it.
Oh, okay. I wasn't aware of that capability, my bad. I will update the RFC. Thanks for pointing it out!
@VoVAllen please see my answers below.
Looks very nice to me.
Thanks!
For the remove part, one thing needs to be decided is whether to free the memory immediately, or set an upper limit of the shared memory and use LRU cache to invalidate some. One thing I met before is transferring tensor between process using PlasmaStore, the producer might exit before the consumer receive the tensor, which mean the reference count goes to zero when producer exits. If the tensor is freed directly, the error happens. Thus it's tricky for producer to keep track such scenario. Another solution is adding some callback/event notification mechanism for producer to hold the reference till it's consumes. I think this rfc might also help simplify the shared memory logic in multiprocessing dataloader.
In this scenario, does the producer explicitly call delete/remove on the tensor before exiting? In TorchStore the store itself will always keep a reference to its entries besides the references kept by the clients. A tensor will only be deallocated if all clients release their references and one explicitly calls remove() on the store. Would that work in this scenario?
For the
with store.storage()API, I like this design. This reminds an old feature request about the external memory allocator Using external memory allocator with PyTorch #43144. Is it possible to bring this into the design scope also?
A general mechanism for external memory allocators is a much broader topic and requires much more thought than this constrained use case. Therefore I would prefer to keep the two discussions separate. We have several subject matter experts in this thread though (e.g. @ezyang) and they might already have some information on that topic.
PlasmaStore supports hugepages as one option. However I'm not sure how much performance improvement we can gain from this. I think it's fine to skip this at the first version
Yes, it is something I am already considering. Working with (transparent) huge pages can be a hit-and-miss though, so I am cautious about them. Let's see how things perform without them, and depending on the results we can decide on it.
Hey @mrshenli
Since this is shared-memory, would I be correct if I assume
torchstoreaims to support multiple processes on the same machine, instead of across machines?
That is correct. TorchStore is intra-host only.
Thanks Can, overall the direction looks pretty nice! I feel this can be pretty useful for inference. For training, we need to come up with some hero examples. NLP.pre_training_embedding_table lookup support for text related task like sentiment analysis might be one concrete example. A few comments and questions below:
Q1: re2 inference, can we assume all workers (same machine) refer to entire module is the main use case?
Q2: git repo, have you considered torch/distributed/_store which is similar as torch/distributed/_sharded_tensor given that _store seems an integral part of PT/PT-D?
Q3: related to Q2, does it make sense to make the namespace consistent like torch.distributed.store?
Q4: best practice to start the store, given that all dependent workers need to wait for the store be ready. would it simplify the design if we recommend store as separate process?
Q5: do we want to allow multiple store instances on one machine?
Q6: naive comment for the writing API, I feel store might be something close to torch.device primitives, would it be possible to have something like
- device = torch.device(“store/{store-name}/cuda:0”) and support
- tensor = torch.ones(10, 20, device=device), or saved = tensor.to(device)
Q7: loading API is not truly loading from somewhere, probably something like store.get({key}) is more natural.
Q8: do we really need dist ref count? I feel it might be useful but it also has lots of implications. Probably treat it as an advanced feature (P2) for now, we can focus on the basic APIs and validate how useful the overall store is first?
Q9: API user guide. It would be nice to have a concrete example to showcase:
- store prep (list of tensors are written)
- how would the distributed training script be update to consume the shared store? Ideally the cmt of code (compared to current experience) would be quite minimum
- what would be the RAM saving?
Thanks for the fedback @bowangbj!
For training, we need to come up with some hero examples.
Please see my answer to Q9 regarding this.
Q1: re2 inference, can we assume all workers (same machine) refer to entire module is the main use case?
I assume that it would be the main use case, but it is also possible to save an arbitrary tensor or a list of tensors. It doesn't have to be necessarily a module.
Q2: git repo, have you considered torch/distributed/_store which is similar as torch/distributed/_sharded_tensor given that _store seems an integral part of PT/PT-D?
At this time I much prefer this to be in his own repo. Since this is more than just a library (also a daemon and a CLI), I want to it to have a faster release cadence than PyTorch, at least till v1 release. This is also how TorchElastic evolved over time and having a separate repo certainly helped it to move faster.
Q3: related to Q2, does it make sense to make the namespace consistent like torch.distributed.store?
We can reevaluate the namespace once we decide to migrate the project into the PyTorch code base. Note though that we already use the name "Store" for our key-value store used by process groups and collective functions. We should think of a different name :)
Q4: best practice to start the store, given that all dependent workers need to wait for the store be ready. would it simplify the design if we recommend store as separate process?
The current API supports both use cases. The reason why I wanted to support an in-process store is to make it easier for users to start using it without fiddling with their launcher scripts. For instance if you are on Slurm, an in-process store makes it possible to srun your script without worrying about daemons or external processes. In summary it is a convenience feature.
Q5: do we want to allow multiple store instances on one machine?
Yes, as long as they have unique names, there can be multiple stores. Each store will listen for its clients under /var/run/torchstore/<store-name>
Q6: naive comment for the writing API, I feel store might be something close to torch.device primitives, would it be possible to have something like
- device = torch.device(“store/{store-name}/cuda:0”) and support
- tensor = torch.ones(10, 20, device=device), or saved = tensor.to(device)
Well there are two issues with that approach: 1) introducing a new device type in PyTorch is a substantial task, and goes beyond just storage allocation. It also directly influences the op dispatcher and many other aspects of the API. Also I am not entirely sure whether you can have "dynamic" parts such as {store-name} in the device name. 2) I want users to be able to save their existing modules and tensors to the store. Introducing a new device type would require users to modify their code.
Q7: loading API is not truly loading from somewhere, probably something like store.get({key}) is more natural.
Well the reason why I picked the names save/load instead of get/set was the lack of immutability of PyTorch tensors. Although I agree that technically we are not loading anything, we are still instantiating a Tensor instance per process. And the user is free to modify certain aspects of that local tensor such as its size or its strides that does not get reflected to other processes. Using reference semantic (i.e. get) would imply as if all processes were getting the same instance, while in fact they are getting the same storage wrapped with a locally instantiated tensor. This is why I used API names that imply value semantics (save/load).
Q8: do we really need dist ref count? I feel it might be useful but it also has lots of implications. Probably treat it as an advanced feature (P2) for now, we can focus on the basic APIs and validate how useful the overall store is first?
You can check out my answers to Edward and Vitaly above. We are not using a distributed ref counting mechanism. It is true that the server will keep track of which clients have references to which store entries, but other than that its ref counting mechanism is no different than a local one. Note also that all allocations happen in the server. The clients will never allocate storage on behalf of the store, which makes reference tracking substantially easier (and non-distributed).
Q9: API user guide. It would be nice to have a concrete example to showcase:
In fact the best outcome of this RFC would be if I could find a customer who is already willing and committed to use TorchStore once it become available. It would be way easier for me to get leadership buy-in if I could demonstrate its potential impact with a concrete real-world use case.
Dear @cbalioglu,
Very cool proposal !
I would have several questions:
- Q1: Do you envision the store be used with sharded models ?
store.save("my-model", m.state_dict(), rank=..., world_size=...)
- Q2: Would the store be usable within the workers of the PyTorch DataLoader ?
Best, T.C
Thank you @tchaton!
Q1: Do you envision the store be used with sharded models ?
Do you mind clarifying what type of sharded model you mean here? There are various techniques to shard a model (e.g. horizontally, vertically, intra-layer, etc.)
Q2: Would the store be usable within the workers of the PyTorch DataLoader ?
I see no reason why it wouldn't be usable. Again, do you have a specific use case in mid?
Cheers.
@VoVAllen regarding your first question again. I think I now get what you mean by the producer/consumer problem. What do you think about an optional time-to-live argument to save() API? The producer can call it like:
torchstore.save("my-tensor", tensor, ttl=60)
which means if there are no outstanding references to "my-tensor" for 60 seconds (i.e. the tensor is not used by any client), automatically delete it from the store.
That's a great proposal, @cbalioglu, as indeed with the growing models we run into a lot of memory duplication issues.
The only bit that I feel needs clarifying in the API is the ability to lock the tensor, otherwise as you wrote not only any process can modify it, there is also a potential race condition when doing so. Surely, this can be implemented in user's code, but it'd be useful to have this as a built-in feature.
Thanks for the feedback @stas00! I agree that this is not ideal. #44027 is a long standing proposal to introduce immutable tensors, but apparently it hasn't gained much traction due to some subtle issues it brings.
One idea might be to have a debug flag like:
store = torchstore.connect(name="my-store", debug=True)
which would map the shared memory as read-only on the client. This would cause any write attempts to segfault. Unfortunately there is no similar mechanism for CUDA device memory, so this would be a half-backed solution.
Well, I was thinking of adding an optional lock API which once acquired blocks other peers until it's released - i.e. following flock - and as mentioned this of course can be done on the user side since flock is a great locking tool when limited to a single node.
And then it could have store- and key-granularity.
To clarify: the proposal is to have an "elective" locking and not forced, and requiring all peers to agree to use it.
@cbalioglu TTL seems a graceful solution to me. Also I'd like to add another scenario need torchstore. In the rpc scenario, the server may take long time to handle a single task, but due to the GIL, the utilization of resources might be low. Thus to increase the throughput, one solution is to use multiprocessing to start multiple replicated servers to handle the requests. However, it's quite common for those servers to share kind of states, which can be handled by torchstore. For the mutable problem, I think torchstore can let user to handle it, for example by the lock using TCPStore or other Store in PyTorch.
This is a great feature. I really wish we had this earlier that we not need to hack torch.multiprocessing for hogwild.
I remember @pritamdamania did a lot of work to make sure we did not lose any gradient during backward with multiprocesses, do we need to implement the equivalent for torchstore?
This is a great feature.
Thank you @zzzwen!
I remember @pritamdamania did a lot of work to make sure we did not lose any gradient during backward with multiprocesses, do we need to implement the equivalent for torchstore?
I do not have much context about that past work. @pritamdamania do you mind chiming in?
Thanks for the great write up @cbalioglu!
Q6: naive comment for the writing API, I feel store might be something close to torch.device primitives, would it be possible to have something like
I also had the similar question as @bowangbj on whether it's better for the store to offer a virtual device semantic. While I'm not suggesting it's a better or even feasible design, I think it's an interesting idea worth more discussion. I have a few comments on the issues you pointed out:
Well there are two issues with that approach: 1) introducing a new device type in PyTorch is a substantial task, and goes beyond just storage allocation. It also directly influences the op dispatcher and many other aspects of the API. Also I am not entirely sure whether you can have "dynamic" parts such as {store-name} in the device name.
While I'm not very familiar with the op dispatching mechanism, isn't meta device already similar in the way that only tensor's storage gets special treatment? IIUC, to use meta tensors with operators, it has to be first moved to a physical device. I imagine the semantic would probably work well here (if the destination device is the same device backing the store, the move would not result in any copy).
- I want users to be able to save their existing modules and tensors to the store. Introducing a new device type would require users to modify their code.
Since the high level description of the project is "TorchStore is a key-value store that holds ATen tensors in shared memory ...", I wonder if it's ideal to support non-tensor objects (or treating it as a primary concern). If torchstore only handles tensor storage, I think object lifetime can be handled completely with reference counting (may need to borrow some ideas from the distributed reference counting protocol of RRef implemented by @mrshenli).
Hi @cbalioglu Interesting. It makes sens and seems straightforward for CPU backend(s) but you dont detail much how that would work for the cuda backend. It s indeed easy for process A and B to read a shared tensor in (posix) sharedmem, shared mem being 'just' an sharable system RAM space: https://man7.org/linux/man-pages/man7/shm_overview.7.html
Nvidia's GPU's RAM are not really open to that kind of sharing. Are you consequently refering to a nvlink solution: https://www.nvidia.com/en-us/data-center/nvlink/ and/or cuda unified memory: https://developer.nvidia.com/blog/unified-memory-cuda-beginners/ ? Refs: the default alloc of pytorch/c10 as today allocates 'normal' cuda mem, not managed(unified): https://github.com/pytorch/pytorch/blob/b7adb3350ac9115869a0c54eec326e387ccbcf6b/c10/cuda/CUDACachingAllocator.cpp#L306
fbgemm seems to have a managed tensor alloc: https://github.com/pytorch/FBGEMM/blob/dccc5736615b1fcd268293833ab3b337b86a54ef/fbgemm_gpu/src/cumem_utils.cu
The managed cuda malloc: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY.html#group__CUDART__MEMORY_1gd228014f19cc0975ebe3e0dd2af6dd1b
kind W.
Thanks for the feedback @yifuwang!
While I'm not very familiar with the op dispatching mechanism, isn't
metadevice already similar in the way that only tensor's storage gets special treatment? IIUC, to usemetatensors with operators, it has to be first moved to a physical device. I imagine the semantic would probably work well here (if the destination device is the same device backing the store, the move would not result in any copy).
Although I agree that a device-based approach would neatly integrate with the rest of the PyTorch API, beyond its organizational challenges (introducing a new device type needs rigorous validation from the core engineering team to ensure that we do not accidentally break anything) there are also some technical shortcomings 1) How would you specify a timeout (or any additional future parameters) when you want to retrieve an entry via a device? 2) How would the device abstraction behave if the entry is not found? We would be overloading the semantics for a device and a store which I believe would be confusing 3) How would you initialize higher level constructs (e.g. a model represented by an arbitrary torch.nn.Module type)? In summary I see many risks/edge-cases with a device-based API compared to regular good-old save/load approach. And although I believe that an API should offer only a single way to accomplish a task, we can still evaluate a device-based API once we have a working version of the store.
Since the high level description of the project is "TorchStore is a key-value store that holds ATen tensors in shared memory ...", I wonder if it's ideal to support non-tensor objects (or treating it as a primary concern). If
torchstoreonly handles tensor storage, I think object lifetime can be handled completely with reference counting (may need to borrow some ideas from the distributed reference counting protocol of RRef implemented by @mrshenli).
In fact there will be a lower-level API that works with simple blobs. The API described here is the high-level PyTorch specific one, I also like to have some native wrappers for DLPack and Python Buffer Protocol. I did not mention them in the proposal, so that the discussions do not get sidetracked. And as I mentioned in my previous comments since internally the store uses a client-server architecture, reference counting would be much more simple than a truly distributed refcounting mechanism.
Hi @WilliamTambellini, thanks for your feedback!
CUDA Runtime offers an IPC mechanism for interprocess memory sharing. Using the cudaIpcGetMemHandle() and cudaIpcOpenMemHandle() functions you can pass a "handle" to a (non-managed) device memory from one process to another. The approach is a bit different than how you open a regular shared memory region, but ultimately works the same way. The only limitation is that it is only supported on 64-bit Linux systems. However it is not a deal breaker for us since we intend to support only Linux and macOS, and CUDA support on macOS is basically non-existent anyways.
Since the high level description of the project is "TorchStore is a key-value store that holds ATen tensors in shared memory ...", I wonder if it's ideal to support non-tensor objects (or treating it as a primary concern). If torchstore only handles tensor storage, I think object lifetime can be handled completely with reference counting (may need to borrow some ideas from the distributed reference counting protocol of RRef implemented by @mrshenli).
One thing I realized from this comment, that we might need some APIs to store several tensors and some metadata together. For example supporting the sparse tensor coo/csr, which consists three tensors (coo: row, col, data) and some metadata (num_rows). However I'm not sure whether it's better to be handled at higher abstraction (that in python level such as customized pickler) or lower level.
Hi @cbalioglu we all agree that is nice feature to have. But it became hard to see the history of RFC changes and track conversations. Could you please create google doc and share it in all can comment mode. As soon as we solidify RFC we can migrate it back to the GitHub. Let me know if you need my help with it.
One thing I realized from this comment, that we might need some APIs to store several tensors and some metadata together. For example supporting the sparse tensor coo/csr, which consists three tensors (coo: row, col, data) and some metadata (num_rows). However I'm not sure whether it's better to be handled at higher abstraction (that in python level such as customized pickler) or lower level.
I totally agree that the API should be flexible to accommodate different use cases. I deliberately avoided diving into lower level details, but ultimately for the core API an entry will be made of a data pointer + some optional metadata. All higher level constructs (tensors, list of tensors, modules) will be ultimately serialized into a data buffer + some metadata. The store itself will provide native support for PyTorch tensors, DLPack tensors, and Python Buffer Protocol (and lists/dictionaries of them), however the low-level API would enable anyone to add support for other constructs as well, be it a CSR or COO matrix.
Hi @cbalioglu we all agree that is nice feature to have. But it became hard to see the history of RFC changes and track conversations. Could you please create google doc and share it in all can comment mode. As soon as we solidify RFC we can migrate it back to the GitHub. Let me know if you need my help with it.
Thanks for the suggestion @VitalyFedyunin. It definitely makes sense. Let me migrate it to Google Doc. I will let you know if I get stuck anywhere, tbh I don't have much experience with Google Doc :)
This would be very useful. Are there any new update on this RFC @cbalioglu ?
Hi @lamhoangtung, it is still on our roadmap, but we don't have an exact ETA yet. Do you have a particular use case in your mind?
Sure. I do!
I'm working on a pretty complex Python real time video processing pipeline that involve multiple machine learning model inferences, analytics simultaneously.
Thanks for NVIDIA Video Processing Frameworks and TorchScript, I had built a pretty efficient end to end GPU pipeline for ours use-cases. Since the whole code base was written mostly in Python, I encounter the GIL very frequently in the current multi-threaded design and avoid it is a pain in the ass.
I am moving to a multi-process design so I don't have to worry about the GIL that much, but managing everything via pytorch multiprocessing and Tensor.share_memory_()properly to avoid memory leak and seg fault are also a pain in the ass too. It's would be great if we had a Redis like KV store to share lots of GPU tensor directly between process without any memory copies, serialize and deserialization nor device transfer.
TorchStore would be perfect for this and will simplify the code base a tons, help me truly decouple each step in the pipeline to a separate process and made the application much more flexible.
So I am really looking forward to TorchStore, thanks for bringing this up @cbalioglu
This will be a very useful feature.
Was this every implemented?
Hi @c-heat16, it is in our backlog, but we don't have a date yet.
Nice proposal! It's good to see the general approach from my blog posts getting traction in this community. A shared-memory store for PyTorch would be helpful for inference use cases involving large numbers of models, especially when there's a "long tail" of little-used models that need to be ready for inference on short notice.
I do have one concern about the current proposal and some suggestions.
Concern: Model deserialization.
The mechanism for loading a model the shared memory store described in the proposal is incorrect. Specifically, the proposal recommends loading models from shared memory by the following process:
- Construct the model's graph of
torch.nn.Moduleobjects. - Read a state dictionary of shared-memory tensors.
- Pass the state dictionary to the root
Moduleobject'sload_state_dict()method.
The problem with this approach is that the load_state_dict() method always deep-copies tensors (See the implementation code here). The steps described above will create a buffer in local process memory for each weight tensor of the model and will copy the weights from shared memory into those local buffers. This is not the intended behavior.
I recommend adding a dedicated API for loading a model directly from TorchStore to avoid this problem. My preference would be to have two APIs:
-
An API similar to
torch.nn.Module.load_state_dict()that does not copy data but instead modifies all the tensors in theModulesuch that their storage buffers become pointers to the appropriate blocks of shared memory. -
An API similar to
torch.load()that deserializes the entire model from TorchStore, without copying weights data and without initializing empty buffers for each of tensors that the model uses to store weights. This API would of course require a second method for serializing aModuleobject to TorchStore, similar totorch.save().
Suggestion: Optional persistence
As currently proposed, TorchStore's shared memory segment is owned by the TorchStore daemon. If the daemon process exits or is killed, all the data in the store will be lost.
This design would be inconvenient for inference use cases that rely on rapid loading of models from shared memory. If the daemon crashes or is restarted, all models that were in the store will need to be reloaded via a "slow-path" mechanism, such as copying the model weights from S3. The resulting delay would impact the availability of the model serving layer.
It would be better if there was an option to keep the data of the shared memory store (i.e. the objects and the mapping from keys to objects) inside a memory-mapped file. Then it would be possible to restart the TorchStore daemon and restore the contents of the store much more quickly.
Suggestion: Read-only access
Most inference use cases for TorchStore will write model weights into the store exactly once. It would be useful if there was a way to write models into TorchStore at model build time, save the entire store to a file, and open this file later on as a read-only instance of TorchStore.
Deploying models for inference this way would give several advantages:
- Models would be available for inference immediately upon container startup.
- No TorchStore daemon would be required for inference; each inference process can run its own in-process TorchStore instance, all of them memory-mapping the same read-only file.
- Different containers on the same physical host would be able to share as single shared memory segment. Linux containers that memory-map the same file in read-only mode share the pages of the file in the host OS's page table.
Following up on my previous comment: Now that I think about it, it is actually possible to implement zero-copy loading of entire models for inference directly from a read-only memory-mapped file. I've put together a quick proof of concept here: https://github.com/project-codeflare/zero-copy-model-loading/blob/main/notebooks/h5_poc.ipynb
Another alternative is to leverage Apache Arrow’s Plasma Object Store. The advantage is that we won’t need to implement our own store and will simply wrap Plasma. However this alternative also has its disadvantages: 1) no support for the storage() API, saving a tensor to the store will always require double-allocation,
I think we can still implement torchstore as a thin wrapper on top of Plasma without the disadvantages listed. Plasma separates "allocation" and "sealing" of the object. So it should be possible to wrap it into the allocator without issues.
Similarly, the "view" relationship can easily be layered on top by storing "tensor" objects and "storage" objects separately.
- no support for in-process store meaning the user has to manually start a separate process for the store regardless of the use case
One could try to link Plasma statically, but dependencies might indeed be a problem. But why would you want to run it in-process instead of forking?