returnn icon indicating copy to clipboard operation
returnn copied to clipboard

Frontend API and PyTorch backend

Open albertz opened this issue 2 years ago • 200 comments

Edit Originally, this issue was about a proof-of-concept for a new PyTorch backend in RETURNN. This has somehow evolved into a whole new generic frontend API (original idea here: https://github.com/rwth-i6/returnn/issues/1264), which very much follows the API from RETURNN-common nn, which covers multiple backends. This API is accessible for the user via returnn.frontend, and the convention for the user to use it would be like:

import returnn.frontend as rf

class MyModel(rf.Module):
  def __call__(self, a: rf.Tensor, b: rf.Tensor):
    return rf.matmul(a, b, ...)

We also made sure that our Tensor class (earlier called Data) supports any raw tensor type (type of Tensor.placeholder, or Tensor.raw_tensor now) and is backend independent. This is all in returnn.tensor now.

Currently, the following backends are relevant:

  • RETURNN layers net dict backend for TF. The raw tensor type is NameCtx.
  • PyTorch. The raw tensor type is torch.Tensor.
  • TF directly / low-level. Raw tensor is tf.Tensor.

The terminology on frontend/backend is sometimes used a bit inconsistent. We mean frontend or frontend API basically to describe what the user sees, what you have in the returnn.frontend namespace, i.e. functions like reduce and dot, or also modules like Linear. This API is basically very much the same as RETURNN-common nn.

We have an abstract base class Backend, which defines some core functions of the API and allows to reimplement it for different backends. The derived classes are e.g. TorchBackend, ReturnnLayersBackend, TFBackend. The earlier terminology was maybe a bit confusing: They implement some frontend functions for the specific backend. So sometimes we referred to this as "different frontend (implementations)".

The Backend class and its different backend implementations is supposed to be internal to RETURNN and not directly exposed to the user. The user has some helper functions to switch the backends.

There is also a lot of code which builds on top of the backend functions. E.g. the rf.Module class, modules like rf.Linear, would all be independent from the backend. Or also functions like cross_entropy.

The user in the end needs to define the following functions in the config:

def get_model(*, epoch: int, step: int, **_unused_kwargs) -> Union[rf.Module, torch.nn.Module]:
  ...

def train_step(*, model, extern_data: TensorDict):
  ...

def forward_step(*, model, extern_data: TensorDict):
  ...

train_step would only be used in training. Here the user should call mark_as_loss on some tensors.

forward_step would be used for recognition. Here the user should call mark_as_output on some tensors. See #1336 for more discussion on this, how it would be used then for beam search, forwarding, or whatever you want to do with the outputs.

To also support PyTorch modules more directly, get_model() can also return a torch.nn.Module. See https://github.com/rwth-i6/returnn/issues/1120#issuecomment-1482924842.

To add losses when using a raw torch.nn.Module, the API inside train_step would look like:

rf.get_run_ctx().mark_as_loss(..., loss_raw_torch_tensor, ...)

But this is intended only to be used for external code, and for our own code, we recommend the use of the RF. But in any case, it will be easy to mix pure PT and RF code together (with the PT backend).

To get access to the step inside train_step, there will be sth like rf.global_train_step().

Related:

  • https://github.com/rwth-i6/returnn/issues/1264
  • https://github.com/rwth-i6/returnn_common/issues/252
  • https://github.com/rwth-i6/returnn/issues/1185
  • https://github.com/rwth-i6/returnn/issues/1165 / https://github.com/rwth-i6/returnn/pull/1261

Some summary of the current open discussions, items, or status updates:

  • rf.control_flow_ctx or not? (https://github.com/rwth-i6/returnn/issues/1288)
  • How to design re-parameterization like weight norm, pre-forward hooks? (https://github.com/rwth-i6/returnn_common/issues/250)
  • Pretraining (#1447)
  • data-apis.org Python array API standard, type promotion rules
  • check_matched_dataset for PT? (https://github.com/rwth-i6/returnn/issues/1120#issuecomment-1476556666)
  • Flash attention support
  • Easy way for the user to have custom training loops, e.g. to have custom AMP implementation, custom multi-GPU implementation, whatever. See some discussion in #1306.
  • Automatic mixed precision (AMP) support for RF (already supported for pure PT) (#1311)
  • Quantization support? (@JackTemaki)
  • How to deal with devices (#1331)
  • Multi-GPU training (#1332)
  • PyTorch ONNX export further tests (#1333)
  • AED/Transducer models, beam search (@albertz)

Done:

  • PyTorch search interface (#1336), refactored PyTorch model loading
  • seq_tag in extern_data (also see #1330)
  • RF: Accumulate Dim in rf.scan for beam search (#1327)
  • rf.top_k
  • Conformer AED model, same decoder outputs including the final label logits (https://github.com/rwth-i6/returnn/issues/1120#issuecomment-1549715526)
  • PyTorch ONNX export basic support (#1333)
  • Automatic mixed precision (AMP) support for pure PT (#1334)
  • rf.cond, rf.while_loop, rf.scan (https://github.com/rwth-i6/returnn/issues/1282)
  • LSTM
  • Conformer (https://github.com/rwth-i6/returnn/issues/1120#issuecomment-1479156121)
  • Relative positional encoding
  • Convolution / pooling: rf.Conv1d, rf.max_pool1d, etc
  • rf.BatchNorm, rf.LayerNorm and other normalizations
  • rf.SelfAttention, rf.dot_attention
  • rf.State as base (https://github.com/rwth-i6/returnn/issues/1120#issuecomment-1478108726)
  • rf.dropout
  • rf.cond (minimal initial support) (see #1282)
  • rf.get_run_ctx().train_flag
  • RF works now both in TF engine (via TF-net-dict backend currently) and PT engine.
  • Initial support to mix PT code/modules with RF code/modules (https://github.com/rwth-i6/returnn/issues/1287)
  • We now have rf.cross_entropy.
  • Basic models work now in both PT and TF-net-dict backends, e.g. rf.Linear, some activation functions.
  • Implement RF user entry points (get_model etc) for PT and TF (@albertz), also #1290
  • RETURNN PyTorch-specific extensions (RPE), modules/functions which make use of Tensor/Dim? (https://github.com/rwth-i6/returnn/issues/1120#issuecomment-1483250726) -> for now, we continue and focus on RF development, and the RF itself can be used as PyTorch extensions
  • rf.random in tests (https://github.com/rwth-i6/returnn/issues/1283)
  • Support inplace random init when possible (#1299)
  • Do we need to support TorchScript graph capture (torch.jit.script/torch.jit.trace) for ONNX? Tracing (torch.jit.trace) should anyway always work, except it would not cover dynamic control flow. torch.compile for ONNX will be the future but does not work yet? Maybe also OpenXLA instead of ONNX. Any other reasonable way to get PT models running within RASR? We can also simply use the RASR-Python bridge. (https://github.com/rwth-i6/returnn/issues/1289)
  • preload_from_files (https://github.com/rwth-i6/returnn/pull/1292)
  • We have basic PT support, wrapping RETURNN datasets (ReturnnDatasetIterDataPipe), chunking, batching, using DataLoader2.

This issue here is also to discuss and report on implementation details of the new PyTorch backend.

The initial step would be to just get a proof-of-concept, meaning the goals currently are:

  • Get some experience with PyTorch, some better ideas how to integrate PyTorch into RETURNN, etc.
  • Get training and inference with some basic model, using existing RETURNN datasets. It would be some hybrid NN-HMM for ASR. We can take existing PyTorch code, for example some Conformer encoder from ESPnet.

Getting it compatible to the TF backend is maybe the ultimate goal, but this is on a completely different level than the proof-of-concept discussed here.

For the dataset, we would implement a PyTorch IterableDataset to wrap any RETURNN dataset.

albertz avatar Sep 12 '22 13:09 albertz

The current goal is that you can run python3 rnn.py demos/demo-torch.config and it starts the training.

albertz avatar Sep 12 '22 14:09 albertz

Note that I created the config option backend, so you do backend = "torch" in the config to enable the PyTorch backend.

I also created an initial returnn.torch package/directory, with a dummy Engine in engine.py.

Feel free to create a data_pipeline.py file when you start working on implementing the dataset and related code.

albertz avatar Sep 12 '22 14:09 albertz

Btw, as we work in master, please make sure the tests are passing. There is no PyTorch test at all yet, so the only test relevant for us currently is about code inspections, like PEP8 etc. Please double check that there are no warnings on the code before you commit.

For commit messages, maybe prefix them with "PyTorch" or "Torch" or so.

albertz avatar Sep 12 '22 14:09 albertz

For the main train loop, see the PyTorch quickstart tutorial (slightly simplified, adapted):

device = "cuda" if torch.cuda.is_available() else "cpu"

class NeuralNetwork(nn.Module):
  ...

model = NeuralNetwork().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

def train(dataloader, model, loss_fn, optimizer):
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    ...

# epoch loop
epochs = 5
for t in range(epochs):
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)

albertz avatar Sep 12 '22 14:09 albertz

So, in the RETURNN Engine, we would also have the same to loop over the batches:

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

We also would have a such a model instance. This is somewhat like the root TFNetwork, or the root module in returnn-common. It is necessary to have such a root module, to have unique param names. See corresponding discussion in returnn-common.

As I understand the PyTorch code, I think we need to have one single loss in the end, such that we have a single loss.backward() call. I think multiple backward() calls would lead to backprop multiple times and this would be stupid. But this is not really a problem: We just need to do the summation of all losses you potentially have. We also do this for TF/Theano.

Maybe, like in returnn-common, we can have the same mark_as_loss API. So, some initial config API suggestions:

def get_model(*, epoch: int, **_unused_kwargs) -> torch.nn.Module:

Would create the model instance. I'm not sure if we really need to or should pass extern_data here, as you should have all the relevant information anyway in the config. But maybe it makes it simpler to write it? Originally I also thought about just providing class Model(torch.nn.Module), with the API that Model() should work.

def train_step(*, extern_data: ExternData, model: Model, ctx: TrainCtx):

Here you would do:

        pred = model(extern_data.data["inputs"].placeholder)
        loss = F.cross_entrop(pred, y)
        ctx.mark_as_loss(loss)

The TrainCtx provides this mark_as_loss method, which would just collect the losses. Maybe together with a name. Then in the engine, the train loop could look sth like this:

    model.train()
    ctx = TrainCtx()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        extern_data.set(X, y)  # ...

        pred = train_step(extern_data=extern_data, model=model, ctx=ctx)
        loss = ctx.total_loss()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

albertz avatar Sep 12 '22 15:09 albertz

I'm implementing a dataset wrapper. The code is still too ugly for a pull request, so first a comment: 😄

The general interface is clear:

from torch.utils.data import IterableDataset


class DatasetWrapper(IterableDataset):
  def __init__(self, returnn_dataset):
    ...
    
  def __iter__(self):
    ...

And then in engine.train() something like:

from torch.utils.data import DataLoader

train_data = DatasetWrapper(self.train_dataset)

data_loader = DataLoader(train_data)

for batch_index, batch_data in enumerate(data_loader):
  ...

Most intuitive would be to let DatasetWrapper.__iter__() return a generator over single sequences and then do batching via the arguments to DataLoader. However, for IterableDataset the batch_sampler and collate_fn arguments are not available. You can set DataLoader(train_data, batch_size=42), but this would really just put a constant amount of sequences into a batch, not constant amount of time frames and other more sophisticated logic we want to have. So I think, DatasetWrapper.__iter__() already has to provide the batches. I'm currently trying to figure out how to best reuse the existing batching code, so Dataset.generate_batches(), BatchSetGenerator, FeedDictDataProvider.get_next_batch() etc. Will continue tomorrow... 😄

patrick-wilken avatar Sep 14 '22 21:09 patrick-wilken

@albertz, it looks to me that having a separate loss function in the training loop is not strictly necessary. Those available loss functions output normal Tensors. So you could have the loss as part of the Module if you want "all calculations to be equal". But that would rather be suitable when defining a network dict, not so much when loading an existing Module, because people normally don't do it that way...

patrick-wilken avatar Sep 14 '22 21:09 patrick-wilken

No, the IterableDataset is supposed to return individual sequences, not batches.

I think the DataLoader is supposed to do that. Yes, I have read somewhere that it is a bit limited in functionality in that it only supports fixed size batches. But I'm sure there are other solutions, other DataLoaderExt implementations or whatever. Maybe just check how other frameworks like Fairseq have done this.

albertz avatar Sep 14 '22 22:09 albertz

it looks to me that having a separate loss function in the training loop is not strictly necessary.

I never said that? See my suggestion on how to do it, i.e. using the TrainCtx.

So you could have the loss as part of the Module

This is now totally up to the user. You could either totally decouple it, i.e. your model just is the model, without losses, and then you have separate code to calculate the losses. Or you could mix it together, define the losses already within the model (when you pass TrainCtx to it). In returnn-common, we have the same situation now.

albertz avatar Sep 14 '22 22:09 albertz

Btw, I don't like having PRs in this early phase of development. PRs just slow it down. I think we can all directly work in the master. PRs are there to avoid breaking anything, and to discuss and review the code. We cannot really break anything, as it is not functional anyway. And for discussion, we can also discuss right here. And for review, you should feel responsible to do that anyway.

If sth is unclear, better discuss here before starting to implement anyway. If we have some rough understanding on how to implement it, we should just start, and then iterate on the code, in master directly.

albertz avatar Sep 14 '22 22:09 albertz

I checked the implementation of datasets for Fairseq and k2 (lhotse repo). Actually, both inherit from torch.utils.data.Dataset as opposed to torch.utils.data.IterableDataset, so both are map-style datasets. I've seen some tendency from the community to choose map-style datasets as opposed to iterable-style datasets even for big datasets, for instance here and here, the arguments being that the DataLoader can handle the batches and iteration over the dataset (or one can do it in a custom way, see k2 below), memory management can still be efficient with a Dataset and additional things have to be taken care of when using IterableDataset (see here, example 2).

k2

The __getitem__() method of the K2SpeechRecognitionDataset class directly returns a batch, which could allow for "custom" batches. Because of this, the DataLoader doesn't actually compute any batches, so when declaring it, batch_size must be equal to None. They also use a custom sampler defined here. An example usage as shown in the last link is:

dataset = K2SpeechRecognitionDataset(cuts)
sampler = SimpleCutSampler(cuts, shuffle=True)  # cuts == input data + text + potentially useful data
loader = DataLoader(dataset, sampler=sampler, batch_size=None)
for epoch in range(start_epoch, n_epochs):
    sampler.set_epoch(epoch)  # Only for shuffling purposes
    train(loader)

I wasn't able to find the definition of __len__() for that class though. Maybe if the data is never indexed, its definition can be avoided?

Fairseq

For Fairseq I focused on the SpeechToTextDataset subclass of FairseqDataset. The __getitem__() method returns an object of the class SpeechToTextDatasetItem, which also has the source audio and the target text. The collater() method is in charge of creating the batches. I assume it's the collate_fn passed to the DataLoader (maybe through *EpochBatchIterator classes), but I couldn't find any evidence of it.

The __len__() method simply returns the number of audios, calculated when the dataset is initialized.

In the LibriSpeech data preparation example, the data is processed and saved to a .tsv file, which is then loaded in the from_tsv() method when executing $ fairseq-train.

Fairseq also has a FairseqIterableDataset class which inherits from torch.utils.data.IterableDataset, but it doesn't seem to be used anywhere.

Icemole avatar Sep 15 '22 10:09 Icemole

In general, I would trust the Fairseq developers more, esp w.r.t. how to use PyTorch in the best way. So, let's follow how they do it.

albertz avatar Sep 15 '22 10:09 albertz

So this means that we should implement DatasetWrapper as inheriting from torch.utils.data.Dataset instead of torch.utils.data.IterableDataset?

I have found two short tutorials on how to use a big enough dataset: this short tutorial from the Stanford University on efficiently obtaining data to the dataset from disk by using a torch.utils.data.Dataset, and this Medium post whose comments I think have good insight. I leave them here in case they're useful in any way.

Icemole avatar Sep 15 '22 12:09 Icemole

When wrapping the RETURNN dataset, I don't think any of these tutorials can be applied. You can just use the RETURNN Dataset API, nothing else.

The Torch map-based dataset unfortunately does not fit too well to RETURNN as most RETURNN datasets do not really allow for efficient random access. However, I think HDFDataset actually should be fine. I think if it works with that, this is ok for now.

We can also do both torch.utils.data.Dataset and torch.utils.data.IterableDataset and then the user can decide. I'm still sure that even for torch.utils.data.IterableDataset, you can do everything what we need.

Remember that for the beginning, we just want to have some proof-of-concept. It's totally ok if we have static batch sizes for that.

albertz avatar Sep 15 '22 12:09 albertz

The interface of RETURNN datasets is iterable style. I would first wrap that and Dataset instead of IterableDataset wouldn't really fit here. Maybe later a map-style wrapper for HDFDataset or something like that would be nice though.

As Nahuel said, Fairseq does not use IterableDataset. I found one usage in huggingface/transformers, there they use one instance to provide sequences and then another instance which wraps the first one and does the batching: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L851 That sounds like a good idea to me.

patrick-wilken avatar Sep 19 '22 11:09 patrick-wilken

I moved the init_seq_order into the DatasetWrapper.__iter__. I think this is cleaner.

albertz avatar Sep 21 '22 09:09 albertz

python3 rnn.py demos/demo-torch.config runs now. I.e. it constructs the net, calcs and optimizes the loss.

Many things are missing. The model is always random in each epoch. No storing/saving implemented yet.

albertz avatar Sep 21 '22 11:09 albertz

In this current state, I don't use ExternData. Probably we want to introduce this. Not sure.

But for getting a first proof-of-concept, we maybe also don't need it yet.

albertz avatar Sep 21 '22 11:09 albertz

Current state:

  • PR #1137, done.
  • We discussed that it is probably simpler for everyone to not open PRs but directly push to master, as long as the code changes is just about the PyTorch backend and does not touch any other code.
    • We should still discuss about how to implement things, better before we actually do it, at least on a high level. We can discuss right here in this issue.
    • We should still review the code changes made by others. In case there are problems, these can be discussed here as well.
    • We don't want to break the CI tests, i.e. the code style must be correct. For that, we either can just use PyCharm for editing, or run the code style checks locally. @patrick-wilken had some issues with that. In case these do not resolve, let's open a separate issue about that and discuss there about the problem.
    • Once you touch other code, e.g. refactor some TF code (e.g. move out Data/ExternData or whatever), this must be via PR.
  • We want a hybrid NN-HMM for now. Either a multi-layer BLSTM (via torch.nn.LSTM) or a Conformer, just using code from ESPnet.
  • For inference, we want to export to ONNX. It's a bit unclear how that looks like for the user. We might use extern_data to define the inputs, and have a new model_outputs just like extern_data to define the model outputs (as Data templates), and a model_forward function, which gets those inputs and returns such outputs.

albertz avatar Oct 17 '22 11:10 albertz

For inference, we want to export to ONNX. It's a bit unclear how that looks like for the user. We might use extern_data to define the inputs, and have a new model_outputs just like extern_data to define the model outputs (as Data templates), and a model_forward function, which gets those inputs and returns such outputs.

This basically means that we should also abstract ExternData, Data and Dim and remove TF specific code there and move it to a common location. -> #1165 (Done now)

Such model_outputs can also be useful for the TF backends, e.g. for the compile-TF-graph script, to clearly define what outputs are expected. -> #1166

albertz avatar Oct 21 '22 12:10 albertz

I'm working on replicating a BLSTM from an already existing Tensorflow BLSTM. To replicate that experiment, in order of priority I think we'd need:

  1. Learning rate scheduling + learning rate control (newbob, ...)
  2. Chunking
  3. Specaugment
  4. Gradient noise (helper: https://discuss.pytorch.org/t/how-to-add-gradient-noise/158298/3)

I think that with the basics that we have I should be able to kickstart a basic training, though.

A list of things that would be handy for future trainings:

  1. Regularization: L2, dropout...
  2. Usage of extern_data As far as I know, L2 is implemented as the weight_decay parameter in the optimizers. However, it looks like it also affects layers like BatchNorm? https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994. It doesn't need to be applied to some layers and one can filter them out, like so: https://discuss.pytorch.org/t/weight-decay-only-for-weights-of-nn-linear-and-nn-conv/114348

Icemole avatar Oct 24 '22 10:10 Icemole

Note, I was thinking a bit about how to go forward later with the PyTorch backend of RETURNN, and whether to support the net dict, or how else we should do it. I think what we have in RETURNN-common, the nn API, could be also a more direct API for it. See https://github.com/rwth-i6/returnn/issues/1264.

albertz avatar Oct 26 '22 09:10 albertz

I heard good things about PyTorch-lightning. One important feature is that it easily supports multi-GPU training, where you don't really need to change anything on the user code.

As we already discussed, PyTorch for RETURNN should be modular, such that the user can easily combine it with PyTorch-lightning or other things.

However, when the user just uses the standard RETURNN training loop, we should maybe copy some of the aspects of PyTorch-lightning which allow for this easy multi-GPU training support. Someone should do some research on how this is done in PyTorch-lightning.

albertz avatar Oct 26 '22 09:10 albertz

I implemented and pushed dev set evaluation and learning rate scheduling. I just reused the existing LearningRateControl code, so all different scheduling options should be supported. For now I only calculate the "score" and not the "error" of the eval sets as that's what the train_step function calculates at the moment. (We could rename it if we keep using it for eval too...). And normalization of the loss could be improved: right now I just average over the steps, which is what is done in the PyTorch "Quickstart" example. Better would be to weight the steps by total number of time steps, that's what the TF backend does...

patrick-wilken avatar Oct 31 '22 18:10 patrick-wilken

I just noticed that another thing still missing is keeping the optimizer state. Both between epochs - currently we recreate it in every epoch - as well as saving the state to file to continue training. Will look into that on Wednesday. In general, do we want to make the kind of optimizer configurable or just always use AdamW for now?

patrick-wilken avatar Oct 31 '22 18:10 patrick-wilken

For now I only calculate the "score" and not the "error" of the eval sets as that's what the train_step function calculates at the moment.

This is wrong. train_step only returns the total loss, only for the gradient computation. We must change the interface to support individual losses (scores, and maybe also errors), and also to be able to properly do normalization and accumulation.

albertz avatar Nov 01 '22 12:11 albertz

I just noticed that another thing still missing is keeping the optimizer state. Both between epochs - currently we recreate it in every epoch

Yes this should be created only once, after the model is created.

as well as saving the state to file to continue training.

We don't even do that for TF. (Intentionally, to keep the behavior the same to Theano...)

In general, do we want to make the kind of optimizer configurable or just always use AdamW for now?

At some point it should be configurable (obviously). But if your question is whether this is now very urgent and important for this proof-of-concept, then I guess no?

albertz avatar Nov 01 '22 12:11 albertz

So then we need to move the optimizer creation to init_train_from_config(). And the learning rate would be controlled by a torch.optim.lr_scheduler.LRScheduler instance, at least that's the standard approach. Theres no built-in LRScheduler which is designed to set an adaptive learning rate from outside. Should be doable though with LambdaLR. Of course we could also use ReduceLROnPlateau and let PyTorch do the scheduling, but then we would lose all the fine-grained config options that we have for LR scheduling. Writing a custom LRScheduler also seems simple, but I would rather stick to the standard API. Then we could for example also make the LRScheduler modular, i.e. user configurable at some point.

patrick-wilken avatar Nov 11 '22 16:11 patrick-wilken

I think we should use just our own LR scheduler, just exactly following what we have in the TF engine.

albertz avatar Nov 12 '22 02:11 albertz

Next steps:

  • Support multiple losses (extend TrainCtx.mark_as_loss(name, loss, scale=1)), save it properly in the learning rate control scores (@patrick-wilken)
  • Chunking, create ChunkingDataset (PyTorch dataset) (@patrick-wilken)
  • Optimizer, create only once, not per epoch. But then also update learning rate correctly in each epoch. See here. (@Icemole)
  • Optimize, save and load state. Separate to the checkpoint. (@Icemole)
  • Optimizer, allow to configure, similar as TF, like optimizer being a string or dict. For AdamW, have a variant KarpathyAdamW or so, see Karpathy MinGPT here, exclude params for weight decay. Also allow optimizer as callable to allow the user to create any custom optimizer with any custom settings.
  • Add support for dynamic_learning_rate at some point (not so important now though).

albertz avatar Nov 14 '22 11:11 albertz