vak icon indicating copy to clipboard operation
vak copied to clipboard

ENH: Make Model just an attrs class with a `from_config` classmethod

Open NickleDave opened this issue 3 years ago • 2 comments
trafficstars

to decouple Model from Engine, add a Model attrs class with a from_config classmethod.

This will be a kind of "interface": if you want to be a Model, have a from_config method that will result in network + loss + optimizer +metrics attributes

This way user doesn't have to actually subclass this Model attrs class. They can use the built-in attrs class if they, want or they can make some other dataclass that can be used with the cli as long as it obeys the "interface" . vak will "know" to instantiate the model using from_config

NickleDave avatar Jul 08 '22 01:07 NickleDave

Goal of this is to make it easier to instantiate a model, and fix #362 (linking here so I can close that one)

NickleDave avatar Jul 29 '22 23:07 NickleDave

The one thing this still does not give us is a __call__ dunder-method, in the same way one can do with a torch.Module though :(

Would be really nice to be able to say y = netowrk(x) and then easily visualize both x and y, e.g. assuming one is a spectrogram and the other is predicted segments

NickleDave avatar Jul 29 '22 23:07 NickleDave

picking this up again

what we want:

  • a way to get a Model instance for any model we declare where we can do two things:
    • just get the network output, e.g. for training
    • get an output with an optional post-processing transform applied, e.g. for prediction

how to implement it:

  • make a Model class with two methods:
    • forward, which just passes a tensor into Model.net and returns the output tensor, the same as one would do with a "raw" torch.nn.Module
    • the __call__ method, which internally calls Model.forward and then passes that output through a post-processing transform if one is specified, which will be something like torchvision.transforms.Compose

and additionally implement a decorator which accepts a user-specified model MyModel in the form of a class with a required set of class variables that correspond to attributes expected by both vak.Model and vak.Engine; the decorator returns a new class which basically has the user-specified MyModel as its own class attribute that it uses when making a new instance of Model

this starts to get pretty meta, I know (is this meta-programming yet?); I will post a code snippet next to illustrate a little better

NickleDave avatar Nov 23 '22 21:11 NickleDave

Here's a code snippet that (I think) illustrates what I have in mind:

def model(model):
    """a decorator that creates a model"""

    class Model:
        self._model = model

        def __init__(network, 
                     loss,
                     optimizer,
                     metrics,
                     post_tfm=None):
            self.network = network
            self.loss = loss
            self.optimizer = optimizer
            self.metrics = metrics
            self.post_tfm = post_tfm

        def forward(input):
            return self.network(input)

        def __call__(input):
            output = self.network(input)
            if self.post_tfm:
                output = self.post_tfm(output)
            return output

        @classmethod
        def from_config(config):
            network = self._model.network(**config['network'])        
            loss = self._model.loss(**config['loss'])
            optimizer = self._model.optimizer(params=network.parameters(), **config['optimizer'])
            metrics = {metric_name: metric_class()
                       for metric_name, metric_class in self._model.metrics.items()}
            # what to do about post_tfm? 
            # do we want to able to write a transform in a config file? 
            # or declare with a class, or both?
            return cls(network=network, optimizer=optimizer, loss=loss, metrics=metrics)
    
    return Model

@vak.model
class TweetyNet:
    network = TweetyNet(**config['network'])        
    loss = torch.nn.CrossEntropyLoss(**config['loss'])
    optimizer = torch.optim.Adam(params=network.parameters(), **config['optimizer'])
    metrics = {'acc': vak.metrics.Accuracy(),
               'levenshtein': vak.metrics.Levenshtein(),
               'segment_error_rate': vak.metrics.SegmentErrorRate(),
               'loss': torch.nn.CrossEntropyLoss()}

NickleDave avatar Nov 23 '22 21:11 NickleDave

Unfinished business:

  • I want to be able to do the following I get a new model with defaults, so I don't need to pass in any config
VakMyModel = vak.model(MyModel)  # using the decorator directly as a function
VakMyModel()  # I don't have to pass in any config or anything, just make a new instance

We could do this by setting the default values for the Model.__init__ function to be instances of the user class attributes, I think? Or do I need to somehow override __new__ here? This is getting to be way meta

NickleDave avatar Nov 23 '22 21:11 NickleDave

I think I have a very Minimal VP of this working:

import functools

import vak
import torch
import tweetynet

def model(model):
    """a decorator that creates a model"""

    @functools.wraps(model, updated=()) 
    class Model:
        _model = model

        def __init__(self,
                     network=None, 
                     loss=None,
                     optimizer=None,
                     metrics=None,
                     post_tfm=None):
            if network is None:
                network = self._model.network()
            if loss is None:
                loss = self._model.loss()
            if optimizer is None:
                optimizer = self._model.optimizer(params=network.parameters())
            if metrics is None:
                metrics = {metric_name: metric_class()
                           for metric_name, metric_class in self._model.metrics.items()}

            self.network = network
            self.loss = loss
            self.optimizer = optimizer
            self.metrics = metrics
            self.post_tfm = post_tfm

        def forward(input):
            return self.network(input)

        def __call__(input):
            output = self.network(input)
            if self.post_tfm:
                output = self.post_tfm(output)
            return output

        @classmethod
        def from_config(cls, config: dict):
            network = self._model.network(**config['network'])        
            loss = self._model.loss(**config['loss'])
            optimizer = self._model.optimizer(params=network.parameters(), **config['optimizer'])
            metrics = {metric_name: metric_class()
                       for metric_name, metric_class in self._model.metrics.items()}
            # what to do about post_tfm? 
            # do we want to able to write a transform in a config file? 
            # or declare with a class, or both?
            return cls(network=network, optimizer=optimizer, loss=loss, metrics=metrics)
    
    return Model

@model
class TweetyNetModel:
    """Model that uses TweetyNet architecture"""
    network = tweetynet.TweetyNet
    loss = torch.nn.CrossEntropyLoss
    optimizer = torch.optim.Adam
    metrics = {'acc': vak.metrics.Accuracy,
               'levenshtein': vak.metrics.Levenshtein,
               'segment_error_rate': vak.metrics.SegmentErrorRate,
               'loss': torch.nn.CrossEntropyLoss}

>>> tweety = TweetyNetModel(network=tweetynet.TweetyNet(num_classes=10))

>>> type(tweety)
__main__.TweetyNetModel

>>> tweety.__doc__
'Model that uses TweetyNet architecture'

NickleDave avatar Nov 24 '22 03:11 NickleDave

Feature branch is already in progress, but documenting here how I ended up implementing this. Basically, as follows:

  • [ ] Have an attrs-like or dataclasses-like ModelDefinition that specifies a model's network, loss function, optimizer, and metrics as class variables, as in #406
  • [ ] Have a base Model class that sub-classes the lightning.LightningModule but additionally has a definition attribute, which is a ModelDefinition class (the whole class! not just an instance).
    • this Model does two things
      • it either accepts instances of the classes on the definition, i.e. an instance of the network that the definition specified, or if no instance is passed in, it makes a default version of the instance
      • it alternatively accepts a config or a path to a config that loads up and instantiates the classes from the definition using the config, i.e. through a class method
  • [ ] finally, specific families of models will further sub-class the Model class, and implement their own logic for the lightning.LightningModule methods like train_step, validation_step, predict_step, etc.
  • [ ] each of these will have a decorator that does the following
    • make a new subclass of the model family class that has the same name as the model definition, and has the model definition's class set as its definition attribute

So, from a user's perspective, they can in theory declare a model for any task they want by writing a model definition that they decorate with the decorator for a family of models. From their POV, they never do any sub-classing; they write a definition and apply a decorator to it.

NickleDave avatar Dec 24 '22 23:12 NickleDave

Here's the Minimum Viable Implementation of that:

import functools
import pathlib
from typing import Callable, ClassVar, NewType

import lightning
import torch
import vak
from vak import labeled_timebins

class ModelDefinition:
    """A class that represents the definition of a model.

    A model definition should specify the following class variables:
        network: torch.nn.Module
        loss: torch.nn.Module
        optimizer: torch.optim.Optimizer
        metrics: dict
    """
    network: torch.nn.Module | dict[str: torch.nn.Module]
    loss: torch.nn.Module | dict[str: torch.nn.Module]
    optimizer: torch.optim.Optimizer
    metrics: dict[str: Callable]

class Model(lightning.LightningModule):
    definition: ClassVar[ModelDefinition]
    REQUIRED_CLASSVARS = ('network', 'loss', 'optimizer', 'metrics')

    def __init__(self,
                 network: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
                 loss: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
                 optimizer: torch.optim.Optimizer | None = None,
                 metrics: dict[str: Callable] | None = None):
        super().__init__()

        # check that we are a sub-class of some other class with required class variables
        if not hasattr(self, 'definition'):
            raise ValueError(
                'This model does not have a definition.'
                'Define a model by wrapping a class with the required class variables with '
                'the ``vak.models.model`` decorator.'
            )
        if not all(
                [hasattr(self.definition, reqd_classvar)
                 for reqd_classvar in self.REQUIRED_CLASSVARS]
        ):
            raise ValueError(
                'vak.Model classes should have all the following class variables defined:\n'
                f'{self.REQUIRED_CLASSVARS}'
            )

        if network is None:
            network = self.definition.network()
        if loss is None:
            loss = self.definition.loss()
        if optimizer is None:
            optimizer = self.definition.optimizer(params=network.parameters())
        if metrics is None:
            metrics = {metric_name: metric_class()
                       for metric_name, metric_class in self.definition.metrics.items()}

    @classmethod
    def from_config(cls, config: dict, post_tfm: Callable | None = None):
        if isinstance(cls.definition.network, dict):
            network = {net_name: net_class(**config['network'][net_name])
                      for net_name, net_class in cls.definition.network.items()}
        elif isinstance(cls.definition.network, torch.nn.Module):
            network = cls.definition.network(**config['network'])

        if isinstance(cls.definition.optimizer, dict):
            # TODO: handle network parameters here
            # simplest case: make parameters from all nets first as flattened list, then pass in
            # more complex case: allow for net_opt config, 
            # not in either opt or net config so we can still just **unpack those
            optimizer = {opt_name: opt_class(**config['optimizer'][opt_name])
                      for opt_name, net_class in cls.definition.network.items()}
        elif isinstance(cls.definition.optimizer, torch.nn.Module):
            optimizer = cls.definition.optimizer(params=network.parameters(), **config['optimizer'])

        loss = cls.definition.loss(**config['loss'])
        metrics = {metric_name: metric_class()
                   for metric_name, metric_class in cls._model.metrics.items()}
        return cls(network=network, optimizer=optimizer, loss=loss, metrics=metrics, post_tfm=post_tfm)

    @classmethod
    def from_config_path(cls, config_path: str | pathlib.Path, post_tfm: Callable | None = None):
        # config = config.model  # need to figure out better config for models here
        # self.from_config(config)
        pass

# define this here instead of vak.typing to avoid circular imports
ModelSubclass = NewType('ModelSubclass', Model)

class WindowedFrameClassificationModel(Model):
    def __init__(self,
                 network: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
                 loss: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
                 optimizer: torch.optim.Optimizer | None = None,
                 metrics: dict[str: Callable] | None = None,
                 post_tfm: Callable | None = None,
                 ):
        """A LightningModule that represents
        a model that predicts a label for each frame
        in a window, e.g., each time bin in
        a window from a spectrogram.

        This task represents one way of
        predicting annotations for a vocalization,
        where the annotations consist of a sequence
        of segments, each with an onset, offset,
        and label.
        The model maps the spectrogam window
        to a vector of labels for each frame, i.e.,
        each time bin.

        To annotate a vocalization with such a model,
        the spectrogram is converted into a batch of
        consecutive non-overlapping windows,
        for which the model produces predictions.
        These predictions are then concatenated
        into a vector of labeled frames,
        from which the segments can be recovered.

        Post-processing can be applied to the vector
        to clean up noisy predictions
        before recovering the segments."""
        super().__init__(network=network, loss=loss,
                         optimizer=optimizer, metrics=metrics)
        self.lbl_tb2labels = labeled_timebins.lbl_tb2labels
        self.post_tfm = post_tfm

    def configure_optimizers(self):
        return self.optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch[0], batch[1]
        y_pred = self.network(x)
        loss = self.loss_func(y_pred, y)
        return loss

    def validation_step(self, batch, batch_idx):
        # TODO: rename "source" -> "spect"
        # TODO: a sample can have "spect", "audio", "annot", optionally other things ("padding"?)
        x, y = batch["source"], batch["annot"]
        # remove "batch" dimension added by collate_fn to x
        # we keep for y because loss still expects the first dimension to be batch
        # TODO: fix this weirdness. Diff't collate_fn?
        if x.ndim == 5:
            if x.shape[0] == 1:
                x = torch.squeeze(x, dim=0)
        else:
            raise ValueError(f"invalid shape for x: {x.shape}")

        out = self.network(x)
        # permute and flatten out
        # so that it has shape (1, number classes, number of time bins)
        # ** NOTICE ** just calling out.reshape(1, out.shape(1), -1) does not work, it will change the data
        out = out.permute(1, 0, 2)
        out = torch.flatten(out, start_dim=1)
        out = torch.unsqueeze(out, dim=0)
        # reduce to predictions, assuming class dimension is 1
        y_pred = torch.argmax(
            out, dim=1
        )  # y_pred has dims (batch size 1, predicted label per time bin)

        if "padding_mask" in batch:
            padding_mask = batch[
                "padding_mask"
            ]  # boolean: 1 where valid, 0 where padding
            # remove "batch" dimension added by collate_fn
            # because this extra dimension just makes it confusing to use the mask as indices
            if padding_mask.ndim == 2:
                if padding_mask.shape[0] == 1:
                    padding_mask = torch.squeeze(padding_mask, dim=0)
            else:
                raise ValueError(
                    f"invalid shape for padding mask: {padding_mask.shape}"
                )

            out = out[:, :, padding_mask]
            y_pred = y_pred[:, padding_mask]

        if self.post_tfm:
            y_pred = self.post_tfm(y_pred)

        y_labels = self.lbl_tb2labels(
            y.cpu().numpy(),
        )
        y_pred_labels = self.lbl_tb2labels(
            y_pred.cpu().numpy()
        )

        # TODO: figure out smarter way to do this
        for metric_name, metric_callable in self.metrics.items():
            if metric_name == "loss":
                self.log(f'val_{metric_name}', self.loss_func(out, y), batch_size=1)
            elif metric_name == "acc":
                self.log(f'val_{metric_name}', metric_callable(y_pred, y), batch_size=1)
            elif metric_name == "levenshtein" or metric_name == "segment_error_rate":
                self.log(f'val_{metric_name}', metric_callable(y_pred_labels, y_labels), batch_size=1)

    def predict_step(self, batch, batch_idx: int):
        x, spect_path = batch["source"].to(self.device), batch["spect_path"]
        if isinstance(spect_path, list) and len(spect_path) == 1:
            spect_path = spect_path[0]
        if x.ndim == 5:
            if x.shape[0] == 1:
                x = torch.squeeze(x, dim=0)
        y_pred = self.network(x)
        return {spect_path: y_pred}

def windowed_frame_classification_model(modeldef: ModelDefinition) -> ModelSubclass:
    """A decorator that creates a model"""    
    attributes = dict(WindowedFrameClassificationModel.__dict__)
    attributes.update({'definition': modeldef})
    wrapped_model = type(modeldef.__name__, (WindowedFrameClassificationModel,), attributes)
   # realized when testing we don't actually need next line since we adding the ModelDefinition as an attribute
   # to the subclass we just made, i.e. we're not really wrapping a class here, so this doesn't make sense
    wrapped_model = functools.wraps(windowed_frame_classification_model, updated=())(wrapped_model)

    return wrapped_model

@windowed_frame_classification_model
class TweetyNetModel:
    network = vak.nets.TweetyNet
    loss = torch.nn.CrossEntropyLoss
    optimizer = torch.optim.Adam
    metrics = {'acc': vak.metrics.Accuracy,
               'levenshtein': vak.metrics.Levenshtein,
               'segment_error_rate': vak.metrics.SegmentErrorRate,
               'loss': torch.nn.CrossEntropyLoss}

After all this, one can do the following to instantiate a model, without needing a config

tweetynet = vak.nets.TweetyNet(num_classes=10)
model = TweetyNetModel(network=tweetynet)

and its methods will be the underlying LightningModule methods, specifically those implemented by WindowedFrameClassificationModel

model.predict_step
<bound method WindowedFrameClassificationModel.predict_step of windowed_frame_classification_model()>

NickleDave avatar Dec 24 '22 23:12 NickleDave

Closed by #605

NickleDave avatar Jan 22 '23 02:01 NickleDave