RFC-0016: Masked reductions and normalizations
This RFC Discusses semantics and implementation details of masked reduction and normalization operators.
cc @IvanYashchuk @pearu @ngimel @mruberry @ezyang @jbschlosser
Feels like it should be prototypable with __torch_function__ (or maybe __torch_dispatch__?)
@ezyang - does this mean you'd prefer to see this released and packaged out of tree first before considering inclusion in the core?
does this mean you'd prefer to see this released and packaged out of tree first before considering inclusion in the core?
Not necessarily; I'm referring to this part of the spec:
Indeed the best way to describe the behavior is to implement it. Please note that this is only meant to describe semantics and is not an actual implementation.
wouldn't be a long step to have an executable specification that people can play around with.
Since the nan* reductions, like nansum, are existing masked reductions we should be sure the semantics are equivalent. This proposal just allows the mask to be specified directly rather than by value. Supporting more general value-based masking might be interesting in the future, too.
cc @heitorschueroff
@ezyang - agreed, I'm wondering whether or when we should create an out-of-tree Python-only prototype for a MaskedTensor.
@mruberry - you can always get a value-based (let's say 4) mask by e.g. masked_sum(input, input != 4) or masked_sum(input, input == input) for nan.
masked_sum(input, ~(input != input))for nan.
Nit: masked_sum(input, input == input) would work for the nan case as well.
agreed, I'm wondering whether or when we should create an out-of-tree Python-only prototype for a MaskedTensor.
If it's just one person, probably sticking it in a colab is good enough. Multiple people wanting to work on the semantics ~> put it in GitHub somewhere.