dance
dance copied to clipboard
Convention for tensor device
Problem
Currently, there is no convention for which device the data is stored when a function, e.g., fit()
, is called. For example, even though the computation device to use is cuda
, the input graph may not be on the GPU and will need to be transferred later after subsampling the neighborhoods for mini-batch training. This inconsistency causes many issues for development.
One naive solution is to call tensor.to(device)
using the correct device every time a computation is being performed, which is certainly unsatisfactory and makes the code base not as clean.
Solution
- By default, all data should sit on
cpu
- When a compute function, e.g.,
fit()
, is called, perform any necessary device conversion inside that function. - The only exceptions are
predict()
andscore()
, which can be more flexible, since they will be used in various places with various configurations, e.g., training on GPU with mini-batch training and evaluating on CPU with full-batch.
The following PR is a corresponding example of fixing the issue: #30
Need to check
grep "\.to(" -r dance | awk -F":" '{print $1}' | sort -u
- [ ] dance/datasets/multimodality.py
- [ ] dance/modules/multi_modality/joint_embedding/dcca.py
- [ ] dance/modules/multi_modality/joint_embedding/jae.py
- [ ] dance/modules/multi_modality/joint_embedding/scmogcn.py
- [ ] dance/modules/multi_modality/joint_embedding/scmogcnv2.py
- [ ] dance/modules/multi_modality/joint_embedding/scmvae.py
- [ ] dance/modules/multi_modality/match_modality/scmm.py
- [ ] dance/modules/multi_modality/match_modality/scmogcn.py
- [ ] dance/modules/multi_modality/predict_modality/babel.py
- [ ] dance/modules/multi_modality/predict_modality/scmm.py
- [ ] dance/modules/multi_modality/predict_modality/scmogcn.py
- [ ] dance/modules/single_modality/cell_type_annotation/actinn.py
- [ ] dance/modules/single_modality/cell_type_annotation/scdeepsort.py
- [ ] dance/modules/single_modality/clustering/graphsc.py
- [ ] dance/modules/single_modality/clustering/scdcc.py
- [ ] dance/modules/single_modality/clustering/scdeepcluster.py
- [ ] dance/modules/single_modality/clustering/scdsc.py
- [ ] dance/modules/single_modality/clustering/sctag.py
- [ ] dance/modules/single_modality/imputation/deepimpute.py
- [ ] dance/modules/single_modality/imputation/graphsci.py
- [ ] dance/modules/single_modality/imputation/scgnn.py
- [ ] dance/modules/spatial/cell_type_deconvo/dstg.py
- [ ] dance/modules/spatial/cell_type_deconvo/spatialdecon.py
- [ ] dance/modules/spatial/cell_type_deconvo/spotlight.py
- [ ] dance/modules/spatial/spatial_domain/spagcn.py
- [ ] dance/modules/spatial/spatial_domain/stagate.py
- [ ] dance/transforms/graph_construct.py
- [ ] dance/transforms/preprocess.py