pytorch_scatter icon indicating copy to clipboard operation
pytorch_scatter copied to clipboard

Fix nullptr math & clean segment_csr_cpu.cpp

Open r-barnes opened this issue 3 years ago • 2 comments
trafficstars

When I run this code with LLVM-12's undefined behaviour sanitizer enabled, I see:

pytorch/torch-scatter/csrc/cpu/segment_csr_cpu.cpp:60:3: runtime error: applying non-zero offset 8 to null pointer
    #0 0x7fa95da15396 in segment_csr_cpu(at::Tensor, at::Tensor, c10::optional<at::Tensor>, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::$_0::operator()() const::'lambda2'()::operator()() const::'lambda'()::operator()() const::'lambda'()::operator()() const pytorch/torch-scatter/csrc/cpu/segment_csr_cpu.cpp:60

This is because on Line 41 we have int64_t *arg_out_data = nullptr;. The value of arg_out_data is set conditionally, but arg_out_data is used unconditionally within the "segment_csr" kernel. Adding an if-statement gates that.

I've also added const and de-shadowed variables in a few places to make the code more readable.

r-barnes avatar May 03 '22 20:05 r-barnes

Thanks; do you happen to know how to fix that? (My current solution is a bit of a guess since I have almost no familiarity with this code.)

On Tue, May 3, 2022, 23:48 Matthias Fey @.***> wrote:

@.**** commented on this pull request.

Thanks for the PR!

In csrc/cpu/segment_csr_cpu.cpp https://github.com/rusty1s/pytorch_scatter/pull/292#discussion_r864473876 :

  •    for (auto k = 0; k < K; k++)
    
  •      Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
    
  •    if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
    

It looks like we skip writing to the destination here for any other reduction (which is likely unintented - see test failures).

— Reply to this email directly, view it on GitHub https://github.com/rusty1s/pytorch_scatter/pull/292#pullrequestreview-961360456, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAXZHVEGH5ZW5WNXN7KGSZTVIIFTRANCNFSM5U76MKRA . You are receiving this because you authored the thread.Message ID: @.***>

r-barnes avatar May 04 '22 07:05 r-barnes

I think we could overload Reducer::write such that it supports arg_out_data or not. We can then pass in arg_out_data in case reduce == min | max, and otherwise not.

rusty1s avatar May 04 '22 07:05 rusty1s

This pull request had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity.

github-actions[bot] avatar Nov 01 '22 02:11 github-actions[bot]