swift-apis
swift-apis copied to clipboard
Duplication of axes in reduction causes several problems
It seems reduction functions are supporting axes with duplication like [0, 0] (_Raw.* handle it under the hood).
But _vjp* don't consider it, so gradient computation can be incorrect or cause crash.
import TensorFlow
let tensor = Tensor<Float>(zeros: [3])
print(tensor.mean(alongAxes: 0)) // [0.0]
print(tensor.mean(alongAxes: 0, 0)) // [0.0]
print(gradient(at: tensor) { tensor in tensor.mean(alongAxes: 0).scalarized() }) // [0.33333334, 0.33333334, 0.33333334]
// Incorrect gradient
print(gradient(at: tensor) { tensor in tensor.mean(alongAxes: 0, 0).scalarized() }) // [0.11111111, 0.11111111, 0.11111111]
print(tensor.sum(squeezingAxes: 0)) // 0.0
print(tensor.sum(squeezingAxes: 0, 0)) // 0.0
print(gradient(at: tensor) { tensor in tensor.sum(squeezingAxes: 0).scalarized() }) // [1.0, 1.0, 1.0]
// This causes crash
// print(gradient(at: tensor) { tensor in tensor.sum(squeezingAxes: 0, 0).scalarized() })