MiDi
MiDi copied to clipboard
Incorrect Value Replacement in edge_types
Description
The line:
edge_types[edge_types == 4] = 1.5
fails because edge_types is a long tensor, truncating 1.5 to 1. This replaces 4 with 1 instead of 1.5, affecting logic and results.
Impact: This propagates incorrect values, impacting downstream research like molecule and atom stabilities in works such as EQGAT-Diff and Semla-Flow.
Suggested Fix:
Convert edge_types to a floating-point tensor before replacement:
edge_types = edge_types.float()
edge_types[edge_types == 4] = 1.5
https://github.com/cvignac/MiDi/blob/775b731c38967a1e49615b2ad70ac6b5db24909a/midi/metrics/molecular_metrics.py#L304