numba icon indicating copy to clipboard operation
numba copied to clipboard

AssertionError: Unreachable condition reached (op code RERAISE executed) with batch_size > 1

Open sivecow opened this issue 11 months ago • 6 comments

System Information:

  • Operating System: Ubuntu 24.04.1
  • Numba version: 0.61.0
  • NumPy version: 2.1.3
  • Python version: 3.12.3

Description: I encountered an internal Numba error when running matrix operations with batch sizes larger than 1. The error occurs in a JIT-compiled function that performs matrix inversion operations. The code works fine with batch_size=1 but fails with batch_size=10 or larger.

The error message indicates this is an internal Numba issue: "This should not have happened, a problem has occurred in Numba's internals."

Steps to Reproduce:

  1. Run the attached code
  2. The code runs successfully with batch_size=1
  3. The code fails with batch_size=10 with an "Unreachable condition reached" error

Code to reproduce:

import numpy as np
import numba as nb

@nb.njit(fastmath=True)
def matrix_update(feature_matrix, A_inv, temp_arrays):
    """
    JIT-compiled function that focuses on matrix inversion operations
    which likely cause the Numba error
    """
    batch_size, n_features = feature_matrix.shape
    
    # Unpack temporary arrays
    lambda_A_inv_feat_T, M, M_inv = temp_arrays
    
    # Compute A_inv @ feature_matrix.T
    for j in range(n_features):
        for i in range(batch_size):
            val = 0.0
            for k in range(n_features):
                val += A_inv[j, k] * feature_matrix[i, k]
            lambda_A_inv_feat_T[j, i] = val
    
    # Compute M = I + feature_matrix @ A_inv @ feature_matrix.T
    for i in range(batch_size):
        for j in range(batch_size):
            if i == j:
                M[i, j] = 1.0  # Identity matrix diagonal
            else:
                M[i, j] = 0.0  # Zero off-diagonal elements
    
    for i in range(batch_size):
        for j in range(batch_size):
            for k in range(n_features):
                M[i, j] += feature_matrix[i, k] * lambda_A_inv_feat_T[k, j]
    
    # Matrix inversion - this is likely where the error happens
    # Reset M_inv to zeros
    for i in range(batch_size):
        for j in range(batch_size):
            M_inv[i, j] = 0.0
    
    # Handle different batch sizes with explicit code paths
    if batch_size == 1:
        # Direct inversion for 1x1 matrices
        M_inv[0, 0] = 1.0 / M[0, 0]
    elif batch_size == 2:
        # 2x2 matrix inversion
        det = M[0, 0] * M[1, 1] - M[0, 1] * M[1, 0]
        if abs(det) > 1e-10:
            M_inv[0, 0] = M[1, 1] / det
            M_inv[0, 1] = -M[0, 1] / det
            M_inv[1, 0] = -M[1, 0] / det
            M_inv[1, 1] = M[0, 0] / det
        else:
            # Fallback to diagonal regularization
            for i in range(batch_size):
                M_inv[i, i] = 1.0 / (M[i, i] + 1e-8)
    else:
        # Simplified approach for larger matrices
        try:
            # Simple regularized diagonal approach
            for i in range(batch_size):
                M_inv[i, i] = 1.0 / (M[i, i] + 1e-8)
        except:
            # Fallback - just in case
            for i in range(batch_size):
                M_inv[i, i] = 1.0 / (M[i, i] + 1e-8)
    
    # Update A_inv using Woodbury identity (simplified)
    new_A_inv = np.zeros_like(A_inv)
    for i in range(n_features):
        for j in range(n_features):
            new_A_inv[i, j] = A_inv[i, j]
            for k in range(batch_size):
                for l in range(batch_size):
                    new_A_inv[i, j] -= lambda_A_inv_feat_T[i, k] * M_inv[k, l] * lambda_A_inv_feat_T[j, l]
    
    return new_A_inv

class MatrixUpdater:
    def __init__(self, n_features, max_batch_size=10):
        self.n_features = n_features
        self.max_batch_size = max_batch_size
        self.dtype = np.float32
        
        # Initialize model parameters
        self._A_inv = np.eye(n_features, dtype=self.dtype)
        
        # Pre-allocate temporary arrays
        self._pre_allocate_temp_arrays(max_batch_size)
        
        # Warm up JIT
        self._warm_up_jit()
    
    def _pre_allocate_temp_arrays(self, batch_size):
        """Pre-allocate temporary arrays"""
        if batch_size > self.max_batch_size:
            self.max_batch_size = batch_size
        
        # Pre-allocate temporary arrays
        self._lambda_A_inv_feat_T = np.zeros((self.n_features, self.max_batch_size), dtype=self.dtype)
        self._M = np.zeros((self.max_batch_size, self.max_batch_size), dtype=self.dtype)
        self._M_inv = np.zeros((self.max_batch_size, self.max_batch_size), dtype=self.dtype)
        
        # Package arrays for Numba function
        self._temp_arrays = (self._lambda_A_inv_feat_T, self._M, self._M_inv)
    
    def _warm_up_jit(self):
        """Warm up the JIT compiler"""
        try:
            # Create tiny dummy data
            X_dummy = np.ones((1, self.n_features), dtype=self.dtype)
            
            # Create small temporary arrays for warm-up
            lambda_A_inv_feat_T = np.zeros((self.n_features, 1), dtype=self.dtype)
            M = np.zeros((1, 1), dtype=self.dtype)
            M_inv = np.zeros((1, 1), dtype=self.dtype)
            
            # Run JIT function to compile it
            _ = matrix_update(X_dummy, self._A_inv, (lambda_A_inv_feat_T, M, M_inv))
            print("JIT compilation warmed up for matrix operations")
        except Exception as e:
            print(f"JIT warm-up failed: {e}")
    
    def update(self, X):
        """Update the matrix with new data"""
        if isinstance(X, list):
            X = np.array(X, dtype=self.dtype)
        elif X.dtype != self.dtype:
            X = X.astype(self.dtype)
        
        # Ensure correct shape
        if len(X.shape) == 1:
            X = X.reshape(1, -1)
        
        batch_size = X.shape[0]
        
        # Ensure temporary arrays are large enough
        if batch_size > self.max_batch_size:
            print(f"Resizing arrays for batch size {batch_size}")
            self._pre_allocate_temp_arrays(batch_size)
        
        # Perform the update operation
        self._A_inv = matrix_update(X, self._A_inv, self._temp_arrays)
        
        return self._A_inv

