swift icon indicating copy to clipboard operation
swift copied to clipboard

[SR-15793] [AutoDiff] Incorrect behavior with derivatives

Open philipturner opened this issue 3 years ago • 3 comments

Previous ID SR-15793
Radar None
Original Reporter @philipturner
Type Bug

Additional Detail from JIRA
Votes 0
Component/s
Labels Bug
Assignee None
Priority Medium

md5: 7bdb8a5d99f3bfcb51ba371fd5018a93

Issue Description:

Automatic differentiation gives incorrect results when differentiating a mutating function. It swizzles the components of the gradient. This function returns the following for derivatives:

  • Number being modified (x = 10): 1

  • Any parameters (y = 2, z = 3, w = 4): themselves, in order

import _Differentiation

extension Double {
  func addingThree(_ lhs: Self, _ mhs: Self, _ rhs: Self) -> Self {
    self + lhs + rhs
  }

  @derivative(of: addingThree)
  func _vjpAddingThree(
    _ lhs: Self,
    _ mhs: Self,
    _ rhs: Self
  ) -> (value: Self, pullback: (Self) -> (Self, Self, Self, Self)) {
    return (addingThree(lhs, mhs, rhs), { v in (v, lhs, mhs, rhs) })
  }

  mutating func addThree(_ lhs: Self, _ mhs: Self, _ rhs: Self) {
    self += lhs + mhs + rhs
  }

  @derivative(of: addThree)
  mutating func _vjpAddThree(
    _ lhs: Self,
    _ mhs: Self,
    _ rhs: Self
  ) -> (value: Void, pullback: (inout Self) -> (Self, Self, Self)) {
    addThree(lhs, mhs, rhs)
    return ((), { v in (lhs, mhs, rhs) })
  }
}

@differentiable(reverse)
func altAddingThree(_ x: Double, _ y: Double, _ z: Double, _ w: Double) -> Double {
  var output = x
  output.addThree(y, z, w)
  return output
}

assert((2, 3, 4) == gradient(at: 2, 3, 4, of: { 10.addingThree($0, $1, $2) }))

// fails
assert((2, 3, 4) == gradient(at: 2, 3, 4, of: { altAddingThree(10, $0, $1, $2) }))
input expected output
(x=10, y=2, z=3) (dx=1, dy=2, dz=3) (dx=1, dy=3, dz=4)
(y=2, x=10, z=3) (dy=2, dx=1, dz=3) (dy=3, dx=1, dz=4)
(y=2, z=3, x=10) (dy=2, dz=3, dx=1) (dy=3, dz=4, dx=1)
(y=2, z=3, w=4) (dy=2, dz=3, dw=4) (dy=3, dz=4, dw=2)
(x=10, y=2) (dx=1, dy=2) (dx=1, dy=3)

philipturner avatar Jan 31 '22 20:01 philipturner

Fix submitted as https://github.com/apple/swift/pull/58437

philipturner avatar Apr 30 '22 19:04 philipturner

This behavior is still present on top-of-tree Swift, and we need to investigate this further, so reopening to look into it.

BradLarson avatar May 09 '23 15:05 BradLarson

Issue still exists in 05/24 toolchain.

jkshtj avatar May 03 '24 19:05 jkshtj