acme icon indicating copy to clipboard operation
acme copied to clipboard

Improving design of Learner.step: adding optional parameter sample

Open wookayin opened this issue 3 years ago • 4 comments

This could be a huge change in acme's API design but I would like to make a proposal or initiate the discussion about how we can change and improve the design of one of the acme's central APIs: Learner.step().

In short, I think we could benefit greatly from adding an optional parameter sample to Learner. Specifically, we could design the method Learner.step such that it takes an optinal parameter samples, and it performs the usual learning step from the provided sample if given, or sample from self._iterator (the current behavior) otherwise.

 class Learner(...):
   @abc.abstractmethod
-  def step(self):
+  def step(self, sample: Optional[reverb.ReplaySample]):
     """Perform an update step of the learner's parameters."""

Then, a typical implementation would look like:

def step(self, sample: Optional[reverb.ReplaySample] = None):
  sample = next(self._iterator)           
  transitions = types.Transition(*sample.data)
                                                                   
  self._state, metrics = self._update_step(self._state, transitions)

Why do we need it?

It will allow much easier extension and customization of learning behaviors.

One example is overriding reward functions; for example, intrinsic rewards. Some example sthat makes use of intrinsic reward functions are AIL (imitaiton learning) and RND, where the agent is built on top of underlying RL algorithm (called "direct RL" learner) but whose reward function is redefined by some other components: intrinsic reward computed by RND network or discriminator (AIL). In such cases, one needs to process and override the reward of a sample because the extrinsic/task reward may not be used.

However, doing this is very complicated with the current form of the API. Because there is no way to access or override the individual samples being sampled from the dataset iterator (usually reverb dataset) outside the learner. In my view the current way how this is achieved looks somewhat intimidating or overly complicated: the iterator of the underlying direct RL learner will be tee-ed and passed through process_sample which is able to arbitrarily process a sample data. To inject this iterator to the direct RL learner, AILLearner or RNDLearner would need to be provided with direct_rl_learner_factory that creates a direct RL learner when passed such a "processed" replay sample iterator. This approach works well with the current API design, but duplicating the iterator is error-prone to some errors about update-sample ratio when rate_limiter is engaged. Readability of code due to such encapsulation and difficult-to-track lambda functions would be another concern.

On the other hand, if we had this optional parameter, a design of nested Learners will be much simpler. For example:

class AILLearner:
  # ...
  def step(self, sample=None):
      sample: AILSample = self._iterator()
      # ...
      self._direct_rl_learner.step(self._process_sample(sample.rl_sample))
      # ...

It will also enable implementing a custom learning process, for instance, (optional) on-policy learning that is not necessarily tied with the reverb replay buffer through dataset or learner._iterator.

What would be concerns of making such changes?

  • We need to ensure that the type of sample is compatible with that of self._iterator. Usually it is replay.ReplaySample, but in principle learner should not heavily coupled with ReplaySample itself. As in the above example, AILLearner expects samples to be AILSample rather than the raw ReplaySample; so the type of sample across different learners can be different and it actually needs to be generic.
  • Since this is a breaking change of API requirement, some third-party or external use cases of Learner might not be compatible with this (see the alternative below).

In addition, I believe there might be other reasons why DM hasn't taken this approach in the first place. Any thoughts behind the design rationale?

Alternatives?

An alternative, less-intrusive approach is to add an unified method like def update_step(sample) for acme.Learner or its specific (abstract) subclass, pretty similar to SACLearner._update_step(learning_state, transitions) or self._sgd_step(learning_state, sample), etc. In this way we could keep the acme.Learner.step(...) intact. Currently all the subclass implementation of Learner do this all differently without a common interface --- protected method names are different, and signatures are different (some are taking Transition and others ReplaySample), and those functions are even hidden by JIT (jax.jit or tf.function) compilation, which makes overriding pretty difficult as usually are defined in the constructor.

If we take this approach, a refactoring will be required to unify types (with some use of Generics of course) and the signature of sample. We may want to add make this polymorphic extension engaged only for more specific, concrete subclass (e.g., GenericLearner , in a similar fashion as GenericActor extends Actor) to avoid a breaking API change in acme.Learner.

Example:

Sample = acme.jax.types.Sample   # TypeVar

class GenericLearner(Learner, Generic[Sample]):
   @abstractmethod
   def step(step, sample: Optional[Sample] = None):
     ...

It would be controversial whether subclass methods may not have different (additional) function signature, so this might not be a good idea. Instead, one could do like (apologizes for tentative naming and sketch designs)

Sample = acme.jax.types.Sample   # TypeVar
LearnerState = TypeVar('LearnerState')