def run_test():
    import time
    
    # Create sample data
    np.random.seed(42)
    n_samples = 1000
    n_features = 5
    
    # Generate data
    X = np.random.randn(n_samples, n_features).astype(np.float32)
    
    # Test with different batch sizes
    for batch_size in [1, 10, 100]:
        print(f"\nTesting with batch_size = {batch_size}")
        
        # Initialize updater
        updater = MatrixUpdater(n_features=n_features, max_batch_size=batch_size)
        
        # Track time
        start_time = time.time()
        
        # Update in batches
        for i in range(0, n_samples, batch_size):
            end_idx = min(i + batch_size, n_samples)
            X_batch = X[i:end_idx]
            A_inv = updater.update(X_batch)
        
        # Print results
        elapsed = time.time() - start_time
        print(f"Elapsed time: {elapsed:.6f} seconds")
        print(f"Final A_inv shape: {A_inv.shape}")

if __name__ == "__main__":
    run_test()

Error Log:

python3 -m test_numba

Testing with batch_size = 1
JIT compilation warmed up for matrix operations
Elapsed time: 0.001606 seconds
Final A_inv shape: (5, 5)

Testing with batch_size = 10
JIT compilation warmed up for matrix operations
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "test_numba.py", line 181, in <module>
    run_test()
  File "test_numba.py", line 173, in run_test
    A_inv = updater.update(X_batch)
            ^^^^^^^^^^^^^^^^^^^^^^^
  File "test_numba.py", line 144, in update
    self._A_inv = matrix_update(X, self._A_inv, self._temp_arrays)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "test_numba.py", line 66, in matrix_update
    for i in range(batch_size):
AssertionError: Unreachable condition reached (op code RERAISE executed)
-------------------------------------------------------------------------------
This should not have happened, a problem has occurred in Numba's internals.
You are currently using Numba version 0.61.0.

Please report the error message and traceback, along with a minimal reproducer
at: https://github.com/numba/numba/issues/new?template=bug_report.md

If more help is needed please feel free to speak to the Numba core developers
directly at: https://gitter.im/numba/numba

Thanks in advance for your help in improving Numba!

sivecow avatar Mar 22 '25 23:03 sivecow

Many thanks for the report / reproducer. I can also reproduce the issue with Numba main and Python 3.11.9.

gmarkall avatar Mar 24 '25 10:03 gmarkall

Minimal reproducer:

import numba as nb
import numpy as np

@nb.njit(fastmath=True)
def f(n, v):
    try:
        for i in range(n):
            return 1.0 / v
    except:
        return 1.0 / v

f(1, np.nan)

gives:

Traceback (most recent call last):
  File "/home/gmarkall/numbadev/issues/10015/split.py", line 14, in <module>
    f(1, np.nan)
  File "/home/gmarkall/numbadev/issues/10015/split.py", line 11, in f
    return 1.0 / v
AssertionError: Unreachable condition reached (op code RERAISE executed)

The error seems to occur when we catch an exception raised in a for loop, and then raise an exception again.

gmarkall avatar Mar 24 '25 11:03 gmarkall

Actually we don't need the for loop, fastmath, a division by zero, or any arguments - we can make this much simpler:

import numba as nb

@nb.njit
def f():
    try:
        raise RuntimeError("oops")
    except:
        raise RuntimeError("oops some more")

f()

for

Traceback (most recent call last):
  File "/home/gmarkall/numbadev/issues/10015/split.py", line 12, in <module>
    f()
  File "/home/gmarkall/numbadev/issues/10015/split.py", line 9, in f
    raise RuntimeError("oops some more")

gmarkall avatar Mar 24 '25 11:03 gmarkall

From https://numba.readthedocs.io/en/stable/reference/pysupported.html#exception-handling:

Currently, exception objects are not materialized inside compiled functions. As a result, it is not possible to store an exception object into a user variable or to re-raise an exception. With this limitation, the only realistic use-case would look like: ...

So technically the code in the issue is unsupported. However, the problem with pushing this constraint onto users is that exceptions can be quite close to invisible to the user, as is the case with the ZeroDivisionError that's triggering the situation in the original reproducer.

gmarkall avatar Mar 24 '25 11:03 gmarkall

@sivecow as a workaround, you can add error_model="numpy" to the njit decorator to prevent the zero division error being raised, like:

@nb.njit(fastmath=True, error_model="numpy")

This should allow you to proceed with debugging your code by preventing the zero division error (see the @jit decorator documentation) without Numba blowing things up. You do still have some problem in your code as the trigger is that the matrix is filled with infs and nans at the point at which the problem occurs, though.

gmarkall avatar Mar 24 '25 11:03 gmarkall

The action required on the Numba side here: we need to change the error message about reaching an unreachable condition for one that more clearly explains what has happened (reraise has occurred, but is unsupported).

gmarkall avatar Mar 25 '25 14:03 gmarkall