vak
vak copied to clipboard
ENH: Make Model just an attrs class with a `from_config` classmethod
to decouple Model from Engine, add a Model attrs class with a from_config classmethod.
This will be a kind of "interface": if you want to be a Model, have a from_config method that will result in network + loss + optimizer +metrics attributes
This way user doesn't have to actually subclass this Model attrs class.
They can use the built-in attrs class if they, want or they can make some other dataclass that can be used with the cli as long as it obeys the "interface" .
vak will "know" to instantiate the model using from_config
Goal of this is to make it easier to instantiate a model, and fix #362 (linking here so I can close that one)
The one thing this still does not give us is a __call__ dunder-method, in the same way one can do with a torch.Module though :(
Would be really nice to be able to say y = netowrk(x) and then easily visualize both x and y, e.g. assuming one is a spectrogram and the other is predicted segments
picking this up again
what we want:
- a way to get a
Modelinstance for any model we declare where we can do two things:- just get the network output, e.g. for training
- get an output with an optional post-processing transform applied, e.g. for prediction
how to implement it:
- make a
Modelclass with two methods:forward, which just passes a tensor intoModel.netand returns the output tensor, the same as one would do with a "raw"torch.nn.Module- the
__call__method, which internally callsModel.forwardand then passes that output through a post-processing transform if one is specified, which will be something liketorchvision.transforms.Compose
and additionally implement a decorator which accepts a user-specified model MyModel in the form of a class with a required set of class variables that correspond to attributes expected by both vak.Model and vak.Engine; the decorator returns a new class which basically has the user-specified MyModel as its own class attribute that it uses when making a new instance of Model
this starts to get pretty meta, I know (is this meta-programming yet?); I will post a code snippet next to illustrate a little better
Here's a code snippet that (I think) illustrates what I have in mind:
def model(model):
"""a decorator that creates a model"""
class Model:
self._model = model
def __init__(network,
loss,
optimizer,
metrics,
post_tfm=None):
self.network = network
self.loss = loss
self.optimizer = optimizer
self.metrics = metrics
self.post_tfm = post_tfm
def forward(input):
return self.network(input)
def __call__(input):
output = self.network(input)
if self.post_tfm:
output = self.post_tfm(output)
return output
@classmethod
def from_config(config):
network = self._model.network(**config['network'])
loss = self._model.loss(**config['loss'])
optimizer = self._model.optimizer(params=network.parameters(), **config['optimizer'])
metrics = {metric_name: metric_class()
for metric_name, metric_class in self._model.metrics.items()}
# what to do about post_tfm?
# do we want to able to write a transform in a config file?
# or declare with a class, or both?
return cls(network=network, optimizer=optimizer, loss=loss, metrics=metrics)
return Model
@vak.model
class TweetyNet:
network = TweetyNet(**config['network'])
loss = torch.nn.CrossEntropyLoss(**config['loss'])
optimizer = torch.optim.Adam(params=network.parameters(), **config['optimizer'])
metrics = {'acc': vak.metrics.Accuracy(),
'levenshtein': vak.metrics.Levenshtein(),
'segment_error_rate': vak.metrics.SegmentErrorRate(),
'loss': torch.nn.CrossEntropyLoss()}
Unfinished business:
- I want to be able to do the following I get a new model with defaults, so I don't need to pass in any config
VakMyModel = vak.model(MyModel) # using the decorator directly as a function
VakMyModel() # I don't have to pass in any config or anything, just make a new instance
We could do this by setting the default values for the Model.__init__ function to be instances of the user class attributes, I think? Or do I need to somehow override __new__ here? This is getting to be way meta
I think I have a very Minimal VP of this working:
import functools
import vak
import torch
import tweetynet
def model(model):
"""a decorator that creates a model"""
@functools.wraps(model, updated=())
class Model:
_model = model
def __init__(self,
network=None,
loss=None,
optimizer=None,
metrics=None,
post_tfm=None):
if network is None:
network = self._model.network()
if loss is None:
loss = self._model.loss()
if optimizer is None:
optimizer = self._model.optimizer(params=network.parameters())
if metrics is None:
metrics = {metric_name: metric_class()
for metric_name, metric_class in self._model.metrics.items()}
self.network = network
self.loss = loss
self.optimizer = optimizer
self.metrics = metrics
self.post_tfm = post_tfm
def forward(input):
return self.network(input)
def __call__(input):
output = self.network(input)
if self.post_tfm:
output = self.post_tfm(output)
return output
@classmethod
def from_config(cls, config: dict):
network = self._model.network(**config['network'])
loss = self._model.loss(**config['loss'])
optimizer = self._model.optimizer(params=network.parameters(), **config['optimizer'])
metrics = {metric_name: metric_class()
for metric_name, metric_class in self._model.metrics.items()}
# what to do about post_tfm?
# do we want to able to write a transform in a config file?
# or declare with a class, or both?
return cls(network=network, optimizer=optimizer, loss=loss, metrics=metrics)
return Model
@model
class TweetyNetModel:
"""Model that uses TweetyNet architecture"""
network = tweetynet.TweetyNet
loss = torch.nn.CrossEntropyLoss
optimizer = torch.optim.Adam
metrics = {'acc': vak.metrics.Accuracy,
'levenshtein': vak.metrics.Levenshtein,
'segment_error_rate': vak.metrics.SegmentErrorRate,
'loss': torch.nn.CrossEntropyLoss}
>>> tweety = TweetyNetModel(network=tweetynet.TweetyNet(num_classes=10))
>>> type(tweety)
__main__.TweetyNetModel
>>> tweety.__doc__
'Model that uses TweetyNet architecture'
Feature branch is already in progress, but documenting here how I ended up implementing this. Basically, as follows:
- [ ] Have an
attrs-like ordataclasses-likeModelDefinitionthat specifies a model's network, loss function, optimizer, and metrics as class variables, as in #406 - [ ] Have a base
Modelclass that sub-classes thelightning.LightningModulebut additionally has adefinitionattribute, which is aModelDefinitionclass (the whole class! not just an instance).- this
Modeldoes two things- it either accepts instances of the classes on the definition, i.e. an instance of the network that the definition specified, or if no instance is passed in, it makes a default version of the instance
- it alternatively accepts a config or a path to a config that loads up and instantiates the classes from the definition using the config, i.e. through a class method
- this
- [ ] finally, specific families of models will further sub-class the
Modelclass, and implement their own logic for thelightning.LightningModulemethods liketrain_step,validation_step,predict_step, etc. - [ ] each of these will have a decorator that does the following
- make a new subclass of the model family class that has the same name as the model definition, and has the model definition's class set as its definition attribute
So, from a user's perspective, they can in theory declare a model for any task they want by writing a model definition that they decorate with the decorator for a family of models. From their POV, they never do any sub-classing; they write a definition and apply a decorator to it.
Here's the Minimum Viable Implementation of that:
import functools
import pathlib
from typing import Callable, ClassVar, NewType
import lightning
import torch
import vak
from vak import labeled_timebins
class ModelDefinition:
"""A class that represents the definition of a model.
A model definition should specify the following class variables:
network: torch.nn.Module
loss: torch.nn.Module
optimizer: torch.optim.Optimizer
metrics: dict
"""
network: torch.nn.Module | dict[str: torch.nn.Module]
loss: torch.nn.Module | dict[str: torch.nn.Module]
optimizer: torch.optim.Optimizer
metrics: dict[str: Callable]
class Model(lightning.LightningModule):
definition: ClassVar[ModelDefinition]
REQUIRED_CLASSVARS = ('network', 'loss', 'optimizer', 'metrics')
def __init__(self,
network: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
loss: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
optimizer: torch.optim.Optimizer | None = None,
metrics: dict[str: Callable] | None = None):
super().__init__()
# check that we are a sub-class of some other class with required class variables
if not hasattr(self, 'definition'):
raise ValueError(
'This model does not have a definition.'
'Define a model by wrapping a class with the required class variables with '
'the ``vak.models.model`` decorator.'
)
if not all(
[hasattr(self.definition, reqd_classvar)
for reqd_classvar in self.REQUIRED_CLASSVARS]
):
raise ValueError(
'vak.Model classes should have all the following class variables defined:\n'
f'{self.REQUIRED_CLASSVARS}'
)
if network is None:
network = self.definition.network()
if loss is None:
loss = self.definition.loss()
if optimizer is None:
optimizer = self.definition.optimizer(params=network.parameters())
if metrics is None:
metrics = {metric_name: metric_class()
for metric_name, metric_class in self.definition.metrics.items()}
@classmethod
def from_config(cls, config: dict, post_tfm: Callable | None = None):
if isinstance(cls.definition.network, dict):
network = {net_name: net_class(**config['network'][net_name])
for net_name, net_class in cls.definition.network.items()}
elif isinstance(cls.definition.network, torch.nn.Module):
network = cls.definition.network(**config['network'])
if isinstance(cls.definition.optimizer, dict):
# TODO: handle network parameters here
# simplest case: make parameters from all nets first as flattened list, then pass in
# more complex case: allow for net_opt config,
# not in either opt or net config so we can still just **unpack those
optimizer = {opt_name: opt_class(**config['optimizer'][opt_name])
for opt_name, net_class in cls.definition.network.items()}
elif isinstance(cls.definition.optimizer, torch.nn.Module):
optimizer = cls.definition.optimizer(params=network.parameters(), **config['optimizer'])
loss = cls.definition.loss(**config['loss'])
metrics = {metric_name: metric_class()
for metric_name, metric_class in cls._model.metrics.items()}
return cls(network=network, optimizer=optimizer, loss=loss, metrics=metrics, post_tfm=post_tfm)
@classmethod
def from_config_path(cls, config_path: str | pathlib.Path, post_tfm: Callable | None = None):
# config = config.model # need to figure out better config for models here
# self.from_config(config)
pass
# define this here instead of vak.typing to avoid circular imports
ModelSubclass = NewType('ModelSubclass', Model)
class WindowedFrameClassificationModel(Model):
def __init__(self,
network: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
loss: torch.nn.Module | dict[str: torch.nn.Module] | None = None,
optimizer: torch.optim.Optimizer | None = None,
metrics: dict[str: Callable] | None = None,
post_tfm: Callable | None = None,
):
"""A LightningModule that represents
a model that predicts a label for each frame
in a window, e.g., each time bin in
a window from a spectrogram.
This task represents one way of
predicting annotations for a vocalization,
where the annotations consist of a sequence
of segments, each with an onset, offset,
and label.
The model maps the spectrogam window
to a vector of labels for each frame, i.e.,
each time bin.
To annotate a vocalization with such a model,
the spectrogram is converted into a batch of
consecutive non-overlapping windows,
for which the model produces predictions.
These predictions are then concatenated
into a vector of labeled frames,
from which the segments can be recovered.
Post-processing can be applied to the vector
to clean up noisy predictions
before recovering the segments."""
super().__init__(network=network, loss=loss,
optimizer=optimizer, metrics=metrics)
self.lbl_tb2labels = labeled_timebins.lbl_tb2labels
self.post_tfm = post_tfm
def configure_optimizers(self):
return self.optimizer
def training_step(self, batch, batch_idx):
x, y = batch[0], batch[1]
y_pred = self.network(x)
loss = self.loss_func(y_pred, y)
return loss
def validation_step(self, batch, batch_idx):
# TODO: rename "source" -> "spect"
# TODO: a sample can have "spect", "audio", "annot", optionally other things ("padding"?)
x, y = batch["source"], batch["annot"]
# remove "batch" dimension added by collate_fn to x
# we keep for y because loss still expects the first dimension to be batch
# TODO: fix this weirdness. Diff't collate_fn?
if x.ndim == 5:
if x.shape[0] == 1:
x = torch.squeeze(x, dim=0)
else:
raise ValueError(f"invalid shape for x: {x.shape}")
out = self.network(x)
# permute and flatten out
# so that it has shape (1, number classes, number of time bins)
# ** NOTICE ** just calling out.reshape(1, out.shape(1), -1) does not work, it will change the data
out = out.permute(1, 0, 2)
out = torch.flatten(out, start_dim=1)
out = torch.unsqueeze(out, dim=0)
# reduce to predictions, assuming class dimension is 1
y_pred = torch.argmax(
out, dim=1
) # y_pred has dims (batch size 1, predicted label per time bin)
if "padding_mask" in batch:
padding_mask = batch[
"padding_mask"
] # boolean: 1 where valid, 0 where padding
# remove "batch" dimension added by collate_fn
# because this extra dimension just makes it confusing to use the mask as indices
if padding_mask.ndim == 2:
if padding_mask.shape[0] == 1:
padding_mask = torch.squeeze(padding_mask, dim=0)
else:
raise ValueError(
f"invalid shape for padding mask: {padding_mask.shape}"
)
out = out[:, :, padding_mask]
y_pred = y_pred[:, padding_mask]
if self.post_tfm:
y_pred = self.post_tfm(y_pred)
y_labels = self.lbl_tb2labels(
y.cpu().numpy(),
)
y_pred_labels = self.lbl_tb2labels(
y_pred.cpu().numpy()
)
# TODO: figure out smarter way to do this
for metric_name, metric_callable in self.metrics.items():
if metric_name == "loss":
self.log(f'val_{metric_name}', self.loss_func(out, y), batch_size=1)
elif metric_name == "acc":
self.log(f'val_{metric_name}', metric_callable(y_pred, y), batch_size=1)
elif metric_name == "levenshtein" or metric_name == "segment_error_rate":
self.log(f'val_{metric_name}', metric_callable(y_pred_labels, y_labels), batch_size=1)
def predict_step(self, batch, batch_idx: int):
x, spect_path = batch["source"].to(self.device), batch["spect_path"]
if isinstance(spect_path, list) and len(spect_path) == 1:
spect_path = spect_path[0]
if x.ndim == 5:
if x.shape[0] == 1:
x = torch.squeeze(x, dim=0)
y_pred = self.network(x)
return {spect_path: y_pred}
def windowed_frame_classification_model(modeldef: ModelDefinition) -> ModelSubclass:
"""A decorator that creates a model"""
attributes = dict(WindowedFrameClassificationModel.__dict__)
attributes.update({'definition': modeldef})
wrapped_model = type(modeldef.__name__, (WindowedFrameClassificationModel,), attributes)
# realized when testing we don't actually need next line since we adding the ModelDefinition as an attribute
# to the subclass we just made, i.e. we're not really wrapping a class here, so this doesn't make sense
wrapped_model = functools.wraps(windowed_frame_classification_model, updated=())(wrapped_model)
return wrapped_model
@windowed_frame_classification_model
class TweetyNetModel:
network = vak.nets.TweetyNet
loss = torch.nn.CrossEntropyLoss
optimizer = torch.optim.Adam
metrics = {'acc': vak.metrics.Accuracy,
'levenshtein': vak.metrics.Levenshtein,
'segment_error_rate': vak.metrics.SegmentErrorRate,
'loss': torch.nn.CrossEntropyLoss}
After all this, one can do the following to instantiate a model, without needing a config
tweetynet = vak.nets.TweetyNet(num_classes=10)
model = TweetyNetModel(network=tweetynet)
and its methods will be the underlying LightningModule methods, specifically those implemented by WindowedFrameClassificationModel
model.predict_step
<bound method WindowedFrameClassificationModel.predict_step of windowed_frame_classification_model()>
Closed by #605