Cassiopeia icon indicating copy to clipboard operation
Cassiopeia copied to clipboard

Failed Numbaization of Distance Function is Not Caught

Open ivalexander13 opened this issue 2 years ago • 1 comments

When creating a wrapper around a dissimilarity function from cas.solver.dissimilarity and applying it to a DistanceSolver's dissimilarity_function argument, I get a numba error.

In order to recreate the issue, I used the pip install command from the repo's readme, and ran this python script:

## From cass.py
from typing import Dict, List, Optional
import cassiopeia as cas
import pandas as pd
import pickle as pic
import os

gt_tree_dir = "/data/yosef2/users/richardz/projects/CassiopeiaV2-Reproducibility/trees/exponential_plus_c/400cells/no_fit/char40/"
gt_tree_file = os.path.join(gt_tree_dir, "tree0.pkl")
gt_tree = pic.load(open(gt_tree_file, "rb"))

cm_file = os.path.join(gt_tree_dir, f"cm0.txt")
cm = pd.read_table(cm_file, index_col = 0)

recon_tree = cas.data.CassiopeiaTree(
    character_matrix=cm, 
    missing_state_indicator = -1
    )

def my_distance_function(
    s1: List[int],
    s2: List[int],
    missing_state_indicator=-1,
    weights: Optional[Dict[int, Dict[int, float]]] = None,
) -> float:

    return cas.solver.dissimilarity.weighted_hamming_distance(
        s1,
        s2,
        missing_state_indicator=missing_state_indicator,
        weights=weights,
    )

solver = cas.solver.NeighborJoiningSolver(
    add_root = True, 
    dissimilarity_function=my_distance_function
    )

solver.solve(recon_tree)

Upon running the script above, the following error pops up:

## From stderr
Traceback (most recent call last):
  File "cass.py", line 38, in <module>
    solver.solve(recon_tree)
  File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/solver/DistanceSolver.py", line 140, in solve
    dissimilarity_map = self.get_dissimilarity_map(cassiopeia_tree, layer)
  File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/solver/DistanceSolver.py", line 106, in get_dissimilarity_map
    self.setup_dissimilarity_map(cassiopeia_tree, layer)
  File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/solver/DistanceSolver.py", line 227, in setup_dissimilarity_map
    self.setup_root_finder(cassiopeia_tree)
  File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/solver/NeighborJoiningSolver.py", line 264, in setup_root_finder
    self.dissimilarity_function, self.prior_transformation
  File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/data/CassiopeiaTree.py", line 1855, in compute_dissimilarity_map
    self.missing_state_indicator,
  File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/data/utilities.py", line 214, in compute_dissimilarity_map
    cm, C, missing_state_indicator, nb_weights
  File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/numba/core/dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/numba/core/dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'weighted_hamming_distance' of type Module(<module 'cassiopeia.solver.dissimilarity_functions' from '/home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/solver/dissimilarity_functions.py'>)

File "cass.py", line 27:
def my_distance_function(
    <source elided>

    return cas.solver.dissimilarity.weighted_hamming_distance(
    ^

During: typing of get attribute at cass.py (27)

File "cass.py", line 27:
def my_distance_function(
    <source elided>

    return cas.solver.dissimilarity.weighted_hamming_distance(
    ^

During: resolving callee type: type(CPUDispatcher(<function my_distance_function at 0x7f1322269f80>))
During: typing of call at /home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/data/utilities.py (197)

During: resolving callee type: type(CPUDispatcher(<function my_distance_function at 0x7f1322269f80>))
During: typing of call at /home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/data/utilities.py (197)

During: resolving callee type: type(CPUDispatcher(<function my_distance_function at 0x7f1322269f80>))
During: typing of call at /home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/data/utilities.py (197)

During: resolving callee type: type(CPUDispatcher(<function my_distance_function at 0x7f1322269f80>))
During: typing of call at /home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/data/utilities.py (197)


File "../../../../../home/eecs/ivalexander13/datadir/miniconda3/envs/fake_cass/lib/python3.7/site-packages/cassiopeia/data/utilities.py", line 197:
    def _compute_dissimilarity_map(cm, C, missing_state_indicator, nb_weights):
        <source elided>
                dm[k] = dissimilarity_func(
                    s1, s2, missing_state_indicator, nb_weights
                    ^

When inspecting the source code, I noticed that in /home/eecs/ivalexander13/datadir/Cassiopeia/cassiopeia/data/utilities.py, there seems to be safeguards that are supposed to catch numba failures, as follows

## From utilities.py at lines 159 to 171
numbaize = True
try:
    dissimilarity_func = numba.jit(dissimilarity_function, nopython=True)
except TypeError:
    warnings.warn(
        "Failed to numbaize dissimilarity function. Falling back to Python.",
        CassiopeiaTreeWarning,
    )
    numbaize = False
    dissimilarity_func = dissimilarity_function
## From utilities.py at lines 206 to 215
with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=numba.NumbaDeprecationWarning)
        warnings.simplefilter("ignore", category=numba.NumbaWarning)
        _compute_dissimilarity_map = numba.jit(
            _compute_dissimilarity_map, nopython=numbaize
        ) 

        return _compute_dissimilarity_map(
            cm, C, missing_state_indicator, nb_weights
        )

When these two snippets are changed to completely avoid using numba, the bug disappears. So I think the bug is due to the numbaization functions not working properly, and somehow bypassing the try-catch.

ivalexander13 avatar Apr 12 '22 03:04 ivalexander13