xtensor
xtensor copied to clipboard
fix: correct negative axis handling in roll function
Checklist
- [x] The title and commit message(s) are descriptive.
- [x] Small commits made to fix your PR have been squashed to avoid history pollution.
- [x] Tests have been added for new features or bug fixes.
- [ ] API of new functions and classes are documented.
Description
This PR fixes a bug in xt::roll(e, shift, axis) where negative axis indices (e.g., -1 for the last axis) were incorrectly rejected.
The Bug
include/xtensor/misc/xmanipulation.hpp
The original code converted axis to size_t before normalization:
std::size_t saxis = static_cast<std::size_t>(axis); // -1 becomes SIZE_MAX
if (axis < 0)
{
axis += std::ptrdiff_t(cpy. dimension());
}
if (saxis >= cpy.dimension() || axis < 0) // SIZE_MAX >= dim is always true!
{
XTENSOR_THROW(... );
}
This caused valid negative indices like -1 to incorrectly trigger the bounds check exception.
The Fix
include/xtensor/misc/xmanipulation.hpp
auto cpy = empty_like(e);
const auto& shape = cpy. shape();
- std::size_t saxis = static_cast<std::size_t>(axis);
- if (axis < 0)
- {
- axis += std::ptrdiff_t(cpy. dimension());
- }
+ const auto dim = cpy.dimension();
- if (saxis >= cpy.dimension() || axis < 0)
+ if (axis < -static_cast<std::ptrdiff_t>(dim) || axis >= static_cast<std::ptrdiff_t>(dim))
{
- XTENSOR_THROW(std::runtime_error, "axis is no within shape dimension.");
+ XTENSOR_THROW(std::runtime_error, "axis is not within shape dimension.");
}
+ std::size_t saxis = normalize_axis(dim, axis);
+
const auto axis_dim = static_cast<std::ptrdiff_t>(shape[saxis]);
- Validate axis bounds before conversion
- Use
normalize_axis()for consistency with other functions (swapaxes,moveaxis, etc.) - Fix typo:
"axis is no within"→"axis is not within"
Tests Added
test/test_xmanipulation.cpp
xarray<double> expected8 = {{{3, 1, 2}}, {{6, 4, 5}}, {{9, 7, 8}}};
ASSERT_EQ(expected8, xt::roll(e2, -2, /*axis*/ 2));
+ // Boundary error cases
+ EXPECT_THROW(xt::roll(e2, 1, /*axis*/ 3), std::runtime_error);
+ EXPECT_THROW(xt::roll(e2, 1, /*axis*/ -4), std::runtime_error);
+
+ // Negative axis indices
+ xarray<double> expected9 = {{{3, 1, 2}}, {{6, 4, 5}}, {{9, 7, 8}}};
+ ASSERT_EQ(expected9, xt::roll(e2, -2, /*axis*/ -1));
+
+ xarray<double> expected10 = {{{1, 2, 3}}, {{4, 5, 6}}, {{7, 8, 9}}};
+ ASSERT_EQ(expected10, xt::roll(e2, -2, /*axis*/ -2));
+
+ xarray<double> expected11 = {{{4, 5, 6}}, {{7, 8, 9}}, {{1, 2, 3}}};
+ ASSERT_EQ(expected11, xt::roll(e2, 2, /*axis*/ -3));
}
Note: This bug has existed since #1823 (2019).