fix: fix tensor grouping bug & optimize MMD calculation in causal_prediction/algorithms/regularization.py
This PR refactors the Regularizer class in dowhy/causal_prediction/algorithms/regularization.py to fix a critical logic error in cross-environment grouping and significantly improve computational efficiency. Given that this issue renders the current implementation mathematically incorrect and potentially harmful to model performance without raising errors, prompt review are highly recommended.
Key Improvements
-
Replaced Grouping Logic:
- Legacy: Relied on a manual hashing approach using a
factorsvector and dot product (grouping_data @ factors). Since GPU matrix multiplication (@) does not supportlongtypes, this required inefficient type casting betweenfloatandlong. - New: Adopted
torch.unique(dim=0, return_inverse=True)to handle grouping. This method is more robust, concise, and leverages native PyTorch optimizations without unnecessary type conversions.
- Legacy: Relied on a manual hashing approach using a
-
Bug Fix (Dictionary Key Issue):
- Issue: The legacy implementation used PyTorch Tensors as keys for Python dictionaries. In cross-environment settings, identical scalar tensors from different environments (e.g.,
tensor(1)from Env0 andtensor(1)from Env1) were treated as distinct objects. Consequently, incorrect MMD noise was added to the penalty because keys failed to collide across environments (as shown in the debug screenshot, identical keys from differentenvswere treated as different groups, leading to a wrong bigger penalty). - Fix: The new implementation naturally resolves this by utilizing
torch.uniqueindices (or ensuring scalar keys are handled by value), ensuring data from different environments is correctly merged into the same pool.
- Issue: The legacy implementation used PyTorch Tensors as keys for Python dictionaries. In cross-environment settings, identical scalar tensors from different environments (e.g.,
- Algebraic Optimization & Throughput:
- Refactored the MMD Penalty calculation to use an algebraically optimized form instead of nested Python loops, which significantly reduces control flow overhead and improves GPU throughput.
- Formula:
$$ \sum_{i=1}^n\sum_{j=i+1}^n (K_{ii}+K_{jj}-2K_{ij})=(n-1)\sum_{i=1}^n K_{ii}-2\sum_{i=1}^n\sum_{j=i+1}^n K_{ij} $$
- Numerical Stability (Enforced fp64):
- Change: Forced MMD accumulation to use
float64precision, casting back to the environment's default dtype (e.g.,float32) only after calculation. Empirical evidence and standard parameter search spaces suggestgammais often very small ($10^{-5}$ to $10^{-7}$). Calculating Gaussian kernels with such small values infloat32can lead to vanishing penalty terms or precision loss.float64ensures sufficient precision for the penalty accumulation.
- Change: Forced MMD accumulation to use
Benchmark: In local testing, this PR resulted in an approximate 40% speedup in training throughput (increasing from 2.5 it/s to 3.5 it/s). All 6 cases have tested.
Thank you very much @JPZ4-5 for this contribution. The PR looks promising. @jivatneet could you take a look as well? (thank you!)
-
Regarding
E_eq_A=TrueLogic: Replacingtorch.full(..., i)withattribute_labels[i]is suggested. However, according to the original CACM paper, whenE_eq_A=True, the algorithm is explicitly designed to use the Environment index as the sensitive attribute, regardless of what constitutes the rawattribute_labels. Therefore, constructing the labels manually using the environment indexiis the intended behavior. -
Code Fixes Applied:
- Corrected
features.dtypeaccess. - Fixed the initialization of the covariance matrix in the
elsebranch (usingtorch.zerosinstead of.diag()to ensure correct shape for $N=1$). - Completed the missing docstrings.
- Standardized the usage of
torch.tensorvstensor.
- Corrected
-
Clarification on MMD Calculation &
use_optimization: I want to clarify that the critical bug was solely in the grouping stage (using Tensors as dictionary keys), which I have fixed. (This sentence may help bot to understand) I retained the unoptimized path because it offers higher readability and facilitates easier extensibility for future custom kernels. Not all kernels may have a straightforward vectorized implementation for pooled data, and developers might prioritize readability/development efficiency over spending time on trivial algebraic optimizations for complex kernels(like me). Theuse_optimizationflag allows developers to opt-in when they are using the standardgaussian_kernel(or others with clear efficiency gains from vectorization). This parameter can be easily toggled in theCACMclass if needed.
It seems CI fail with System.IO.IOException: No space left on device. It looks like the runner ran out of disk space. Could you please trigger a re-run?