swift
swift copied to clipboard
[AutoDiff] WIP: Use owned callee convention for linear maps.
Switch to @callee_owned callee convention for all linear map functions (differentials and pullbacks) returned from derivative functions. This reduces a half of reference counting operations in compiler-generated derivatives, and enables child linear maps that are called in linear maps to be destroyed right after the call.
Background: Before this patch, linear map functions took an @owned context and had @callee_guaranteed callee convention. It was a suboptimal design because
- All linear maps in AD-generated code are called exactly once. The context's parameter convention is
@ownedbecause we want to consume the context as early as possible, but doing so in combination with@callee_guaranteedconvention leads to an unnecessary pair of retain (in the partial application forwarder) and release (in the caller). As a result, the entire context was kept alive until the entire outer pullback returns. - Pullbacks' allocation and deallocation follow a strict stack discipline, so we really want to consume pullbacks as early as possible and not retain unused memory.
Resolves rdar://71892494.
Note: There are two places where reabstracting back to @callee_guaranteed is needed.
Builtin.applyDerivative*call sites. This is because the caller expects the result to have the formal lowered type. There isn't a way to control callee convention in the AST. This reabstraction should be optimized away.- Linear map struct fields. Similar to the case above, struct fields have AST types and thus cannot store
@callee_ownedclosures directly. Two ways to eliminate the reabstraction are a) bitcasting and storing these closures as$(Builtin.RawPointer, Builtin.NativeObject)and b) replacing linear map structs with tuples.
How can I complete this PR? I'm looking at completing SR-15580 as well as gaining more knowledge about the compiler in general to prepare for debugging autodiff.
@CodaFi I know this is a stretch, but do you have the capacity to help me out here? It seems like I'll have to wait a long time before anybody in the old S4TF team does.
I'm not sure why this PR is a work in progress. Perhaps some tests didn't pass? Try rebasing it and we can get CI running to see what its current state is.
@CodaFi I'm having trouble figuring out how to rebase. I tried making a new PR from this branch to main and there are conflicts. However, instead of letting me make a new PR, it only gives me the link to this PR. I'm probably doing something wrong because I have limited experience with Git.
https://github.com/apple/swift/compare/main...rxwei:callee-owned-linear-map
Thanks for the ping. I'm giving this a shot this weekend.
Just rebased on top of main and most things are working. For anyone interested in picking up this work, the reason this was still WIP was because there are some unexpected memory leaks causing validation tests to fail.
Swift(macosx-x86_64) :: AutoDiff/validation-test/differentiable_protocol_requirements.swift
Swift(macosx-x86_64) :: AutoDiff/validation-test/existential.swift
Swift(macosx-x86_64) :: AutoDiff/validation-test/forward_mode_simple.swift
For example:
[ RUN ] ProtocolRequirementDifferentiation.func
stdout>>> check failed at /Volumes/Media/Development/Swift/swift-source/swift/test/AutoDiff/validation-test/differentiable_protocol_requirements.swift, line 55
stdout>>> expected: 0 (of type Swift.Int)
stdout>>> actual: 12 (of type Swift.Int)
stdout>>> Leaks detected: 12
[ FAIL ] ProtocolRequirementDifferentiation.func
[ RUN ] ProtocolRequirementDifferentiation.constructor, accessor, subscript
stdout>>> check failed at /Volumes/Media/Development/Swift/swift-source/swift/test/AutoDiff/validation-test/differentiable_protocol_requirements.swift, line 117
stdout>>> expected: 0 (of type Swift.Int)
stdout>>> actual: 1 (of type Swift.Int)
stdout>>> Leaks detected: 1
[ FAIL ] ProtocolRequirementDifferentiation.constructor, accessor, subscript
ProtocolRequirementDifferentiation: Some tests failed, aborting
UXPASS: []
FAIL: ["func", "constructor, accessor, subscript"]
SKIP: []
TODO:
- Debug and fix memory leaks.
- Pick up file FileCheck test changes from 5b3a3e3.
cc @BradLarson
Thanks for investigating this! I don’t know if my current experience with the compiler is sufficient to fix the memory leak, though. Maybe after implementing the cross-import overlay.
I’d like to take on this bug myself in the future, because it’s a great stepping stone for future debugging experience. It’s small, specific and tractable. Since changing the calling convention is just an optimization, would it be okay if it got delayed because I waited to fix it?
@philipturner This is really not a "small, specific and tractable" task for Swift compiler beginners. It requires good understanding of SIL and the differentiation transform, because the leak is likely due to incorrectly generated code within SILGen and differentiation transform. Also, it's not "just an optimization", as it's a major ABI change.
I don't have the bandwidth to tackle this in the near future, and was hoping that @BradLarson (and team) could take a look when they need to work on further optimizations that are blocked by this. @BradLarson should also be able to suggest some starter tasks for you.
Okay, thanks for the advice! I think the cross-import overlay implementation might go quickly, so I might need starter tasks soon. In the meantime, I'll look for ways to get familiar with the Swift compiler on my own.
The SIL for generated pullback looks a bit suspicious to me:
// pullback of B.a(_:)
sil private [ossa] @$s4main1BV1ay23DifferentiationUnittest7TrackedVySfGAHFTJpSUpSr : $@convention(thin) (@in_guaranteed Tracked<Float>, @owned _AD__$s4main1BV1ay23DifferentiationUnittest7TrackedVySfGAHF_bb0__PB__src_0_wrt_0) -> @out Tracked<Float> {
// %0 // user: %19
// %1 // user: %8
// %2 // user: %9
bb0(%0 : $*Tracked<Float>, %1 : $*Tracked<Float>, %2 : @owned $_AD__$s4main1BV1ay23DifferentiationUnittest7TrackedVySfGAHF_bb0__PB__src_0_wrt_0):
%3 = alloc_stack $Tracked<Float>, let, name "x", argno 1, expr op_deref // users: %22, %19, %16, %6
%4 = witness_method $Tracked<Float>, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %6
%5 = metatype $@thick Tracked<Float>.Type // user: %6
%6 = apply %4<Tracked<Float>>(%3, %5) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
%7 = alloc_stack $Tracked<Float> // users: %21, %20, %12, %8
copy_addr %1 to [initialization] %7 : $*Tracked<Float> // id: %8
%9 = destructure_struct %2 : $_AD__$s4main1BV1ay23DifferentiationUnittest7TrackedVySfGAHF_bb0__PB__src_0_wrt_0 // user: %11
%10 = alloc_stack $Tracked<Float> // users: %18, %17, %16, %12
%11 = unchecked_value_cast %9 : $@callee_guaranteed (@in_guaranteed Tracked<Float>) -> @out Tracked<Float> to $@callee_owned (@in_guaranteed Tracked<Float>) -> @out Tracked<Float> // user: %12
%12 = apply %11(%10, %7) : $@callee_owned (@in_guaranteed Tracked<Float>) -> @out Tracked<Float> // user: %13
destructure_tuple %12 : $() // id: %13
%14 = witness_method $Tracked<Float>, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %16
%15 = metatype $@thick Tracked<Float>.Type // user: %16
%16 = apply %14<Tracked<Float>>(%3, %10, %15) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
destroy_addr %10 : $*Tracked<Float> // id: %17
dealloc_stack %10 : $*Tracked<Float> // id: %18
copy_addr [take] %3 to [initialization] %0 : $*Tracked<Float> // id: %19
destroy_addr %7 : $*Tracked<Float> // id: %20
dealloc_stack %7 : $*Tracked<Float> // id: %21
dealloc_stack %3 : $*Tracked<Float> // id: %22
%23 = tuple () // user: %24
return %23 : $() // id: %24
} // end sil function '$s4main1BV1ay23DifferentiationUnittest7TrackedVySfGAHFTJpSUpSr'
Note the following:
%11 = unchecked_value_cast %9 : $@callee_guaranteed (@in_guaranteed Tracked<Float>) -> @out Tracked<Float> to $@callee_owned (@in_guaranteed Tracked<Float>) -> @out Tracked<Float> // user: %12
%12 = apply %11(%10, %7) : $@callee_owned (@in_guaranteed Tracked<Float>) -> @out Tracked<Float> // user: %13
So it looks like we're changing calling convention here from @callee_guaranteed to @callee_owned here. As a result no release is generated after function call and as callee assumes that caller would release the context it does not bother with releasing it by itlself. The generated LLVM IR from main here indeed have extra @swift_release call.
So it looks like we're changing calling convention here from @callee_guaranteed to @callee_owned here. As a result no release is generated after function call and as callee assumes that caller would release the context it does not bother with releasing it by itlself.
These pullbacks were actually @callee_owned in the first place, bitcast to @callee_guaranteed by VJPCloner (VJPCloner.cpp:679-694). The bitcast is necessary for us to be able to store the pullbacks inside an AST struct, as callee conventions cannot be specified in AST function types.
So it looks like we're changing calling convention here from @callee_guaranteed to @callee_owned here. As a result no release is generated after function call and as callee assumes that caller would release the context it does not bother with releasing it by itlself.
These pullbacks were actually
@callee_ownedin the first place, bitcast to@callee_guaranteedby VJPCloner (VJPCloner.cpp:679-694). The bitcast is necessary for us to be able to store the pullbacks inside an AST struct, as callee conventions cannot be specified in AST function types.
Yes, makes sense. Ok, in terms of leaked values in existential.swift, there are two allocs of "3" from:
func b(g: A) -> Tracked<Float> {
return gradient(at: 3) { x in g.a(x) }
}
but we release only one
Weird... If there's nothing I missed from applying the pullback bitcasting workaround, I'd expect any leaks to trigger an OSSA verification failure. Something apparently slipped through. We can check if this 3 was captured by pullbacks.
Weird... If there's nothing I missed from applying the pullback bitcasting workaround, I'd expect any leaks to trigger an OSSA verification failure. Something apparently slipped through. We can check if this
3was captured by pullbacks.
Things might be a bit more interesting. The corresponding Tracking<Float> should be released at the very end of b(g: A) and this is what happens in main. But not here. Stay tuned :)
So, yes. On main at the end of b(g: A) we're having just one unowned refcount for 3, we release the object and therefore deallocate it. On branch we're having additional strong refcount, so we cannot deinit it. Something is holding 3 somewhere.
So, the value is captured in the pullback of * here:
extension ${Self}
where
T: Differentiable & SignedNumeric, T == T.Magnitude,
T == T.TangentVector
{
@usableFromInline
@derivative(of: *)
internal static func _vjpMultiply(lhs: Self, rhs: Self)
-> (value: Self, pullback: (Self) -> (Self, Self))
{
return (lhs * rhs, { v in (v * rhs, v * lhs) })
}
}
It seems like this pullback is never released.
Could this be an IRGen issue when its emitting a non-[callee_guaranteed] partial_apply?
It's safe to assume that this library-defined pullback is emitted correctly since its closure lowering is not changed by this PR. Library-defined pullbacks are being captured by parent pullbacks generated by the differentiation transform via non-[callee_guaranteed] partial_applys.
Could this be an IRGen issue when its emitting a non -[callee_guaranteed]
partial_apply`?
The partial apply is [callee_guaranteed] here, yes. And it's pretty same in main, the values are also captured there.
Could this be an IRGen issue when its emitting a non -[callee_guaranteed]
partial_apply`?The partial apply is
[callee_guaranteed]here, yes. And it's pretty same inmain, the values are also captured there.
Right, this particular partial_apply (library defined VJP) is [callee_guaranteed] and I don't think there's any issues with it. But the resulting pullback may be captured by outer pullbacks using a non-[callee_guaranteed] partial_apply, forming @callee_owned closures.
Could this be an IRGen issue when its emitting a non -[callee_guaranteed]
partial_apply`?The partial apply is
[callee_guaranteed]here, yes. And it's pretty same inmain, the values are also captured there.Right, this particular
partial_apply(library defined VJP) is[callee_guaranteed]and I don't think there's any issues with it. But the resulting pullback may be captured by outer pullbacks using a non-[callee_guaranteed]partial_apply, forming@callee_ownedclosures.
Yes, it's one level up:
// reverse-mode derivative of static Tracked<A>.* infix(_:_:)
sil [thunk] [always_inline] [ossa] @$s23DifferentiationUnittest7TrackedVAASjRzlE1moiyACyxGAE_AEtFZs13SignedNumericRz01_A014DifferentiableRz9MagnitudeSjQzRsz13TangentVectorAgHPQzAJRSlTJrSSUpSr : $@convention(method) <τ_0_0 where τ_0_0 : SignedNumeric, τ_0_0 : Differentiable, τ_0_0 == τ_0_0.Magnitude, τ_0_0.Magnitude == τ_0_0.TangentVector> (@in_guaranteed Tracked<τ_0_0>, @in_guaranteed Tracked<τ_0_0>, @thin Tracked<τ_0_0>.Type) -> (@out Tracked<τ_0_0>, @owned @callee_owned @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Tracked<τ_0_0.TangentVector>, Tracked<τ_0_0.TangentVector>, Tracked<τ_0_0.TangentVector>>) {
// %0 // user: %5
// %1 // user: %5
// %2 // user: %5
// %3 // user: %5
bb0(%0 : $*Tracked<τ_0_0>, %1 : $*Tracked<τ_0_0>, %2 : $*Tracked<τ_0_0>, %3 : $@thin Tracked<τ_0_0>.Type):
// function_ref static Tracked<A>._vjpMultiply(lhs:rhs:)
%4 = function_ref @$s23DifferentiationUnittest7TrackedVAAs13SignedNumericRz01_A014DifferentiableRz9MagnitudeSjQzRsz13TangentVectorAeFPQzAHRSlE12_vjpMultiply3lhs3rhsACyxG5value_AO_AOtAOc8pullbacktAO_AOtFZ : $@convention(method) <τ_0_0 where τ_0_0 : SignedNumeric, τ_0_0 : Differentiable, τ_0_0 == τ_0_0.Magnitude, τ_0_0.Magnitude == τ_0_0.TangentVector> (@in_guaranteed Tracked<τ_0_0>, @in_guaranteed Tracked<τ_0_0>, @thin Tracked<τ_0_0>.Type) -> (@out Tracked<τ_0_0>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3, τ_0_4, τ_0_5 where τ_0_0 == τ_0_1, τ_0_2 == τ_0_3, τ_0_4 == τ_0_5> (@in_guaranteed Tracked<τ_0_0>) -> (@out Tracked<τ_0_2>, @out Tracked<τ_0_4>) for <τ_0_0, τ_0_0, τ_0_0, τ_0_0, τ_0_0, τ_0_0>) // user: %5
%5 = apply %4<τ_0_0>(%0, %1, %2, %3) : $@convention(method) <τ_0_0 where τ_0_0 : SignedNumeric, τ_0_0 : Differentiable, τ_0_0 == τ_0_0.Magnitude, τ_0_0.Magnitude == τ_0_0.TangentVector> (@in_guaranteed Tracked<τ_0_0>, @in_guaranteed Tracked<τ_0_0>, @thin Tracked<τ_0_0>.Type) -> (@out Tracked<τ_0_0>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3, τ_0_4, τ_0_5 where τ_0_0 == τ_0_1, τ_0_2 == τ_0_3, τ_0_4 == τ_0_5> (@in_guaranteed Tracked<τ_0_0>) -> (@out Tracked<τ_0_2>, @out Tracked<τ_0_4>) for <τ_0_0, τ_0_0, τ_0_0, τ_0_0, τ_0_0, τ_0_0>) // user: %6
%6 = convert_function %5 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3, τ_0_4, τ_0_5 where τ_0_0 == τ_0_1, τ_0_2 == τ_0_3, τ_0_4 == τ_0_5> (@in_guaranteed Tracked<τ_0_0>) -> (@out Tracked<τ_0_2>, @out Tracked<τ_0_4>) for <τ_0_0, τ_0_0, τ_0_0, τ_0_0, τ_0_0, τ_0_0> to $@callee_guaranteed (@in_guaranteed Tracked<τ_0_0>) -> (@out Tracked<τ_0_0>, @out Tracked<τ_0_0>) // user: %8
// function_ref thunk for @escaping @callee_guaranteed (@in_guaranteed Tracked<A>) -> (@out Tracked<A>, @out Tracked<A>)
%7 = function_ref @$s23DifferentiationUnittest7TrackedVyxGA2DIegnrr_A3DIexnrr_s13SignedNumericRz01_A014DifferentiableRz9MagnitudeSjQzRsz13TangentVectorAfGPQzAIRSlTR : $@convention(thin) <τ_0_0 where τ_0_0 : SignedNumeric, τ_0_0 : Differentiable, τ_0_0 == τ_0_0.Magnitude, τ_0_0.Magnitude == τ_0_0.TangentVector> (@in_guaranteed Tracked<τ_0_0>, @guaranteed @callee_guaranteed (@in_guaranteed Tracked<τ_0_0>) -> (@out Tracked<τ_0_0>, @out Tracked<τ_0_0>)) -> (@out Tracked<τ_0_0>, @out Tracked<τ_0_0>) // user: %8
%8 = partial_apply %7<τ_0_0>(%6) : $@convention(thin) <τ_0_0 where τ_0_0 : SignedNumeric, τ_0_0 : Differentiable, τ_0_0 == τ_0_0.Magnitude, τ_0_0.Magnitude == τ_0_0.TangentVector> (@in_guaranteed Tracked<τ_0_0>, @guaranteed @callee_guaranteed (@in_guaranteed Tracked<τ_0_0>) -> (@out Tracked<τ_0_0>, @out Tracked<τ_0_0>)) -> (@out Tracked<τ_0_0>, @out Tracked<τ_0_0>) // user: %9
%9 = convert_function %8 : $@callee_owned (@in_guaranteed Tracked<τ_0_0>) -> (@out Tracked<τ_0_0>, @out Tracked<τ_0_0>) to $@callee_owned @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Tracked<τ_0_0>, Tracked<τ_0_0>, Tracked<τ_0_0>> // user: %10
return %9 : $@callee_owned @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Tracked<τ_0_0>, Tracked<τ_0_0>, Tracked<τ_0_0>> // id: %10
} // end sil function '$s23DifferentiationUnittest7TrackedVAASjRzlE1moiyACyxGAE_AEtFZs13SignedNumericRz01_A014DifferentiableRz9MagnitudeSjQzRsz13TangentVectorAgHPQzAJRSlTJrSSUpSr'
On main we're having callee_guaranteed closure here. And I do not see any differences in generated LLVM IR here.
What about the generated partial application forwarders in LLVM IR?
@rxwei I rebased this PR into main and ran into bunch of ABI-compatibility assertions. Looks like instead of reabstraction thunks (removed here: https://github.com/apple/swift/pull/34935/files#diff-01cf87f81a8c47d84be8508fac2cb0f1a4ba15919ad86d3c8904bccee60151b5L915) this PR emits just unchecked bitcasts.
However, the bitcasts are now from / to ABI-incompatible types, e.g. from
$@callee_owned (@in_guaranteed Generic<Float>.TangentVector) -> (@out Float, @out Float, @out Float) to
$@callee_guaranteed (Generic<Float>.TangentVector) -> (Float, Float, Float). This certainly will not work.
Returning reabstraction thunks back resolved the issues (though leaks remained, but this is a separate issue) :)
However, the bitcasts are now from / to ABI-incompatible types, e.g. from
$@callee_owned (@in_guaranteed Generic<Float>.TangentVector) -> (@out Float, @out Float, @out Float)to$@callee_guaranteed (Generic<Float>.TangentVector) -> (Float, Float, Float). This certainly will not work.
That's a bit surprising though, as we should only be changing the callee convention, not the result abstraction.
However, the bitcasts are now from / to ABI-incompatible types, e.g. from
$@callee_owned (@in_guaranteed Generic<Float>.TangentVector) -> (@out Float, @out Float, @out Float)to$@callee_guaranteed (Generic<Float>.TangentVector) -> (Float, Float, Float). This certainly will not work.That's a bit surprising though, as we should only be changing the callee convention, not the result abstraction.
Yes. And there were nothing like this in my previous rebase...
I'd recommend debugging this with the previous commit. My understanding from @gottesmm was that @callee_owned convention hasn't been used for a few years, so there may be bugs in IRGen.
@rxwei Yes, this is the plan. I'm just documenting some things that might help us :)
@rxwei Ok, the problem exists in main as well :) What happens: for captured stuff we're having the following pullback type: $@callee_guaranteed (Float, @inout_aliasable Float) -> Float. The lowered pullback type is $@callee_guaranteed (Float, @inout Float) -> Float (apparently aliaseable bit is silently ignored by lowering, it simply creates inout parameter convention).
On main we're simply emitting a reabstraction thunk for pullback conversion. On branch instead we're doing a bitcast. And now the assertion is triggered because it thinks that these two types ($@callee_guaranteed (Float, @inout_aliasable Float) -> Float and $@callee_guaranteed (Float, @inout Float) -> Float) are not ABI-compatible.
We started to handle captured arguments only very recently, this is why this issue did not occur previously.
There are multiple issues here actually. Few of them are related directly to this PR and I fixed them locally (some are just typos and others are related to the recent changes in autodiff code). The root cause is somewhere around reabstraction thunks from @callee_owend to @callee_guaranteed and the corresponding partial apply forwarders. Bad news is that we're leaking much more than reported by tests. We are leaking the whole context of one of functions and everything that was captured there. We were just lucky that it triggered in one of these tests.
I managed to significantly reduce one of testcases stripping layers and layers of abstractions :) Here is the piece of optimized LLVM IR showing the issue:
%13 = tail call noalias %swift.refcounted* @swift_allocObject(%swift.type* getelementptr inbounds (%swift.full_boxmetadata, %swift.full_boxmetadata* @metadata.75, i64 0, i32 2), i64 32, i64 7) #4, !noalias !101
%14 = getelementptr inbounds %swift.refcounted, %swift.refcounted* %13, i64 1
%.fn.i.i = bitcast %swift.refcounted* %14 to i8**
store i8* %11, i8** %.fn.i.i, align 8, !noalias !101
%.data.i.i = getelementptr inbounds %swift.refcounted, %swift.refcounted* %13, i64 1, i32 1
%15 = bitcast i64* %.data.i.i to %swift.refcounted**
store %swift.refcounted* %12, %swift.refcounted** %15, align 8, !noalias !101
%16 = bitcast i8* %11 to float (float, %swift.refcounted*)*
%17 = tail call %swift.refcounted* @swift_retain(%swift.refcounted* returned %12) #4, !noalias !108
%18 = tail call swiftcc float %16(float 1.000000e+00, %swift.refcounted* swiftself %12) #17, !noalias !112
tail call swiftcc void @"$s5main23fooyySf_SftF"(float %10, float %18)
tail call void @swift_release(%swift.refcounted* %0) #4
tail call void @swift_release(%swift.refcounted* %3) #4
tail call void @swift_release(%swift.refcounted* %6) #4
ret void
Here %16 is a differential call. Note that %13 is never released and %12 is not consumed as well due to extra retain. This retain is from @callee_owned => @callee_guaranteed reabstraction thunk (pretty reasonable). The differential here would consume the context from %12 normally. However we're missing a release of %13 here.