ludwig
ludwig copied to clipboard
Refactor: Make every output feature define their own forward()
When studying the PyTorch code for output feature modules, there's no forward().
Only upon looking at the base class, one learns that the base class defines a real (and default) forward pass, that invokes a special abstract method, logits() that all subclasses are required to implement. The base class' forward pass serves as a thin wrapper around the real forward pass, which is defined by subclass' the logits() method.
This is rather unusual for PyTorch modules, which usually define their own forward().
One benefit of this design is that the boilerplate code that unpacks combiner outputs into a representation that can be fed to an output_feature's decoder object, as well as the code that prepares the decoder object's outputs into the final tensors that are passed to prediction, are consolidated into a single location.
However, a more canonical design would be to have the boilerplate operations factored out into separate libraries (i.e. feature_utils.py), and have each output feature call these methods, as needed, in their own forward(). This is also more canonical pytorch and more flexible -- output features like sequence or text features may have other more customized ways to unpack combiner tensors.
Instead of:
class LudwigModule(BaseClass):
def logits(tensor):
return stuff(tensor)
def other_method():
pass
class BaseClass:
def forward(tensor):
tensor = prep_input_tensor(tensor)
tensor = self.logits(tensor) # Calls the subclass implementation, which is confusing.
return prep_output_tensor(tensor)
@abstractmethod
def logits(tensor): -> tensor
raise NotImplementedError()
@abstractmethod
def other_method():
raise NotImplementedError()
This is better:
class LudwigModule(BaseClass):
def forward(tensor):
tensor = prep_input_tensor(tensor)
tensor = stuff(tensor)
return prep_output_tensor(tensor)
class BaseClass:
@abstractmethod
def other_method():
raise NotImplementedError()