TorchFort icon indicating copy to clipboard operation
TorchFort copied to clipboard

Add training/inference functions that support multiple input/output arguments and support for custom loss functions via TorchScript.

Open romerojosh opened this issue 1 year ago • 0 comments

This PR introduces more flexible training/inference functions, torchfort_train_multiarg and torchfort_inference_multiarg, designed to support models that require multiple input/output tensor arguments. To facilitate passing multiple tensors from Fortran/C, we introduce a new type, torchfort_tensor_list_t, that are used as arguments to these new routines.

To complete the support for multiple output tensors, we require loss functions that can work with multiple outputs/predictions from the models. The current built-in loss functions expect single prediction/label pairs and are not compatible with multiple model outputs. To enable this, we enable support for custom loss functions exported from TorchScript, similar to our existing support of TorchScript exported models. To maximize flexibility, we include support for custom loss functions with extra tensor arguments (e.g. a mask tensor to ignore data entries in loss computation) via an optional extra_loss_args tensor list argument in torchfort_train_multiarg.

Please refer to the new Graph Neural Network example added in this PR and new documentation for more details.

This PR also adds new testing for supervised training/inference and loss functions, cover both the newly added and existing functionality.

romerojosh avatar Oct 03 '24 22:10 romerojosh