RVT
RVT copied to clipboard
Wrong length of test dataset
trafficstars
Hi,
I have set the sample_mode='enumerate', with a batch size of 48. The two tasks are close_jar and insert_onto_square_peg, and the returned length of the test_dataset is 64. I believe this means that if I iterate 64 times, all samples in the test replay buffer will be enumerated. However, when I do that, I find that the insert_onto_square_peg task is never sampled. Only after increasing the number of iterations to about 263 do I begin to see the insert_onto_square_peg samples.
Do you have any insight into why this is happening?
`class PyTorchIterableReplayDataset(IterableDataset):
def __init__(self, replay_buffer: ReplayBuffer, sample_mode, sample_distribution_mode = 'transition_uniform'):
self._replay_buffer = replay_buffer
self._sample_mode = sample_mode
if self._sample_mode == 'enumerate':
self._num_samples = self._replay_buffer.prepare_enumeration()
self._sample_distribution_mode = sample_distribution_mode
def _generator(self):
while True:
if self._sample_mode == 'random':
yield self._replay_buffer.sample_transition_batch(pack_in_dict=True, distribution_mode = self._sample_distribution_mode)
elif self._sample_mode == 'enumerate':
yield self._replay_buffer.enumerate_next_transition_batch(pack_in_dict=True)
def __iter__(self):
return iter(self._generator())
def __len__(self): # enumeration will throw away the last incomplete batch
return self._num_samples // self._replay_buffer._batch_size`