dance icon indicating copy to clipboard operation
dance copied to clipboard

Convention for tensor device

Open RemyLau opened this issue 2 years ago • 0 comments

Problem

Currently, there is no convention for which device the data is stored when a function, e.g., fit(), is called. For example, even though the computation device to use is cuda, the input graph may not be on the GPU and will need to be transferred later after subsampling the neighborhoods for mini-batch training. This inconsistency causes many issues for development.

One naive solution is to call tensor.to(device) using the correct device every time a computation is being performed, which is certainly unsatisfactory and makes the code base not as clean.

Solution

  • By default, all data should sit on cpu
  • When a compute function, e.g., fit(), is called, perform any necessary device conversion inside that function.
  • The only exceptions are predict() and score(), which can be more flexible, since they will be used in various places with various configurations, e.g., training on GPU with mini-batch training and evaluating on CPU with full-batch.

The following PR is a corresponding example of fixing the issue: #30

Need to check

grep "\.to(" -r dance | awk -F":" '{print $1}' | sort -u
  • [ ] dance/datasets/multimodality.py
  • [ ] dance/modules/multi_modality/joint_embedding/dcca.py
  • [ ] dance/modules/multi_modality/joint_embedding/jae.py
  • [ ] dance/modules/multi_modality/joint_embedding/scmogcn.py
  • [ ] dance/modules/multi_modality/joint_embedding/scmogcnv2.py
  • [ ] dance/modules/multi_modality/joint_embedding/scmvae.py
  • [ ] dance/modules/multi_modality/match_modality/scmm.py
  • [ ] dance/modules/multi_modality/match_modality/scmogcn.py
  • [ ] dance/modules/multi_modality/predict_modality/babel.py
  • [ ] dance/modules/multi_modality/predict_modality/scmm.py
  • [ ] dance/modules/multi_modality/predict_modality/scmogcn.py
  • [ ] dance/modules/single_modality/cell_type_annotation/actinn.py
  • [ ] dance/modules/single_modality/cell_type_annotation/scdeepsort.py
  • [ ] dance/modules/single_modality/clustering/graphsc.py
  • [ ] dance/modules/single_modality/clustering/scdcc.py
  • [ ] dance/modules/single_modality/clustering/scdeepcluster.py
  • [ ] dance/modules/single_modality/clustering/scdsc.py
  • [ ] dance/modules/single_modality/clustering/sctag.py
  • [ ] dance/modules/single_modality/imputation/deepimpute.py
  • [ ] dance/modules/single_modality/imputation/graphsci.py
  • [ ] dance/modules/single_modality/imputation/scgnn.py
  • [ ] dance/modules/spatial/cell_type_deconvo/dstg.py
  • [ ] dance/modules/spatial/cell_type_deconvo/spatialdecon.py
  • [ ] dance/modules/spatial/cell_type_deconvo/spotlight.py
  • [ ] dance/modules/spatial/spatial_domain/spagcn.py
  • [ ] dance/modules/spatial/spatial_domain/stagate.py
  • [ ] dance/transforms/graph_construct.py
  • [ ] dance/transforms/preprocess.py

RemyLau avatar Oct 07 '22 00:10 RemyLau