graphstorm
graphstorm copied to clipboard
[Draft] [Multi-task Learning] Add support for multi-task learning
Issue #, if available: #789
Description of changes:
Graph construction
Update GraphStorm input config parsing to support multi-task learning. Allow user to specify to specify multiple training tasks for a training job through yaml file. By providing the multi_task_learning
configurations in the yaml file, users can define multiple training tasks. The following config defines two training tasks, one for node classification and one for edge classification.
---
version: 1.0
gsf:
basic:
...
...
multi_task_learning:
- node_classification:
target_ntype: "movie"
label_field: "label"
mask_fields:
- "train_mask_field_nc"
- "val_mask_field_nc"
- "test_mask_field_nc"
task_weight: 1.0
- edge_classification:
target_etype:
- "user,rating,movie"
label_field: "rate"
mask_fields:
- "train_mask_field_ec"
- "val_mask_field_ec"
- "test_mask_field_ec"
task_weight: 0.5 # weight of the task
Task specific hyperparameters in multi-task learning are same as thoses in single task learning, except that two new configs are required, i.e., mask_fields and task_weight. The mask_fields provides the training, validation and test masks for the task and the task_weight gives its loss weight.
DataLoader for multi-task learning
Add GSgnnMultiTaskDataLoader to support multi-task learning.
When initializing a GSgnnMultiTaskDataLoader, users need to provide two inputs: 1) a list of config.TaskInfo objects recording the information of each task and 2) a list of dataloaders corresponding to each training task.
During training for each iteration, GSgnnMultiTaskDataLoader will iteratively call each task-dataloader to generate a mini-batch and finally return a list of mini-batches to the trainer.
The length of the dataloader (number of batches for an epoch) is determined by the largest task in the GSgnnMultiTaskDataLoader.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.