ludwig icon indicating copy to clipboard operation
ludwig copied to clipboard

Refactor: Make every output feature define their own forward()

Open justinxzhao opened this issue 3 years ago • 0 comments
trafficstars

When studying the PyTorch code for output feature modules, there's no forward().

Only upon looking at the base class, one learns that the base class defines a real (and default) forward pass, that invokes a special abstract method, logits() that all subclasses are required to implement. The base class' forward pass serves as a thin wrapper around the real forward pass, which is defined by subclass' the logits() method.

This is rather unusual for PyTorch modules, which usually define their own forward().

One benefit of this design is that the boilerplate code that unpacks combiner outputs into a representation that can be fed to an output_feature's decoder object, as well as the code that prepares the decoder object's outputs into the final tensors that are passed to prediction, are consolidated into a single location.

However, a more canonical design would be to have the boilerplate operations factored out into separate libraries (i.e. feature_utils.py), and have each output feature call these methods, as needed, in their own forward(). This is also more canonical pytorch and more flexible -- output features like sequence or text features may have other more customized ways to unpack combiner tensors.

Instead of:

class LudwigModule(BaseClass):
  def logits(tensor):
    return stuff(tensor)

  def other_method():
    pass

class BaseClass:
  def forward(tensor):
    tensor = prep_input_tensor(tensor)
    tensor = self.logits(tensor) # Calls the subclass implementation, which is confusing.
    return prep_output_tensor(tensor)

  @abstractmethod
  def logits(tensor): -> tensor
    raise NotImplementedError()

  @abstractmethod
  def other_method():
    raise NotImplementedError()

This is better:

class LudwigModule(BaseClass):
  def forward(tensor):
    tensor = prep_input_tensor(tensor)
    tensor = stuff(tensor)
    return prep_output_tensor(tensor)

class BaseClass:
  @abstractmethod
  def other_method():
    raise NotImplementedError()

justinxzhao avatar May 19 '22 23:05 justinxzhao