snowfall icon indicating copy to clipboard operation
snowfall copied to clipboard

Idea for pre-training / separating models..

Open danpovey opened this issue 4 years ago • 2 comments

I have an idea how we could organize our models to make it easier to deal with pre-training. We could divide our model into "lower" and "upper" parts. The idea is that we could train, let's say, a "lower" part that's a 1-d CNN with an "upper" part that's an LSTM, and then train that 1-d CNN together with a conformer network as the "upper" part. Any ideas on how best to organize the code to make this possible? E.g. regarding the AcousticModel interface?

danpovey avatar Mar 15 '21 07:03 danpovey

I'm a bit familiar with the huggingface/transformers library, the way they do it is by matching the module names, and loading the weights from the checkpoint only for the matching modules. For example:

class NetworkToTransferFrom(nn.Module):
    def __init__(self, ...):
        self.cnn = CnnEncoder(...)
        self.lstm = nn.LSTM(...)
        ...
class NetworkToTransferInto(nn.Module):
    def __init__(self, ...):
        self.cnn = CnnEncoder(...)
        self.conformer = Conformer(...)
        ...

You write the checkpoint loading logic so that cnn is loaded, and it displays a message that lstm was ignored as it's not found in the network, and conformer is randomly initialized because its weights were not in the checkpoint. The code responsible for that starts here: https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_utils.py#L1062

pzelasko avatar Mar 15 '21 14:03 pzelasko

Cool, I may try this kind of thing..

On Mon, Mar 15, 2021 at 10:02 PM Piotr Żelasko @.***> wrote:

I'm a bit familiar with the huggingface/transforms library, the way they do it is by matching the module names, and loading the weights from the checkpoint only for the matching modules. For example:

class NetworkToTransferFrom(nn.Module): def init(self, ...): self.cnn = CnnEncoder(...) self.lstm = nn.LSTM(...) ...

class NetworkToTransferInto(nn.Module): def init(self, ...): self.cnn = CnnEncoder(...) self.conformer = Conformer(...) ...

You write the checkpoint loading logic so that cnn is loaded, and it displays a message that lstm was ignored as it's not found in the network, and conformer is randomly initialized because its weights were not in the checkpoint. The code responsible for that starts here: https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_utils.py#L1062

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/snowfall/issues/128#issuecomment-799444920, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO5KLO2UVLWWUPR25U3TDYHQ5ANCNFSM4ZGA327Q .

danpovey avatar Mar 15 '21 14:03 danpovey