data
data copied to clipboard
Support Key/Value databases
🚀 The feature
The existing cacheholder leverages a python dictionary
@functional_datapipe("in_memory_cache")
class InMemoryCacheHolderMapDataPipe(MapDataPipe[T_co]):
def __init__(self, source_dp: MapDataPipe[T_co]) -> None:
self.source_dp: MapDataPipe[T_co] = source_dp
self.cache: Dict[Any, T_co] = {}
def __getitem__(self, index) -> T_co:
if index not in self.cache:
self.cache[index] = self.source_dp[index]
return self.cache[index] # type: ignore[index]
But could instead provide a generic interface to plug in different cache providers like redis or memcached
class Cache(ABC):
@abstractmethod
def __init__():
pass
@abstractmethod
def __getitem__(self,index) -> T_co
pass
class RedisCache(Cache):
def __init__(self, url):
setup_redis_client(url)
def __getitem__(self,index) -> T_co:
return NotImplementedError
class MemCacheCache(Cache):
def __init__(self,url):
setup_memcache_client(url)
def __getitem__(self,index) -> T_co:
return NotImplementedError
Motivation, pitch
Python dictionaries have a few limitations when used as a cache
- Need to copy them per process
- Updating cache in one process needs to manually synchronize with cache in other processes
- Need to load the entire dictionary in memory to potentially look at a single element
So the goal of this work would be to reduce memory overhead and cache misses of cache in multiprocessing environments while sacrificing latency on cache hit because a python dictionary will be faster to access than a remote KV store
The 0.4 release was very much about leveraging remote object stores so this work would follow that trend
Alternatives
No response
Additional context
Our queues right now use python lists https://github.com/pytorch/data/blob/main/torchdata/dataloader2/communication/queue.py#L11 but could instead leverage queues like Kafka or RabbitMQ so imagine a similar solution
It's a reasonable feature for MapDataPipe
. Would it be possible to extend this cache interface to support in-memory cache for IterDataPipe
, which follows FIFO manner.
BTW, I think __contains__
is also required for cache object to check if the request has already been in cache.
Summarizing notes from meeting with Erjia
- Let's do a comparison of open source caching solutions with their tradeoffs
- Do we need to come up with our format? What are the benefits and restrictions?
- Aggregate feedback from https://github.com/pytorch/data/issues?q=is%3Aissue+is%3Aopen+cache
- For single node may make sense to host a cache/queue service locally that way latency is still good and we can use far less memory which fits the original goals of datapipes
- For distributed we can use a remote cache since otherwise we need to duplicate cache per node
After talking with @msaroufim I wanted to take a stab at this implementation. Below is the prototype code I put together solely for the purpose of discussion. this uses the IterDatapipe
class but the map one is probably even simpler.
Once we figure out a good strategy for handling some of the design items below I can
- abstract the cache interface
- clean up loose ends like authentication and TLS
- proper error handling
- tests/examples
- submit a PR
def setup_redis_client(url, username=None, password=None):
# TODO Implement with ACL support
# TODO Implement with TLS support?
# TODO Redis cluster client??
try:
import redis
return redis.Redis(url)
except ImportError as e:
print("Redis needs to be installed in order to use Redis cache for datapipes")
raise
@functional_datapipe("redis_cache")
class RedisCacheHolderIterDataPipe(IterDataPipe[T_co]):
def __init__(self, source_dp: IterDataPipe[T_co], redis_url: str, cached_elements: Optional[int] = None) -> None:
self.source_dp: IterDataPipe[T_co] = source_dp
self._client = setup_redis_client(redis_url)
self._key = "tpipe"
self._start_idx = 0
# use number of cached elements rather than cache size
# avoids problem of using Redis DB size when Redis being used for
# more than just a datapipe cache
self.cached_elements = cached_elements
def _iter_stored(self):
# index always starts at 0 for redis list
# _start_index solely for tracking number of stored elements
for idx in range(0, self._cache_list_len()):
# LRANGE? or Pipeline??
yield self._deserialize(self._client.lindex(self._key, idx))
def _deserialize(self, response):
return pickle.loads(response)
def _serialize(self, value):
# TODO store datatype in Redis upon init? how to assert datatype
# dont serialize for primiative datatypes? only collections?
return pickle.dumps(value)
def __iter__(self) -> Iterator[T_co]:
if self._cache_list_len() > 1:
for idx, data in enumerate(self.source_dp):
print(data)
if idx < self._start_idx:
yield data
else:
break
yield from self._iter_stored()
else:
for data in self.source_dp:
self._client.rpush(self._key, self._serialize(data))
# Cache reaches element limit
if self.cached_elements is not None and self._cache_list_len() > self.cached_elements:
self._client.lpop(self._key)
self._start_idx += 1
yield data
def __contains__(self, key):
return self._client.exists(key)
def _cache_list_len(self):
return self._client.llen(self._key)
def __len__(self) -> int:
try:
return len(self.source_dp)
except TypeError:
# if list has been created in the database
if self._key in self:
return self._start_idx + self._cache_list_len()
else:
raise TypeError(f"{type(self).__name__} instance doesn't have valid length until the cache is loaded.")
When running the following simple example
from torchdata.datapipes.iter import IterableWrapper
source_dp = IterableWrapper(range(10))
cache_dp = source_dp.redis_cache(redis_url="localhost")
print(list(cache_dp))
I get the expected answer
>>> [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
I've also tested with the ag_news
example augmented to use the redis cache
# Stack CSV Parser directly on top of web-stream
dp = HttpReader([URL[split]]).parse_csv()
cache_dp = dp.redis_cache(redis_url="localhost")
return cache_dp.map(_process_tuple)
A couple design points to mention
- I went for the
cached_elements
instead ofsize
since it'll be tough to directly correlate list item size to cache size unless that is tracked separately. This provides an optional layer on top of the existing eviction strategies within Redis. -
start_idx
is used in the case wherecached_elements
is triggered. This simply assists with the fact that redis lists always start with 0. - I wasn't sure how to handle the
__len__
function as this is my first pass at this library so please chime in if that doesn't look right.
Design Points
- How to optimize retrieval?
- A few options when optimizing redis retrieval times
- pipelining
- binning (use hset instead of redis lists)
- async
- A few options when optimizing redis retrieval times
- Client init
- Is this class, once initialized, expected to be used across threads? Processes? I would assume so and hence we might need to think about where/how we initialize the client for best performance.
- Serialization
- I used
pickle
above, but that is not the ideal library to use here primarily for security reasons. Other options? Protobuf? Dill? - We should be able to detect the datatype and if it's a primitive just use simple byte encoding which will perform better than something heavier like protobuf.
- I used
I'll keep chipping away at this, but wanted to post early to gather feedback.
Hi @Spartee thank you for your patience here are my thoughts
Overall I think a PR doing what you describe should be pretty easy to merge with the value being that a Redis cache is yet another data source that anyone could leverage
Regarding some of your specific design questions
- cached_elements is a good idea, although I wonder is it common for people to use Redis caches for more than one use case or do folks typically have a cache per application?
- start_idx looks good
-
__len__
: for an iterable pipe I'd just raise an error https://github.com/pytorch/data/blob/main/torchdata/datapipes/iter/load/huggingface.py#L79-L80 otherwise return the cache length
On the future design points
- Optimize retrieval: I think we can punt on those questions for a first PR but just mention them in your PR. I think what we can do is merge some simple prototype, find a real human user and then add more optimizations over time
- client init: Yeah for now we can think about multiple processes only
- serialization: We use dil in some parts of our codebase https://github.com/pytorch/data/blob/main/torchdata/datapipes/iter/util/cacheholder.py#L35
@Spartee Thank you for putting up a prototype! Here are my thoughts.
Aside from optimize retrieval
, we might be able to provide multiple-layer cache
to reduce cach miss. For redis
example, cached_elements
is used to define the cache size. When we runs out of cache
, we might fall back to the another layer of cache
(maybe by default use a on-disk file - this might need the format discussed with @msaroufim ).
client init
Noob question about how redis
client/server works in python. Let's say we have multiple processes running the same pipeline, would each redis
client attach to the same server? If that's the case, it seems like we need to figure out a way to make sure the order of data is preserved because normally the data should be in round-robin order. (Process 0: [0, 2, 4, 6, ...], Process 1: [1, 3, 5, 7, ...])
Serialization
We currently still rely on pickle
as pytorch core is still depending on pickle
. But, cache
itself could implement their own serialization strategy inside __getstate__
and __setstate__
to eliminate the security issue. And, within each DataPipe, pickle
will invoke __getstate__
function and the inner serialization logic can be called before pickling the cache data.
Happy to! Some responses
cached_elements is a good idea, although I wonder is it common for people to use Redis caches for more than one use case or do folks typically have a cache per application?
I think it's most common to have one cache per application, but I have seen some places where a number of microservices use the same cache for a number of purposes (pubsub/brokering/caching). whatever we decide here should be well documented. I lean towards a more "hands off" approach here as the size checks would need to occur regularly on the client side which, under load, is not a negligible cost.
Serialization
Both valid options presented by @ejguan and @msaroufim. the out-of-band performance of pickle (protocol 5) is quite good and I know other pydata libraries use it. If we use dill, I would want to include that in a pip extra like torchdata[redis]
for consistency.
I'm going to play around with serialization/compression and come back with some options.
multiple-layer cache to reduce cach miss.
This is a strategy that I think should be handled on the server side. Some variants of Redis, esp managed ones like Redis Enterprise, give flash support for handing tiered caching. This is more performant than multiple requests on the client side. There are some OSS variants that provide tiered caching as well that are consistent with the OSS redis API.
client init
If we solely focus on multiprocessing settings, I would think the best solution would be to init a client within each process at startup and keep them alive until the object is destroyed. How would the processes be initilaized? user passes them in? Can you point me to any examples that use a single datapipe with multiple processes?
How would the processes be initilaized? user passes them in? Can you point me to any examples that use a single datapipe with multiple processes?
Here are some references. DataLoader2
would rely on MultiprocessingReadingService
to spawn processes.
https://github.com/pytorch/data/blob/86df1a09c0f649aca195a233508669de35b8623b/torchdata/dataloader2/reading_service.py#L147-L166
The datapipe should be automatically sharded based on the process number. Here would be a minimum example for multiprocessing.
from torchdata.datapipes.iter import IterableWrapper
from torchdata.dataloader2 import DataLoader2, PrototypeMultiProcessingReadingService
if __name__ == "__main__":
input_dp = IterableWrapper(list(range(100))
dp = input_dp.shuffler().sharding_filter()
rs = PrototypeMultiProcessingReadingService(num_workers=2)
dl = DataLoader2(dp, reading_service=rs)
for d in dl:
print(d)