habitat-lab icon indicating copy to clipboard operation
habitat-lab copied to clipboard

[baselines] Refactor Observations Encoders

Open Skylion007 opened this issue 4 years ago • 3 comments

🚀 Feature

Motivation

  • Currently, all the observation heads for various observation types are hard coded in a list of if statements in the ResNet encoder and elsewhere.
  • We should factor these out and make them useful outside of the policy.
  • We should also standardize where/when RGB images are normalized and the channel orders are changed in our baselines pipeline.

Pitch

  • Doing so would make it easier to integrate new observation types and try out new architecture without a massive amount of code reuse from ResNet.

Alternatives

  • Move these multi-head encoders into an abstract class that can be extended by multiple policies through multi-inheritance? Seems a bit messy.
  • Make all the observation transforms and such as forward hooks onto the nn.Module

Additional context

Skylion007 avatar Jan 05 '21 16:01 Skylion007

@rpartsey, you can share your experience or code snippets here how you done it. Thank you!

mathfac avatar Jan 14 '21 02:01 mathfac

@mathfac, thanks for letting me know about this issue We were also trying to solve a similar problem: simple yet flexible code (to run experiments and iterate fast). It was important for us to have a model with configurable encoders and input dimensions (number of channels) with FC layers on top.

We also considered implementing each encoder type as part of our project, but after investigating available libraries decided to add segmentation_models_pytorch as a dependency. It already has a wide range of encoders implemented and it is actively supported by the community (+ some useful functionality like patching first Conv layer/loading pre-trained weights without FC layers/... ).

Our model code looks smth like this (we use from_config classmethod to instantiate objects):

# model.py

import torch
import torch.nn as nn
from segmentation_models_pytorch.encoders import get_encoder


class Net(nn.Module):
    def __init__(self, encoder, fc):
        super().__init__()
        self.encoder = encoder
        self.fc = fc

    def forward(self, x):
        x = self.encoder(x)[-1]  # get last stage output
        x = self.fc(x)

        return x

    @classmethod
    def from_config(cls, model_config):
        model_params = model_config.params
        encoder_params = model_params.encoder.params
        fc_params = model_params.fc.params

        encoder = get_encoder(
            name=model_params.encoder.type,
            in_channels=encoder_params.in_channels,
            depth=encoder_params.depth,
            weights=encoder_params.weights
        )

        fc = cls.create_fc_layers(
            input_size=cls.compute_output_size(encoder, encoder_params),
            hidden_size=fc_params.hidden_size,
            output_size=fc_params.output_size,
            p_dropout=fc_params.p_dropout
        )

        return cls(encoder, fc)

    @staticmethod
    def create_fc_layers(input_size: int, hidden_size: list, output_size: int, p_dropout: float = 0.0):
        fc_layers = ...  # FC layers code

        return fc_layers

    @staticmethod
    def compute_output_size(encoder, config):
        input_size = (1, config.in_channels, config.in_height, config.in_width)

        encoder_input = torch.randn(*input_size)
        with torch.no_grad():
            output = encoder(encoder_input)

        return output[-1].view(-1).size(0)

Swapping encoders is as easy as changing the encoder name in the config.yaml file:

# config.yaml

model:
  type: Net
  save: True
  params:
    encoder:
      type: resnet18  # change to any other encoder name from segmentation_models_pytorch
      params:
        depth: 5
        weights: imagenet
        in_channels: 8
        in_height: 360
        in_width: 640
    fc:
      params:
        hidden_size: [512, 512]
        output_size: 4
        p_dropout: 0

rpartsey avatar Jan 14 '21 20:01 rpartsey

One word of warning: If you are training things, BatchNorm can be problematic due to the highly correlated data seen in RL and IL

erikwijmans avatar Jan 14 '21 21:01 erikwijmans