NeMo icon indicating copy to clipboard operation
NeMo copied to clipboard

Consider making the decoding classes subclasses of torch.nn.Module and registering them via add_module()

Open galv opened this issue 1 year ago • 8 comments

Is your feature request related to a problem? Please describe.

@artbataev pointed out to me an issue with calling this when the model is using the cuda graph rnn-t decoder:

from nemo.collections.asr.models import ASRModel
import torch

asr_model = ASRModel.from_pretrained(model_name, map_location=torch.device("cuda:0"))
asr_model = asr_model.to(torch.device("cuda:1"))

Basically, if I initialize the cuda decoder with instance variable buffers (or parameters) that are torch.Tensors on device cuda:0, the "to()" method won't move them over to device cuda:1, because to() recurses only into members that are also torch.nn.Modules: https://github.com/pytorch/pytorch/blob/c3b4d78e175920141de210f44d292971d7c52ff0/torch/nn/modules/module.py#L572

In order to make that behavior work, I would need to every class that transitively uses an instance of RNNTGreedyDecodeCudaGraph to inherit from torch.nn.Module, so that to() will act properly. This seems like a lot of work, but would allow this API call to work as expected. Otherwise, you get an error as Vladimir shows here: https://github.com/NVIDIA/NeMo/pull/8191#discussion_r1491227311

.to() is being called here transitively https://github.com/NVIDIA/NeMo/blob/21990e4e2182cfa1f33d88c21ec4797f79e3ef24/nemo/core/connectors/save_restore_connector.py#L179 by setup_model() here: https://github.com/NVIDIA/NeMo/blob/21990e4e2182cfa1f33d88c21ec4797f79e3ef24/examples/asr/transcribe_speech.py#L217-L241

in his transcribe_speech.py command line.

Basically any code in NeMo that allocates a torch.tensor that isn't tracked by pytorch's tracking of torch.Tensors (via wrapping it in either a parameter or a buffer and putting it inside a torch.nn.Module) will fail to be converted properly by to(). This could also cause subtle bugs in converting a model from float32 to bfloat16 as well. @erastorgueva-nv I don't think this is what you might be seeing in your Canary debugging, but FYI.

I ultimately don't think this is a good idea since it seems like a lot of work. I think a better way to fix this is to avoid calling to() at all, and instead have users set CUDA_VISIBLE_DEVICES environment variable appropriately before starting a process (and remove the cuda=1 config option in transcribe_speech.py). Note that using CUDA_VISIBLE_DEVICES to specify a single GPU and simply using torch.device("cuda") instead of specifying a device index in code via torch.device("cuda", my_index) is what's recommend: https://pytorch.org/docs/stable/generated/torch.cuda.set_device.html#torch.cuda.set_device

Do we have any use for multiple cuda devices in a single process in NeMo?

galv avatar Feb 15 '24 23:02 galv

There's a reason those decoding classes aren't modules. When rnnt was being developed it turned out there was a memory leak due to pt graph tracking and Autoregressive calls to the decoder joint. That's why .freeze(), .as_frozen() was developed in NeMos core to deal with that.

I dunno if that reason is still valid in pytorch after the inference_mode() decorator was added, maybe safe to make them modules again.

As you've mentioned, it is a lot of work to do this change, and I don't feel it is super high priority, but if someone wants to take a look that's fine.

To get device inside the decoding classes, the preferred mechanism is mostly next(module.parameters()).device and cache that value inside the inner forwards.

CUDA_VISIBLE_DEVICES environment variable appropriately before starting a process (and remove the cuda=1

This is not a good user experience and I want to avoid it. Users should not have to use any env flags for use Nemo ASR as a necessity.

Do we have any use for multiple cuda devices in a single process in NeMo?

Well if you're training, yes, but it's handled by ptl. During inference, we stick to single GPU for now, but with larger models we may consider doing multi GPU in the future.

titu1994 avatar Feb 15 '24 23:02 titu1994

I frequently use in the notebooks model.to("cuda:1"), and expect this to work (I have multiple GPUs). It is inconvenient to use CUDA_VISIBLE_DEVICES with notebooks.

I think that there are two possible approaches:

  • if the decoder wants to store anything on the GPU (or MPS), it should be responsible for auto-transferring self-state to the device. This is due to Infer classes, which are not nn.Module instances (but have Joint/Prednet as submodules). In this case, everything will work, but the code's author should care about this.
  • redesign the RNN-T in the following way using nn.Modules
    • DecodingWrapper is an nn.Module instance and owns Joint and Prednet
    • decoding strategies inherit from DecodingWrapper (e.g., BatchedGreedyDecodingWrapper)
    • the core RNN-T model does not have joint/prednet members anymore, only through DecodingWrappers subclasses
    • so, everything inherits nn.Module, no duplication of ownership of joint/prednet, buffers are automatically moved to the appropriate device
    • changing the decoding strategy implies changing a subclass of DecodingWrapper (transferring Joint/Prednet from the current wrapper)

However, such a redesign will break the current checkpoints and will require a lot of work and testing.

artbataev avatar Feb 16 '24 10:02 artbataev

It's a few lines to manually manage device type from the encoder, decoder or joints param dtype, we shouldn't have such breaking changes for this

titu1994 avatar Feb 16 '24 10:02 titu1994

On top of this, yes the fact that parameter count would be double counted for decoder and joint with such a redesign is also bad.

titu1994 avatar Feb 16 '24 10:02 titu1994

I agree that, for now, such breaking changes are undesirable.

On top of this, yes the fact that parameter count would be double counted for decoder and joint with such a redesign is also bad Nope, the concept aim is exactly to avoid parameter duplication.

Currently, rnnt model owns:

  • encoder
  • joint
  • prediction network
  • *Infer (not nn.Module), owns
    • joint (duplicate)
    • prediction network (duplicate)

Proposed, rnnt model owns:

  • encoder
  • *DecodingWrapper, owns
    • joint
    • prediction network

artbataev avatar Feb 16 '24 12:02 artbataev

That's just bad design, why merge the transcription and prediction network when in all literature it's denoted as separate modules. Also, you don't always call the prednet and decoder network with the same set of inputs (train time prepends blank, eval time it starts with blank as first token for autoregressive decoding).

This is not a viable proposal in my opinion

titu1994 avatar Feb 16 '24 14:02 titu1994

I don't really get what's the big issue with the decoding framework not being a neural module. It's responsibility is not to act as a parameter based operation on NN network as part of the forward, it's a agnostic layer that provides a stable interface (Hypothesis) to map the enc Dec Joint logprobs to text. It's a logical separation, not a module dependency.

I understand that certain issues can arise due to current design, that's cause we're using more advanced design. However very simple solution to this exists which Daniel has already implemented and is trivial to do (and also recommended by pytorch btw - to base new tensors on the device of the current active ones that it will interact with).

I don't see why such a thing requires a refactoring of the deciding framework. If there's a bug we fix it, we don't scrap the entire thing and do it all over cause of "pytorch design pattern" (which Nemo does not fully follow by design, we instead use PTL design pattern)

titu1994 avatar Feb 16 '24 14:02 titu1994

My overall conclusion is that I have encountered an edge case, and the right approach is just to recreate the appropriate state tensors anytime that the device of the input tensors changes from what was being used before.

galv avatar Feb 16 '24 19:02 galv

Closing this. The better way is lazy initialization given how nemo currently is.

galv avatar Feb 28 '24 23:02 galv