graphstorm icon indicating copy to clipboard operation
graphstorm copied to clipboard

[Draft] [Multi-task Learning] Add support for multi-task learning

Open classicsong opened this issue 9 months ago • 0 comments

Issue #, if available: #789

Description of changes:

Graph construction

Update GraphStorm input config parsing to support multi-task learning. Allow user to specify to specify multiple training tasks for a training job through yaml file. By providing the multi_task_learning configurations in the yaml file, users can define multiple training tasks. The following config defines two training tasks, one for node classification and one for edge classification.

---
version: 1.0
gsf:
  basic:
    ...
  ...
  multi_task_learning:
    - node_classification:
      target_ntype: "movie"
      label_field: "label"
      mask_fields:
        - "train_mask_field_nc"
        - "val_mask_field_nc"
        - "test_mask_field_nc"
      task_weight: 1.0
    - edge_classification:
      target_etype:
        - "user,rating,movie"
      label_field: "rate"
      mask_fields:
        - "train_mask_field_ec"
        - "val_mask_field_ec"
        - "test_mask_field_ec"
      task_weight: 0.5 # weight of the task

Task specific hyperparameters in multi-task learning are same as thoses in single task learning, except that two new configs are required, i.e., mask_fields and task_weight. The mask_fields provides the training, validation and test masks for the task and the task_weight gives its loss weight.

DataLoader for multi-task learning

Add GSgnnMultiTaskDataLoader to support multi-task learning.

When initializing a GSgnnMultiTaskDataLoader, users need to provide two inputs: 1) a list of config.TaskInfo objects recording the information of each task and 2) a list of dataloaders corresponding to each training task.

During training for each iteration, GSgnnMultiTaskDataLoader will iteratively call each task-dataloader to generate a mini-batch and finally return a list of mini-batches to the trainer.

The length of the dataloader (number of batches for an epoch) is determined by the largest task in the GSgnnMultiTaskDataLoader.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

classicsong avatar May 15 '24 20:05 classicsong