acme
acme copied to clipboard
Converting Tensorflow Dataset to `iterator` does not sync well with client
Hi, I accidentally stumbled upon a problem within the tutorial notebook when playing around with the acme and reverb API that causes a weird synchronization behaviour between sampling from the reverb
table and updating priorities. Another artifact of this that I encountered is that the very first transition would be consistently repeated until some hidden tensorflow buffer would be flushed.
What I found is that when I would mutate the priorities in a reverb
table using client.mutate_priorities(table_name, my_dict)
and then create an iterator from the tf.data.Dataset
object, then the priorities would update only after flushing a large number of samples. In contrast, if I didn't convert the tf.data.Dataset
to an iterator and used the dataset.batch(n); dataset.take(n)
interface, it would immediately sync with the new priorities.
It seems to me that the problem lies with the implementation of __iter__
in tf.data.Dataset, but I posted this issue here since the Colab makes a call to as_numpy_iterator()
on the dataset object, and this is also the implementation of the D4PG
jax agent. Since this is a silent and obscure bug, this effectively eliminates the possibility of changing the baseline D4PG
agent to utilize Prioritized Experience Replay.
Minimal Reproducible example:
import warnings
warnings.filterwarnings('ignore')
import acme
from acme import wrappers
from acme.datasets import reverb as datasets
from acme.adders.reverb import sequence
from acme.jax import utils
import tree
import reverb
import jax
import numpy as np
from dm_control import suite
# Create dummy environment with short episodes to easily dichotomize samples
env = suite.load('cartpole', 'balance')
env = wrappers.step_limit.StepLimitWrapper(env, step_limit=5)
spec = acme.make_environment_spec(env)
# Danger: reverb.Table crashes kernel if run > once
table = reverb.Table(
name='priority_table',
sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
remover=reverb.selectors.Fifo(),
max_size=10_000,
rate_limiter=reverb.rate_limiters.MinSize(1),
signature=sequence.SequenceAdder.signature(spec)
)
server = reverb.Server([table], port=None)
client = reverb.Client(f'localhost:{server.port}')
# Construct adder such that only 1 sample is added to table after an episode.
adder = sequence.SequenceAdder(client, sequence_length=6, period=5)
def new_dataset():
# Clear old data
client.reset(table.name)
return datasets.make_reverb_dataset(
table=table.name, server_address=client.server_address, batch_size=3
)
def fill_dataset():
step = env.reset()
adder.add_first(step)
action = env.action_spec().generate_value()
i = 0
while (not step.last()) and i < 10:
step = env.step(action)
adder.add(action, step)
i += 1
env.close()
adder.reset()
### Example of expected behaviour
dataset = new_dataset()
fill_dataset()
print('before mutation')
for s in dataset.take(1):
k, p = s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()
print(s.data.action.numpy().reshape(3, -1)) # (B, T, 1) -> (B, T)
print('sample priority:', p)
# Iteratively halve the priorities
new_priorities = dict(zip(k, p * 0.5))
client.mutate_priorities(table.name, new_priorities)
print()
print('after mutation')
for s in dataset.take(1):
# Priorities have been updated --> all probabilities should now be adjusted.
print(s.data.action.numpy().reshape(3, -1)) # (B, T, 1) -> (B, T)
print('sample priority:', s.info.priority.numpy())
### Test-cases
print('\nUsing dataset.take')
dataset = new_dataset()
fill_dataset()
# This runs fine
for repeat in range(5):
for i in range(30): # Flush count guess
for s in dataset.take(1):
k, p = s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()
# Exponentially decay the priorities
new_priorities = dict(zip(k, p * 0.999))
client.mutate_priorities(table.name, new_priorities)
for s in dataset.take(1):
new_p = s.info.priority.numpy().ravel()
assert not np.isclose(new_p, p).any(), "priorities did not update!"
else:
# No break in for loop
print('No errors!')
print('\nUsing next on iter(dataset) - Problems start here.')
dataset = new_dataset()
fill_dataset()
it = iter(dataset)
# Repeat the test-loop as behaviour strangely changes periodically
for repeat in range(5):
for i in range(30): # Flush count guess
s = next(it)
k, p = s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()
# Iteratively halve the priorities
new_priorities = dict(zip(k, p * 0.999))
client.mutate_priorities(table.name, new_priorities)
s = next(it)
new_p = s.info.priority.numpy().ravel()
# Priority mutations now sync extremely slowly
if not np.isclose(p, new_p).all():
print(f'Priorities updated at flush-step {i}')
break
else:
# No break in for loop : not reached
print('No errors!')
Output:
before mutation
[[-1. -1. -1. -1. -1. 0.]
[-1. -1. -1. -1. -1. 0.]
[-1. -1. -1. -1. -1. 0.]]
sample priority: [1. 1. 1.]
after mutation
[[-1. -1. -1. -1. -1. 0.]
[-1. -1. -1. -1. -1. 0.]
[-1. -1. -1. -1. -1. 0.]]
sample priority: [0.5 0.5 0.5]
Using dataset.take
No errors!
No errors!
No errors!
No errors!
No errors!
Using next on iter(dataset) - Problems start here.
Priorities updated at flush-step 24
Priorities updated at flush-step 5
Priorities updated at flush-step 18
Priorities updated at flush-step 5
Priorities updated at flush-step 18
Proposed Solution
The problem is immediately solved if iter(dataset)
is called at each call to next
. Because of this, I wasn't sure whether to post this issue here or in the tensorflow github, since the problem is with tf.data.Dataset. Personally I would suggest creating a wrapper around tf.data.Dataset for that either makes use of the take
and batch
API, or reinitialize the iter
at every call. Because of how reverb
implements sampling, reinitializing the dataset iterator should have no side-effects.
Example solution:
print('\nReinitializing iter on every next call - Problem Solved.')
dataset = new_dataset()
fill_dataset()
it = iter(dataset) # Ignore this iterator
# Repeat the test-loop as behaviour strangely changes periodically
for repeat in range(5):
for i in range(30): # Flush count guess
s = next(iter(dataset)) # CHANGE: call iter(dataset) every time `next` is called
k, p = s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()
# Iteratively halve the priorities
new_priorities = dict(zip(k, p * 0.999))
client.mutate_priorities(table.name, new_priorities)
s = next(iter(dataset)) # CHANGE: call iter(dataset) every time `next` is called
new_p = s.info.priority.numpy().ravel()
# Priority mutations now sync extremely slowly
if not np.isclose(p, new_p).all():
print(f'Priorities updated at flush-step {i}')
break
else:
# No break in for loop : not reached
print('No errors!')
Output: ( priorities are updated after every call, which is what we expected).
Reinitializing iter on every next call - Problem Solved.
Priorities updated at flush-step 0
Priorities updated at flush-step 0
Priorities updated at flush-step 0
Priorities updated at flush-step 0
Priorities updated at flush-step 0
Another quick fix that I use is to wrap the reverb dataset inside the following class:
class RefreshIterator:
"""tf.data.Dataset fix for slow reverb client synchronization. Wrap around reverb-dataset."""
__slots__ = ["_iterable"]
def __init__(self, iterable):
self._iterable = iterable
def __iter__(self):
return self
def __next__(self):
return next(iter(self._iterable))
def next(self):
return self.__next__()
Use:
dataset = datasets.make_reverb_dataset(
table=my_table.name, server_address=reverb_client.server_address, batch_size=..., ...
)
jax_dataset = utils.multi_device_put(_NumpyIterator(RefreshIterator(dataset)), ...)
With unfortunately _NumpyIterator
a private class in tf.dataset_ops
.