keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

Add `from_preset` constructor to `BertPreprocessor`

Open jbischof opened this issue 3 years ago • 2 comments

Proposal

In #387 we allowed construction of a BERT model from a "preset" model architecture and weights; for example Bert.from_preset("bert_base_uncased_en"). I propose to do the same with BertPreprocessor, automatically generating the matching preprocessing functionality using the same string id: BertPreprocessor.from_preset("bert_base_uncased_en").

Preset configuration

We will need to add preprocessing metadata to each preset id. One way to do this is keeping the vocabulary information normalized in a separate struct and joining it at call time. However, we could follow the pattern in #387 more closely by inlining the preprocessing config and weights:

backbone_presets = {
    "bert_tiny_uncased_en": {
        "config": {
            "vocabulary_size": 30522,
            "num_layers": 2,
            "num_heads": 2,
            "hidden_dim": 128,
            "intermediate_dim": 512,
            "dropout": 0.1,
            "max_sequence_length": 512,
            "num_segments": 2,
        },
        "description": (
            "Tiny size of BERT where all input is lowercased. "
            "Trained on English Wikipedia + BooksCorpus."
        ),
        "weights_url": "https://storage.googleapis.com/keras-nlp/models/bert_tiny_uncased_en/model.h5",
        "weights_hash": "c2b29fcbf8f814a0812e4ab89ef5c068",
        "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bert_base_uncased_en/vocab.txt",
        "vocabulary_hash": "64800d5d8528ce344256daf115d4965e",
        "lowercase": True,
    },
}

If we add more preprocessor arguments than lowercase in the future then a separate preprocessor_config field could be added. However, we have already hardcoded most preprocessing metadata in the BertPreprocessor class. Additional information is generally added in files like vocab.txt and merges.txt rather than flat parameters, so this doesn't seem helpful at this time.

Signature

The arguments sequence_length and truncate are not determined by the preset alone and should be added to the signature. We must enforce that sequence_length is no longer than the training time setting however.

One complication is that most users will want to use the biggest sequence length that the model allows. We might want to set the default to None so that we can fill in this value unless the user has a different preference.

class BertPreprocessor(keras.layers.Layer):
    # Generic constructor
    def __init__(
        self,
        vocabulary="uncased_en",
        lowercase=False,
        sequence_length=512,
        truncate="round_robin",
        **kwargs,
    ):
        pass

    @classmethod
    def from_preset(
        cls, 
        preset,
        sequence_length=512, 
        truncate="round_robin",
        **kwargs,
    ):
    if sequence_length > backbone_presets[preset]["config"]["max_sequence_length"]:
        raise ValueError(...)

Preset namespaces

In the long term we will likely add task specific presets as well as backbone presets. Unlike model classes, Preprocessors could pertain to any preset type. Therefore long term we will have to create a single namespace of presets or explicitly check multiple namespaces and enforce uniqueness between them.

jbischof avatar Oct 13 '22 00:10 jbischof

Overall, sounds good to me! Definitely like adding BertPreprocessor.from_preset("bert_base_uncased_en"), that will really improve our end-to-end, mid-level usage pattern.

I wonder if it's a little more correct to have sequence_length=None in from_preset signature, and infer it from the preset config if we do not provide it. I just so happens that all BERT models we currently ship use a 512 sequence length, but I wouldn't want to encode that forever. The default (and maximum) sequence length a model can take is a model hyper.

Re the structure of the metadata, I think it only really makes sense to talk through lowercase in the context of exposing a preset dict directly. When we do, I would see more of an argument to keeping the config options in a dict, because we can start to really push this simple mental for our users that our presets are just a config and weights. There are more questions here (like how to handle vocabulary, which unlike weights is in the config).

But if we want to keep things incremental, maybe let's just ship any structure that is most convenient now and not worry too much? And then follow up on this when exposing presets?

mattdangerw avatar Oct 13 '22 18:10 mattdangerw

Re namespaces, I do think we are signing up for a global namespace of IDs. We probably will store presets for different uses separately out of convenience, but the namespace is global IMO. Long term, we may want to figure out a test for this.

It seems relatively easy to think of backbone_presets and classifier_presets as separate dicts, and to get the full list of preprocessing presets we just concat those two dicts together. Would that work?

If we don't want to do that, we could definitely keep a separate flat dict of preprocessing presets, and add a test that every backbone preset has a preprocessing preset. That seems slightly more clunky and merge conflict prone at scale, but generally OK as well.

mattdangerw avatar Oct 13 '22 18:10 mattdangerw