rlkit icon indicating copy to clipboard operation
rlkit copied to clipboard

Dataset based Trainer

Open redknightlois opened this issue 5 years ago • 5 comments

This example dataset based trainer also does expert signal recollection, so that is why I didnt do a PR, will let it to you to decide which parts make sense for rlkit.

class OptimizedBatchRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta):
    def __init__(
            self,
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            max_num_steps_before_training=1e5,
            expert_data_collector: PathCollector = None,
    ):
        super().__init__(
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )

        assert isinstance(replay_buffer, Dataset), "The replay buffers must be compatible with Pytorch Dataset to use this version."

        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.max_num_steps_before_training = max_num_steps_before_training
        self.expert_data_collector = expert_data_collector

    def _train(self):
        if self.min_num_steps_before_training > 0:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )

            self.replay_buffer.add_paths(init_expl_paths)

            self.expert_data_collector.end_epoch(-1)
            self.expl_data_collector.end_epoch(-1)

        if self.expert_data_collector is not None:
            new_expl_paths = self.expert_data_collector.collect_new_paths(
                self.max_path_length,
                min(int(self.replay_buffer.max_buffer_size * 0.5), self.max_num_steps_before_training),
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(new_expl_paths)

        dataset_loader = torch.utils.data.DataLoader(self.replay_buffer, pin_memory=True, batch_size=self.batch_size, num_workers=0)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            printout('Evaluation sampling')
            self.eval_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
            )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):

                printout('Exploration sampling')
                new_expl_paths = self.expl_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_expl_steps_per_train_loop,
                    discard_incomplete_paths=False,
                )
                gt.stamp('exploration sampling', unique=False)

                self.replay_buffer.add_paths(new_expl_paths)
                gt.stamp('data storing', unique=False)

                self.training_mode(True)

                i = 0
                with tqdm(total=self.num_trains_per_train_loop) as pbar:
                    while True:

                        for _, data in enumerate(dataset_loader, 0):
                            if i > self.num_trains_per_train_loop:
                                break  # We are done

                            observations = data[0].to(ptu.device)
                            actions = data[1].to(ptu.device)
                            rewards = data[2].to(ptu.device)
                            terminals = data[3].to(ptu.device).float()
                            next_observations = data[4].to(ptu.device)
                            env_infos = data[5]

                            train_data = dict(
                                observations=observations,
                                actions=actions,
                                rewards=rewards,
                                terminals=terminals,
                                next_observations=next_observations,
                            )

                            for key in env_infos.keys():
                                train_data[key] = env_infos[key]

                            self.trainer.train(train_data)
                            pbar.update(1)
                            i += 1

                        if i > self.num_trains_per_train_loop:
                            break

                gt.stamp('training', unique=False)
                self.training_mode(False)

                if isinstance(self.expl_data_collector, AtariPathCollectorWithEmbedder):
                    eval_policy = self.eval_data_collector.get_snapshot()['policy']
                    self.expl_data_collector.evaluate(eval_policy)

            self._end_epoch(epoch)

    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)

redknightlois avatar Apr 29 '19 15:04 redknightlois

Thanks for defining this class.. Can you share an example how to use this trainer class along with DDPG and SAC?

nm-narasimha avatar Jun 08 '19 11:06 nm-narasimha

Standard examples show how to do that. There is no difference between the current and this one. I use #52 for dataset size reasons though, but for the rest is pretty straightforward.

redknightlois avatar Jun 08 '19 12:06 redknightlois

Thanks.. @redknightlois , do you have a sample replay_buffer compatable with pytorch dataset class? Is env_replay_buffer or any other class in rlkit.data_management is compatable?

Thanks, Narasimha

nm-narasimha avatar Jun 10 '19 00:06 nm-narasimha

#52 is a pytorch dataset class.

redknightlois avatar Jun 10 '19 01:06 redknightlois

Hmmm, so it looks like the main difference is the addition of expert_data_collector. Is that correct? In that case, I'm not sure if we need to create an entirely new class for this. One option would be to add that data to the replay buffer before passing the replay buffer to the algorithm. What do you think of that? It would help separate out the algorithm from the pretraining phase.

vitchyr avatar Jun 11 '19 21:06 vitchyr