rune
rune copied to clipboard
Caching strategy for long running processes like BERT QA inference
It would be good to be able to register key/keyspace for a particular function and cache/memoise output.
Implementation options can be memory-mapped fxHash with optional on-disk persistence (TBD).
The way how it can be achieved in Redis Using RedisGears module:
Register function on keyspace, which is triggered on keymiss
event
gb = GB('KeysReader')
gb.map(qa_cached_keymiss)
gb.register(prefix='bertqa*', commands=['get'], eventTypes=['keymiss'], mode="async_local")
Which runs qa_cached_keymiss
function:
async def qa_cached_keymiss(record):
val=record['key'].split('_')
cache_key='bertqa{%s}_%s_%s' % (hashtag(), val[1],val[2])
# Asynchronois call to BERT QA inference
res = await qa(val)
# store output of BERT QA in cache via standard SET command
execute('set',cache_key, res)
override_reply(res)
return res
The API client always only calls GET BERTQA* key
and is unaware of implementation details of BERT QA inference function.
redis-cli -c -p 30003 -h 127.0.0.1 get "bertqa{8YG}_PMC302072.xml:{8YG}:10_Who performs viral transmission among adults"
Proposal of caching strategy into Transformers library.
Blog post write up. I know how to do this type of caching in Python/Redis, not in Rust (yet).
one thing we need to consider is how this would work on edge devices. This could be something that is a capability which can be consumed by the transformer.
Phones have a dedicated AI chip for ML inference, capability can be defined in terms of available RAM.
"bertqa{8YG}_PMC302072.xml:{8YG}:10_Who performs viral transmission among adults" decyphers like this: shard {8YG} (called hash id in Redis speak) contains key PMC302072.xml:{8YG}:10 with pre-tokenised context. When running inference the question is tokenised and then appended to (pre-tokenised) context. Allows achieving thigh throughput even on CPU with no quantisation or ONNX optimisations.