seed_rl icon indicating copy to clipboard operation
seed_rl copied to clipboard

what is bit-packing and is it environment specific?

Open turmeric-blend opened this issue 4 years ago • 8 comments

I am new to the concept of bit-packing and have limited knowledge on binary vs decimal.

I see in the frame_stack() method there is some bit-packing code:


# Unpacked 'frame_stacking_state'. Ordered from oldest to most recent.
  unstacked_state = []
  for i in range(stack_size - 1):
    # [batch_size, height, width]
    unstacked_state.append(tf.cast(tf.bitwise.bitwise_and(
        tf.bitwise.right_shift(frame_stacking_state, i * 8), 0xFF),
                                   tf.float32))

and


shifted = tf.bitwise.left_shift(
      tf.cast(stacked_frames[-1, ..., :-1], tf.int32),
      # We want to shift so that MSBs are newest frames.
      [8 * i for i in range(stack_size - 2, -1, -1)])
  # This is really a reduce_or, because bits don't overlap.
  new_state = tf.reduce_sum(shifted, axis=-1)

I am assuming this is to make processing faster, however, is this code specific to frame/image based environments? (I am concern about the value 8 and (0xFF, 255 decimal value from what I know)). For example, will the same seed_rl r2d2 frame_stack() method work with CartPole?

turmeric-blend avatar Apr 07 '20 15:04 turmeric-blend

The bit packing code is used to make CPU<->Accelerator transfers faster. For this particular code, we have the following in the documentation:

frames: [time, batch_size, height, width, channels]. These should be un-normalized frames in range [0, 255]. channels must be equal to 1 when we actually stack frames (stack_size > 1).

The 8 and 0xFF does mean it packs values in the range 0-255 and anything outside would be truncated.

lespeholt avatar Apr 08 '20 20:04 lespeholt

Also note that it would be easy to replace this by a simple and less specialized implementation that uses a state of shape [batch_size, , stack_size - 1] to perform frame stacking (you should be able to do that just by replacing stack_frames() and initial_frame_stacking_state()), or start experimenting without frame stacking (set stack_size to 1).

RaphaelMarinier avatar Apr 08 '20 20:04 RaphaelMarinier

@lespeholt

frames: [time, batch_size, height, width, channels]. These should be un-normalized frames in range [0, 255]. channels must be equal to 1 when we actually stack frames (stack_size > 1).

The 8 and 0xFF does mean it packs values in the range 0-255 and anything outside would be truncated.

so this is specific code (bit-stacking) for frames/images and would not be suitable for environments with floating point observations/states such as LunarLander/CartPole?

@RaphaelMarinier

simple and less specialized implementation that uses a state of shape [batch_size, , stack_size - 1] to perform frame stacking

could you elaborate on this more?

(you should be able to do that just by replacing stack_frames() and initial_frame_stacking_state())

does this mean not using stack_frames() and initial_frame_stacking_state() at all?

turmeric-blend avatar Apr 09 '20 06:04 turmeric-blend

In an effort to make the stacking code more generic to accommodate other environments, I've modified the frame stacking code to the following:

STACKING_STATE_DTYPE = tf.float32

def initial_frame_stacking_state(stack_size, batch_size, observation_shape):
  if stack_size == 1:
    return ()
  return tf.zeros(tf.concat([[batch_size], [tf.math.reduce_prod(observation_shape)], 
                            [stack_size-1]], axis=0),
                  dtype=STACKING_STATE_DTYPE)
