Potential bug in ott.geometry.segment._segment_interface
Describe the bug
segment._segment_interface() seems to error if the total number of points in any dataset is smaller than the max_measure_size. (I am not sure if this case should be covered by the function though.) Here is the traceback for the code snippet below:
Traceback (most recent call last):
File "/p/project/dynadis/soeren.becker/repos/inverse_cot/drafts/test_segment_sinkhorn.py", line 47, in <module>
main()
File "/p/project/dynadis/soeren.becker/repos/inverse_cot/drafts/test_segment_sinkhorn.py", line 27, in main
segment._segment_interface(
File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/ott/geometry/segment.py", line 171, in _segment_interface
segmented_y, segmented_weights_y = segment_point_cloud(
File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/ott/geometry/segment.py", line 118, in segment_point_cloud
idx = jax.lax.dynamic_slice(jnp.sort(idx), (0,), (max_measure_size,))
File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/lax/slicing.py", line 167, in dynamic_slice
return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes,
File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/core.py", line 444, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/core.py", line 447, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/core.py", line 935, in process_primitive
return primitive.impl(*tracers, **params)
File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
TypeError: slice slice_sizes must be less than or equal to operand shape, got slice_sizes (80,) for operand shape (70,).
To Reproduce Steps to reproduce the behavior:
import jax.numpy as jnp
from ott.geometry import costs, pointcloud, segment
def main():
def eval_fn(*args):
print("eval_fn")
dim = 10
num_per_segment_x = jnp.array([80, 70, 50])
num_per_segment_y = jnp.array([50, 10, 10]) # ERROR
# num_per_segment_y = jnp.array([60, 10, 10]) # WORKS
x = jnp.arange(num_per_segment_x.sum() * dim).reshape(-1, dim)
y = jnp.arange(num_per_segment_y.sum() * dim).reshape(-1, dim)
num_segments = len(num_per_segment_x)
max_measure_size = max(num_per_segment_x.max(), num_per_segment_y.max())
indices_are_sorted = False
segment_ids_x = segment_ids_y = None
weights_x = weights_y = None
padding_vector = None
segment._segment_interface(
x,
y,
eval_fn,
num_segments=num_segments,
max_measure_size=max_measure_size,
segment_ids_x=segment_ids_x,
segment_ids_y=segment_ids_y,
indices_are_sorted=indices_are_sorted,
num_per_segment_x=num_per_segment_x,
num_per_segment_y=num_per_segment_y,
weights_x=weights_x,
weights_y=weights_y,
padding_vector=padding_vector,
)
print("done")
Desktop (please complete the following information):
- OS: Linux
Additional context
I think the error occurs as segment._segment_interface() internally calls segment_point_cloud() two times (https://github.com/ott-jax/ott/blob/main/src/ott/geometry/segment.py#L160:L180), once for x and once for y, while using the same max_measure_size that may have been computed globally using both x and y. The solution might be as simple as computing max_measure_size separately for x and y but I am actually not fully sure what segment._segment_interface() is supposed to do or whether using the same max_measure_size is somehow required, here or elsewhere.
Here is a potential fix: https://github.com/soerenab/ott/commit/39b7209a52ba612e53f616502b19ea1f498bf3fa
Closed via https://github.com/ott-jax/ott/commit/2011fe4d2fa5a456983c24e0ed83e21f3dd4388a