class GenericLearner(Learner, Generic[LearnerState, Sample]):
   def __init__(...):
      self._logger = ...
      self._state = self.make_initial_state(random_key)
      self._counter = ...
      
   def step(self):
     """An unified, common implementation of learner.""" 
     sample = next(self._iterator)
     transitions = types.Transition(*sample.data)
     self._state, metrics = self.update_step(self._state, transitions)
     # Optionally, measure elapsed time, etc.
     counts = self._counter.increment(steps=1, walltime=elapsed_time)
     self._logger.write({**metrics, **counts})

   @abstractmethod
   def make_initial_state(self, key: networks_lib.PRNGKey
                         ) -> LearnerState:
     ...

   @abstractmethod
   def update_step(self, state: LearnerState, sample: Sample
                   ) -> Tuple[LearnerState, Metrics]:
     # TODO: what about jax.jit?
     ...

wookayin avatar Apr 05 '22 06:04 wookayin

I'd like to offer my two cents on this.

I like @wookayin's proposal. I found the fact the learner needs to take the iterator during construction to be a bit inflexible at times. In fact, I think that the learner should be stateless, in that it only needs to know how to update the learner state given samples. In fact, I believe having a stateless learner that is decoupled from the iterator, counter and logger would offer further parallelism opportunities. For example, it's now possible to vmap over learners and embed them in an environment loop written in JAX that can be pmapped and jitted (something requested in https://github.com/deepmind/acme/issues/220)

I noticed that there were some attempts in this direction. For example, https://github.com/deepmind/acme/commit/e8c90feca0e22fc47d9075658fefb5453a298399 introduced the concept of a LearnerCore which I found to be quite interesting. This was however reverted in https://github.com/deepmind/acme/commit/f8811c4e8a1a3abe82adfee51192099e03e90e5e. I am not sure why there's no further development in this direction. Maybe the Acme developers can share some insights on this.

From my perspective, the idea of a LearnerCore will help de-duplicate some of the bookkeeping that's currently in the learner implementation. For example, handling computation of the elapsed time and writing to the logger could have otherwise been shared by some implementation in the learner core. However, I believe that it also introduces other problems: sometimes there are things in the learner that are specific to one algorithm. For example, DQN with prioritized replay needs to also update reverb with the new priorities. Implementing this in the learner core would be nice, but it would also create an extra burden for users who do not need such behavior to opt out.

ethanluoyc avatar Apr 13 '22 15:04 ethanluoyc

Thank you @ethanluoyc for your opinions!

In fact, I think that the learner should be stateless

+100. This is the gist of everything. Learner (or at least one of specific subclass) should be stateless, which will make use of it way easier and flexible. Currently learning logics (which can be stateless) and worker-like behavior (running by itself, by sampling some data ...) are strongly coupled, but learning logic and sampling logic should be decoupled. This could be achieved, by exposing such learners as the interface of Worker (i.e., having .step()) when augmented with dataset iterator to sample from.

To avoid breaking interface change, we might leave Learner as-is, i.e., Learner = LearnerCore + Sampler (which is stateful), and we would extract the core learning logic as LearnerCore as @ethanluoyc pointed out (but not as a 'dataclass' nor a "final" class). LearnerCore should be state-less in terms that it doesn't maintain state about how to pull train samples, but be stateful in terms that it stores all the parameters and learning states as necessary. I don't see any potential challenges towards this direction, it feels like a matter of straightforward refactoring.

Looking forward to DM folk's inputs. Please have this forwarded to anyone whom it may concern.

wookayin avatar Apr 13 '22 17:04 wookayin

@wookayin @ethanluoyc thanks for this discussion. Indeed we had a very serious attempt to make the learner stateless. We unfortunately found this to be harder than it looked on the surface and we have abandoned the effort. The main reason was that we wanted to maintain support for our learners to call any external library in the step method. Some of those can not be refactored to pull out the state which makes it impossible to have a stateless step function.

nikolamomchev avatar May 10 '22 14:05 nikolamomchev

Thanks @nikolamomchev for the clarification. That makes sense. I guess stateless learner is something useful to do, but not something necessarily that should be enforced for all learners.

@nikolamomchev @wookayin I think one idea that I have is to have stateless learners whenever it makes sense. One way to do this is to implement a stateless learner for some agents (e.g. TD3 or SAC) and the stateful learner would be a wrapper that builds on top of that. Most of the training logic would be in the stateless learner and users can opt to use the stateless learner if they find that to be more convenient. WDYT? From Acme's side, I think one thing that would be really useful is to have an official wrapper for the stateless learners that users can plug in stateless learners when they need it. A concrete thing to do would be to revive the JaxDefaultLearner reverted in previous commits.

ethanluoyc avatar May 11 '22 13:05 ethanluoyc