RVT icon indicating copy to clipboard operation
RVT copied to clipboard

Wrong length of test dataset

Open LPY1219 opened this issue 10 months ago • 1 comments
trafficstars

Hi,

I have set the sample_mode='enumerate', with a batch size of 48. The two tasks are close_jar and insert_onto_square_peg, and the returned length of the test_dataset is 64. I believe this means that if I iterate 64 times, all samples in the test replay buffer will be enumerated. However, when I do that, I find that the insert_onto_square_peg task is never sampled. Only after increasing the number of iterations to about 263 do I begin to see the insert_onto_square_peg samples.

Do you have any insight into why this is happening?

`class PyTorchIterableReplayDataset(IterableDataset):

def __init__(self, replay_buffer: ReplayBuffer, sample_mode, sample_distribution_mode = 'transition_uniform'):
    self._replay_buffer = replay_buffer
    self._sample_mode = sample_mode
    if self._sample_mode == 'enumerate':
        self._num_samples = self._replay_buffer.prepare_enumeration()
    self._sample_distribution_mode = sample_distribution_mode

def _generator(self):
    while True:
        if self._sample_mode == 'random':
            yield self._replay_buffer.sample_transition_batch(pack_in_dict=True, distribution_mode = self._sample_distribution_mode)
        elif self._sample_mode == 'enumerate':
            yield self._replay_buffer.enumerate_next_transition_batch(pack_in_dict=True)

def __iter__(self):
    return iter(self._generator())

def __len__(self): # enumeration will throw away the last incomplete batch
    return self._num_samples // self._replay_buffer._batch_size`

LPY1219 avatar Jan 16 '25 07:01 LPY1219