swift-apis
swift-apis copied to clipboard
some mutating tensor operations are not @differentiable
+=
, -=
, *=
in https://github.com/tensorflow/swift-apis/blob/master/Sources/TensorFlow/Operators/Math.swift are not differentiable, but it should be pretty straightforward to make them differentiable.
Just starting out with swift for tensorflow. So do we only need to add @differentiable(where Scalar: TensorFlowFloatingPoint)
or am I missing something?
Oh yes, I think that is all that you need to do. (If the mutating operations were implemented in terms of other functions that were not differentiable, then you would need to actually specify what the derivative is, but it looks like the mutating operations are defined in terms of functions that are differentiable.)
Also some tests that the derivatives work are important.
Is there any way to test the changes without compiling the entire swift for tensorflow toolchain?
Yes, you can use the the toolchain binaries at https://github.com/tensorflow/swift/blob/master/Installation.md if you have one of the supported operating systems.
I meant if I customize the above mentioned file won't I need to recompile the entire thing?
Ah, I see what you mean. The quickest way to test the change is to add a test somewhere in Tests/...
and then run swift test
in the root directory of this repo. This will compile the repo and run the tests, which should only take a few minutes.
I would like to work on this. Please let me know if I can pick it.