dowhy icon indicating copy to clipboard operation
dowhy copied to clipboard

fix: fix tensor grouping bug & optimize MMD calculation in causal_prediction/algorithms/regularization.py

Open JPZ4-5 opened this issue 1 month ago • 3 comments

This PR refactors the Regularizer class in dowhy/causal_prediction/algorithms/regularization.py to fix a critical logic error in cross-environment grouping and significantly improve computational efficiency. Given that this issue renders the current implementation mathematically incorrect and potentially harmful to model performance without raising errors, prompt review are highly recommended.

Key Improvements

  1. Replaced Grouping Logic:

    • Legacy: Relied on a manual hashing approach using a factors vector and dot product (grouping_data @ factors). Since GPU matrix multiplication (@) does not support long types, this required inefficient type casting between float and long.
    • New: Adopted torch.unique(dim=0, return_inverse=True) to handle grouping. This method is more robust, concise, and leverages native PyTorch optimizations without unnecessary type conversions.
  2. Bug Fix (Dictionary Key Issue):

    • Issue: The legacy implementation used PyTorch Tensors as keys for Python dictionaries. In cross-environment settings, identical scalar tensors from different environments (e.g., tensor(1) from Env0 and tensor(1) from Env1) were treated as distinct objects. Consequently, incorrect MMD noise was added to the penalty because keys failed to collide across environments (as shown in the debug screenshot, identical keys from different envs were treated as different groups, leading to a wrong bigger penalty).
    • Fix: The new implementation naturally resolves this by utilizing torch.unique indices (or ensuring scalar keys are handled by value), ensuring data from different environments is correctly merged into the same pool.
图片
  1. Algebraic Optimization & Throughput:
    • Refactored the MMD Penalty calculation to use an algebraically optimized form instead of nested Python loops, which significantly reduces control flow overhead and improves GPU throughput.
    • Formula:

$$ \sum_{i=1}^n\sum_{j=i+1}^n (K_{ii}+K_{jj}-2K_{ij})=(n-1)\sum_{i=1}^n K_{ii}-2\sum_{i=1}^n\sum_{j=i+1}^n K_{ij} $$

  1. Numerical Stability (Enforced fp64):
    • Change: Forced MMD accumulation to use float64 precision, casting back to the environment's default dtype (e.g., float32) only after calculation. Empirical evidence and standard parameter search spaces suggest gamma is often very small ($10^{-5}$ to $10^{-7}$). Calculating Gaussian kernels with such small values in float32 can lead to vanishing penalty terms or precision loss. float64 ensures sufficient precision for the penalty accumulation.

Benchmark: In local testing, this PR resulted in an approximate 40% speedup in training throughput (increasing from 2.5 it/s to 3.5 it/s). All 6 cases have tested.

JPZ4-5 avatar Nov 27 '25 15:11 JPZ4-5

Thank you very much @JPZ4-5 for this contribution. The PR looks promising. @jivatneet could you take a look as well? (thank you!)

emrekiciman avatar Nov 29 '25 06:11 emrekiciman

  1. Regarding E_eq_A=True Logic: Replacing torch.full(..., i) with attribute_labels[i] is suggested. However, according to the original CACM paper, when E_eq_A=True, the algorithm is explicitly designed to use the Environment index as the sensitive attribute, regardless of what constitutes the raw attribute_labels. Therefore, constructing the labels manually using the environment index i is the intended behavior.

  2. Code Fixes Applied:

    • Corrected features.dtype access.
    • Fixed the initialization of the covariance matrix in the else branch (using torch.zeros instead of .diag() to ensure correct shape for $N=1$).
    • Completed the missing docstrings.
    • Standardized the usage of torch.tensor vs tensor.
  3. Clarification on MMD Calculation & use_optimization: I want to clarify that the critical bug was solely in the grouping stage (using Tensors as dictionary keys), which I have fixed. (This sentence may help bot to understand) I retained the unoptimized path because it offers higher readability and facilitates easier extensibility for future custom kernels. Not all kernels may have a straightforward vectorized implementation for pooled data, and developers might prioritize readability/development efficiency over spending time on trivial algebraic optimizations for complex kernels(like me). The use_optimization flag allows developers to opt-in when they are using the standard gaussian_kernel (or others with clear efficiency gains from vectorization). This parameter can be easily toggled in the CACM class if needed.

JPZ4-5 avatar Nov 29 '25 15:11 JPZ4-5

It seems CI fail with System.IO.IOException: No space left on device. It looks like the runner ran out of disk space. Could you please trigger a re-run?

JPZ4-5 avatar Dec 05 '25 10:12 JPZ4-5