habitat-lab
habitat-lab copied to clipboard
[baselines] Refactor Observations Encoders
🚀 Feature
Motivation
- Currently, all the observation heads for various observation types are hard coded in a list of if statements in the ResNet encoder and elsewhere.
- We should factor these out and make them useful outside of the policy.
- We should also standardize where/when RGB images are normalized and the channel orders are changed in our baselines pipeline.
Pitch
- Doing so would make it easier to integrate new observation types and try out new architecture without a massive amount of code reuse from ResNet.
Alternatives
- Move these multi-head encoders into an abstract class that can be extended by multiple policies through multi-inheritance? Seems a bit messy.
- Make all the observation transforms and such as forward hooks onto the nn.Module
Additional context
@rpartsey, you can share your experience or code snippets here how you done it. Thank you!
@mathfac, thanks for letting me know about this issue We were also trying to solve a similar problem: simple yet flexible code (to run experiments and iterate fast). It was important for us to have a model with configurable encoders and input dimensions (number of channels) with FC layers on top.
We also considered implementing each encoder type as part of our project, but after investigating available libraries decided to add segmentation_models_pytorch as a dependency. It already has a wide range of encoders implemented and it is actively supported by the community (+ some useful functionality like patching first Conv layer/loading pre-trained weights without FC layers/... ).
Our model code looks smth like this (we use from_config classmethod to instantiate objects):
# model.py
import torch
import torch.nn as nn
from segmentation_models_pytorch.encoders import get_encoder
class Net(nn.Module):
def __init__(self, encoder, fc):
super().__init__()
self.encoder = encoder
self.fc = fc
def forward(self, x):
x = self.encoder(x)[-1] # get last stage output
x = self.fc(x)
return x
@classmethod
def from_config(cls, model_config):
model_params = model_config.params
encoder_params = model_params.encoder.params
fc_params = model_params.fc.params
encoder = get_encoder(
name=model_params.encoder.type,
in_channels=encoder_params.in_channels,
depth=encoder_params.depth,
weights=encoder_params.weights
)
fc = cls.create_fc_layers(
input_size=cls.compute_output_size(encoder, encoder_params),
hidden_size=fc_params.hidden_size,
output_size=fc_params.output_size,
p_dropout=fc_params.p_dropout
)
return cls(encoder, fc)
@staticmethod
def create_fc_layers(input_size: int, hidden_size: list, output_size: int, p_dropout: float = 0.0):
fc_layers = ... # FC layers code
return fc_layers
@staticmethod
def compute_output_size(encoder, config):
input_size = (1, config.in_channels, config.in_height, config.in_width)
encoder_input = torch.randn(*input_size)
with torch.no_grad():
output = encoder(encoder_input)
return output[-1].view(-1).size(0)
Swapping encoders is as easy as changing the encoder name in the config.yaml file:
# config.yaml
model:
type: Net
save: True
params:
encoder:
type: resnet18 # change to any other encoder name from segmentation_models_pytorch
params:
depth: 5
weights: imagenet
in_channels: 8
in_height: 360
in_width: 640
fc:
params:
hidden_size: [512, 512]
output_size: 4
p_dropout: 0
One word of warning: If you are training things, BatchNorm can be problematic due to the highly correlated data seen in RL and IL