Cassiopeia
Cassiopeia copied to clipboard
Failed Numbaization of Distance Function is Not Caught
When creating a wrapper around a dissimilarity function from cas.solver.dissimilarity
and applying it to a DistanceSolver
's dissimilarity_function
argument, I get a numba error.
In order to recreate the issue, I used the pip install
command from the repo's readme, and ran this python script:
## From cass.py
from typing import Dict, List, Optional
import cassiopeia as cas
import pandas as pd
import pickle as pic
import os
gt_tree_dir = "/data/yosef2/users/richardz/projects/CassiopeiaV2-Reproducibility/trees/exponential_plus_c/400cells/no_fit/char40/"
gt_tree_file = os.path.join(gt_tree_dir, "tree0.pkl")
gt_tree = pic.load(open(gt_tree_file, "rb"))
cm_file = os.path.join(gt_tree_dir, f"cm0.txt")
cm = pd.read_table(cm_file, index_col = 0)
recon_tree = cas.data.CassiopeiaTree(
character_matrix=cm,
missing_state_indicator = -1
)
def my_distance_function(
s1: List[int],
s2: List[int],
missing_state_indicator=-1,
weights: Optional[Dict[int, Dict[int, float]]] = None,
) -> float:
return cas.solver.dissimilarity.weighted_hamming_distance(
s1,
s2,
missing_state_indicator=missing_state_indicator,
weights=weights,
)
solver = cas.solver.NeighborJoiningSolver(
add_root = True,
dissimilarity_function=my_distance_function
)
solver.solve(recon_tree)
Upon running the script above, the following error pops up:
## From stderr
Traceback (most recent call last):
File "cass.py", line 38, in <module>
solver.solve(recon_tree)
File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/solver/DistanceSolver.py", line 140, in solve
dissimilarity_map = self.get_dissimilarity_map(cassiopeia_tree, layer)
File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/solver/DistanceSolver.py", line 106, in get_dissimilarity_map
self.setup_dissimilarity_map(cassiopeia_tree, layer)
File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/solver/DistanceSolver.py", line 227, in setup_dissimilarity_map
self.setup_root_finder(cassiopeia_tree)
File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/solver/NeighborJoiningSolver.py", line 264, in setup_root_finder
self.dissimilarity_function, self.prior_transformation
File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/data/CassiopeiaTree.py", line 1855, in compute_dissimilarity_map
self.missing_state_indicator,
File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/data/utilities.py", line 214, in compute_dissimilarity_map
cm, C, missing_state_indicator, nb_weights
File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/numba/core/dispatcher.py", line 468, in _compile_for_args
error_rewrite(e, 'typing')
File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/numba/core/dispatcher.py", line 409, in error_rewrite
raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'weighted_hamming_distance' of type Module(<module 'cassiopeia.solver.dissimilarity_functions' from '/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/solver/dissimilarity_functions.py'>)
File "cass.py", line 27:
def my_distance_function(
<source elided>
return cas.solver.dissimilarity.weighted_hamming_distance(
^
During: typing of get attribute at cass.py (27)
File "cass.py", line 27:
def my_distance_function(
<source elided>
return cas.solver.dissimilarity.weighted_hamming_distance(
^
During: resolving callee type: type(CPUDispatcher(<function my_distance_function at 0x7f1322269f80>))
During: typing of call at /home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/data/utilities.py (197)
During: resolving callee type: type(CPUDispatcher(<function my_distance_function at 0x7f1322269f80>))
During: typing of call at /home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/data/utilities.py (197)
During: resolving callee type: type(CPUDispatcher(<function my_distance_function at 0x7f1322269f80>))
During: typing of call at /home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/data/utilities.py (197)
During: resolving callee type: type(CPUDispatcher(<function my_distance_function at 0x7f1322269f80>))
During: typing of call at /home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/data/utilities.py (197)
File "../../../../../home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/data/utilities.py", line 197:
def _compute_dissimilarity_map(cm, C, missing_state_indicator, nb_weights):
<source elided>
dm[k] = dissimilarity_func(
s1, s2, missing_state_indicator, nb_weights
^
When inspecting the source code, I noticed that in /home/eecs/ivalexander13/datadir/Cassiopeia/cassiopeia/data/utilities.py
, there seems to be safeguards that are supposed to catch numba failures, as follows
## From utilities.py at lines 159 to 171
numbaize = True
try:
dissimilarity_func = numba.jit(dissimilarity_function, nopython=True)
except TypeError:
warnings.warn(
"Failed to numbaize dissimilarity function. Falling back to Python.",
CassiopeiaTreeWarning,
)
numbaize = False
dissimilarity_func = dissimilarity_function
## From utilities.py at lines 206 to 215
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=numba.NumbaDeprecationWarning)
warnings.simplefilter("ignore", category=numba.NumbaWarning)
_compute_dissimilarity_map = numba.jit(
_compute_dissimilarity_map, nopython=numbaize
)
return _compute_dissimilarity_map(
cm, C, missing_state_indicator, nb_weights
)
When these two snippets are changed to completely avoid using numba, the bug disappears. So I think the bug is due to the numbaization functions not working properly, and somehow bypassing the try-catch.