Tensor-Puzzles icon indicating copy to clipboard operation
Tensor-Puzzles copied to clipboard

Problem with `where` function

Open zarif98sjs opened this issue 1 year ago • 4 comments

Shouldn't the where function be this?

def where(q, a, b):
    "Use this function to replace an if-statement."
    return (q * a) + (torch.logical_not(q)) * b

Otherwise if we use ~q, technically isn't that incorrect according to the desired function outcome?

If we used ~q, where(arange(4) * 0, 0, 1) returns tensor([-1, -1, -1, -1]). But the desired output should be tensor([1, 1, 1, 1])

zarif98sjs avatar Mar 16 '24 05:03 zarif98sjs

I agree. ~ is bitwise NOT. So the behavior is unexpected if q is a list of integers.

shunzh avatar Apr 22 '24 22:04 shunzh

Oops, will fix if I do a new version.

srush avatar Apr 23 '24 13:04 srush

Ah, nice! When creating the issue I was wondering why nobody noticed all these years 😅 Can send a PR if you want

zarif98sjs avatar Apr 23 '24 13:04 zarif98sjs

no I should just do a v2 this summer. lots of small fixes abound

srush avatar Apr 23 '24 14:04 srush