def stack_frames(frames, frame_stacking_state, done, stack_size):
  if frames.shape[0:2] != done.shape[0:2]:
    raise ValueError('Expected same first 2 dims for frames and dones. Got {} vs {}.'
      .format(frames.shape[0:2], done.shape[0:2]))

  batch_size = frames.shape[1]
  obs_shape = frames.shape[2:-1]

  if stack_size > 1 and frames.shape[-1] != 1:
    raise ValueError('Due to frame stacking, we require last observation '
                     'dimension to be 1. Got {}'.format(frames.shape[-1]))
  if stack_size == 1:
    return frames, ()
  if frame_stacking_state[0].dtype != STACKING_STATE_DTYPE:
    raise ValueError('Expected dtype {} got {}'.format(
        STACKING_STATE_DTYPE, frame_stacking_state[0].dtype))

  unstacked_state = tf.unstack(frame_stacking_state, axis=-1)[::-1]

  extended_frames = tf.concat(
      [tf.reshape(frame, [1] + frame.shape + [1])
       for frame in unstacked_state] +
      [frames],
      axis=0)

  stacked_frames = tf.concat(
      [extended_frames[stack_size - 1 - i:extended_frames.shape[0] - i]
       for i in range(stack_size)],
      axis=-1)

  done_mask_row_shape = done.shape[0:2] + [1] * (frames.shape.rank - 2)
  done_masks = [
      tf.zeros(done_mask_row_shape, dtype=tf.bool),
      tf.reshape(done, done_mask_row_shape)
  ]
  while len(done_masks) < stack_size:
    previous_row = done_masks[-1]
    # Add 1 zero in front (time dimension).
    done_masks.append(
        tf.math.logical_or(
            previous_row,
            tf.pad(previous_row[:-1],
                   [[1, 0]] + [[0, 0]] * (previous_row.shape.rank - 1))))

  stacked_done_masks = tf.concat(done_masks, axis=-1)
  stacked_frames = tf.where(
      stacked_done_masks,
      tf.zeros_like(stacked_frames), stacked_frames)

  new_state = stacked_frames[-1, ..., :-1]

  return stacked_frames, new_state

main changes are:

  • removed bit-packing related code
  • added unstacked_state = tf.unstack(frame_stacking_state, axis=-1)[::-1]
  • added new_state = stacked_frames[-1, ..., :-1]
  • AgentState.frame_stacking_state.shape is now (batch_size, ...features.., stack_size-1)

frame in this case us represents observations

@lespeholt @RaphaelMarinier what do you think? any feedback is appreciated (:

turmeric-blend avatar Apr 11 '20 09:04 turmeric-blend

Looks sensible (see tests in networks_test.py) although frame stacking can be done quite simple when specialized tricks aren't used. Instead of using a tensor with all frames, one can do something like (pseudo code) the following using separate tensors:

def initial_state(obs_shape, stack_size): return (tf.zeros(obs_shape, ....),) * stack_size

def call(obs, ..., state): if done: new_state = initial_state()[1:] + (obs,) else: new_state = state[1:] + (obs,)

lespeholt avatar Apr 14 '20 10:04 lespeholt

hi @lespeholt I've taken a look at networks_test.py and was wondering why input frame frames=[[[1]]] in test_stack_frames(self) has shape [time=1, batch_size=1, channels=1] used for testing?

def test_stack_frames(self):
    zero_state = networks.DuelingLSTMDQNNet(2, [1], stack_size=4).initial_state(
        1).frame_stacking_state
    # frames: [time=1, batch_size=1, channels=1].     <----- ?
    # done: [time=1, batch_size=1].
    output, state = stack_frames(
        frames=[[[1]]], done=[[False]], frame_stacking_state=zero_state,
        stack_size=4)

Whereas the running code requires frames: <float32>[time, batch_size, height, width, channels] as input.

Why don't we include height, width in the test_stack_frames(self) function as well?

turmeric-blend avatar Apr 16 '20 04:04 turmeric-blend

Good observation. In practice, stack_frames() is agnostic to the spatial dimensions of the observation. It probably works on observations of shape [time, batch_size, <spatial_dims>, channels] regardless what <spatial_dims> are. The test is written like that to keep it as simple as possible (heavily nested tensors are cumbersome to specify and to debug).

RaphaelMarinier avatar Apr 21 '20 14:04 RaphaelMarinier

I seem to be getting an error when I run networks_test.py's test_stack_frames() and test_stack_frames_done() functions with my converted networks.py code at:

extended_frames = tf.concat(
      [tf.reshape(frame, [1] + frame.shape + [1])
       for frame in unstacked_state] +
      [frames],
      axis=0)

ConcatOp : Ranks of all input tensors should match: shape[0] = [1,1,1,1] vs. shape[3] = [1,1,1] [Op:ConcatV2]

this may be due to my initial_frame_stacking_state() function returning an extra stack_size dimension and unstacked_state = tf.unstack(frame_stacking_state, axis=-1)[::-1] unpacking it differently than the original code (ie with vs without bit-packing).

My question is that networks_test.py is bit-packing specific? and it might not be a implementation error at my code for generic stacking purposes? If so, I tried a work around at the networks_test.py code at test_stack_frames() and test_stack_frames_done() by adding an extra dimension to the input frames which is basically the 'spatial dimension' eg frames=[[[1]]] --> frames=[[[[1]]]]. But im not sure if this is the right way.

turmeric-blend avatar Apr 23 '20 01:04 turmeric-blend