feat: VJP utility based on `autodiff_thunk`
Fixes #1853
Todo:
- [ ] Add docs of what I learned
Related:
- https://github.com/JuliaDiff/DifferentiationInterface.jl/discussions/721
Codecov Report
:white_check_mark: All modified and coverable lines are covered by tests.
:white_check_mark: Project coverage is 28.01%. Comparing base (e8cec0c) to head (533fd4e).
Additional details and impacted files
@@ Coverage Diff @@
## main #2309 +/- ##
=======================================
Coverage 28.01% 28.01%
=======================================
Files 2 2
Lines 207 207
=======================================
Hits 58 58
Misses 149 149
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
In some way, this feels equivalent to implementing autodiff but for non-scalar returns and fixing my old mistake of always passing in one(T) as the seed.
I think of autodiff(Reverse generally as vjp (with the convention that the output is updated in-place)
We could also call it autodiff but then we'd need to figure out how the output seed is passed. It can't be inside a Duplicated or BatchDuplicated because we have no primal
It can't be inside a Duplicated or BatchDuplicated because we have no primal
Seed and BatchSeed? Just throwing some ideas into the air.
And how would you see the order of arguments?
autodiff(Reverse, f, seed, args...)
autodiff(Reverse, f, args...; seed=...)
The first variant, since that is already the convention used for the activity of the return.
Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.
Click here to view the suggested changes.
diff --git a/src/sugar.jl b/src/sugar.jl
index f15e6e2..3ed38de 100644
--- a/src/sugar.jl
+++ b/src/sugar.jl
@@ -1367,10 +1367,10 @@ julia> Enzyme.batchify_activity(Duplicated{Vector{Float64}}, Val(2))
BatchDuplicated{Vector{Float64}, 2}
""" -batchify_activity(::Type{Active{T}}, ::Val{B}) where {T,B} = Active{T} -batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T,B} = BatchDuplicated{T,B} -batchify_activity(::Type{DuplicatedNoNeed{T}}, ::Val{B}) where {T,B} = BatchDuplicatedNoNeed{T,B} -batchify_activity(::Type{MixedDuplicated{T}}, ::Val{B}) where {T,B} = BatchMixedDuplicated{T,B} +batchify_activity(::Type{Active{T}}, ::Val{B}) where {T, B} = Active{T} +batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T, B} = BatchDuplicated{T, B} +batchify_activity(::Type{DuplicatedNoNeed{T}}, ::Val{B}) where {T, B} = BatchDuplicatedNoNeed{T, B} +batchify_activity(::Type{MixedDuplicated{T}}, ::Val{B}) where {T, B} = BatchMixedDuplicated{T, B}
""" @@ -1378,12 +1378,12 @@ batchify_activity(::Type{MixedDuplicated{T}}, ::Val{B}) where {T,B} = BatchMixed
Wrapper for a single adjoint to the return value in reverse mode. """ -struct Seed{T,A} +struct Seed{T, A} dval::T
- function Seed(dval::T) where T
- function Seed(dval::T) where {T} A = guess_activity(T, Reverse)
-
return new{T,A}(dval)
-
end endreturn new{T, A}(dval)
@@ -1395,10 +1395,10 @@ Wrapper for a tuple of adjoints to the return value in reverse mode. struct BatchSeed{N, T, AB} dvals::NTuple{N, T}
- function BatchSeed(dvals::NTuple{N,T}) where {N,T}
- function BatchSeed(dvals::NTuple{N, T}) where {N, T} A = guess_activity(T, Reverse) AB = batchify_activity(A, Val(N))
-
return new{N,T,AB}(dvals)
-
end endreturn new{N, T, AB}(dvals)
@@ -1417,7 +1417,7 @@ Useful for computing pullbacks / VJPs for functions whose output is not a scalar function autodiff( rmode::Union{ReverseMode{ReturnPrimal}, ReverseModeSplit{ReturnPrimal}}, f::FA,
-
dresult::Seed{RT,RA},
-
) where {ReturnPrimal, FA <: Annotation, RT, RA, N} rmode_split = Split(rmode) @@ -1454,7 +1454,7 @@ Useful for computing pullbacks / VJPs for functions whose output is not a scalar function autodiff( rmode::Union{ReverseMode{ReturnPrimal}, ReverseModeSplit{ReturnPrimal}}, f::FA,dresult::Seed{RT, RA}, args::Vararg{Annotation, N},
-
dresults::BatchSeed{B,RT,RA},
-
) where {ReturnPrimal, B, FA <: Annotation, RT, RA, N} rmode_split_rightwidth = ReverseSplitWidth(Split(rmode), Val(B)) diff --git a/test/seeded.jl b/test/seeded.jl index 68c3962..664a6fa 100644 --- a/test/seeded.jl +++ b/test/seeded.jl @@ -4,9 +4,9 @@ using Testdresults::BatchSeed{B, RT, RA}, args::Vararg{Annotation, N},
@testset "Batchify activity" begin @test batchify_activity(Active{Float64}, Val(2)) == Active{Float64}
- @test batchify_activity(Duplicated{Vector{Float64}}, Val(2)) == BatchDuplicated{Vector{Float64},2}
- @test batchify_activity(DuplicatedNoNeed{Vector{Float64}}, Val(2)) == BatchDuplicatedNoNeed{Vector{Float64},2}
- @test batchify_activity(MixedDuplicated{Tuple{Float64,Vector{Float64}}}, Val(2)) == BatchMixedDuplicated{Tuple{Float64,Vector{Float64}},2}
- @test batchify_activity(Duplicated{Vector{Float64}}, Val(2)) == BatchDuplicated{Vector{Float64}, 2}
- @test batchify_activity(DuplicatedNoNeed{Vector{Float64}}, Val(2)) == BatchDuplicatedNoNeed{Vector{Float64}, 2}
- @test batchify_activity(MixedDuplicated{Tuple{Float64, Vector{Float64}}}, Val(2)) == BatchMixedDuplicated{Tuple{Float64, Vector{Float64}}, 2} end
the base case is a function returning (a(x, y), b(x, y))
@@ -56,11 +56,11 @@ dx_ref = da * 2x * y .+ db * abs2(y) dy_ref = da * sum(abs2, x) + db * sum(x) * 2y dxs_ref = ( das[1] * 2x * y .+ dbs[1] * abs2(y),
- das[2] * 2x * y .+ dbs[2] * abs2(y)
- das[2] * 2x * y .+ dbs[2] * abs2(y), ) dys_ref = ( das[1] * sum(abs2, x) + dbs[1] * sum(x) * 2y,
- das[2] * sum(abs2, x) + dbs[2] * sum(x) * 2y
- das[2] * sum(abs2, x) + dbs[2] * sum(x) * 2y, )
input derivatives, (a+b) case
@@ -69,11 +69,11 @@ dx1_ref = (da + db) * (2x * y .+ abs2(y)) dy1_ref = (da + db) * (sum(abs2, x) + sum(x) * 2y) dxs1_ref = ( (das[1] + dbs[1]) * (2x * y .+ abs2(y)),
- (das[2] + dbs[2]) * (2x * y .+ abs2(y))
- (das[2] + dbs[2]) * (2x * y .+ abs2(y)), ) dys1_ref = ( (das[1] + dbs[1]) * (sum(abs2, x) + sum(x) * 2y),
- (das[2] + dbs[2]) * (sum(abs2, x) + sum(x) * 2y)
- (das[2] + dbs[2]) * (sum(abs2, x) + sum(x) * 2y), )
output seeds, weird cases
@@ -99,7 +99,7 @@ dzs6 = (MyMixedStruct(das[1], [dbs[1]]), MyMixedStruct(das[2], [dbs[2]]))
validation
function validate_seeded_autodiff(f, dz, dzs)
- @testset for mode in (Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal)
- return @testset for mode in (Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal) @testset "Simple" begin dx = make_zero(x) dinputs_and_maybe_result = autodiff(mode, Const(f), Seed(dz), Duplicated(x, dx), Active(y))
</details>
@vchuravy how do I generalize addition beyond arrays to accumulate the adjoint into the shadow?
@vchuravy any further comments?
Only when looking at the examples I feel like we should fuse Active, Seed(dresult), maybe Seed{Active}(dresult) and then we could add Seedx(::Float64) = Seed{Active}(x)
I guess that requires some form of automatic activity detection in the general case. Should I use guess_activity?
Bump @vchuravy, in general do we want automatic activity detection for the seed?
Sorry was on vacation.
automatic activity detection for the seed?
Yes I think so. At least it is consistent with current activity deduction for return, but allows specification of the seed value,
I put the activity detection inside autodiff itself because guess_activity needs the mode object too. And I added a little batch_activity routine because guess_activity never returns a BatchDuplicated. Let me know what you think @vchuravy
Ok last thing would be to bump the version in EnzymeCore since we are exporting new types.
Are we making batchify_activity public or not? Right now I didn't export it
Oops, still need to bump EnzymeCore compat inside Enzyme
@vchuravy the benchmark and integration tests fail because I imposed an updated EnzymeCore compat bound in Enzyme, but these CI processes try to install EnzymeCore v0.8.9 from the general registry, where it is not available
The only CI failures that are new compared to main are the integration tests. Those fail because they try to download EnzymeCore from the general registry instead of installing the local version like they do for Enzyme. That's probably something we want to fix separately.
If you're satisfied with the content of this PR, perhaps performance optimizations like https://github.com/EnzymeAD/Enzyme.jl/pull/2309#discussion_r2067905040 could be done at a later stage? thunk is not documented anywhere anyway, so I don't exactly know what you expect and I think that kind of change should rather be done by the maintainers.
Can you address the other comments at least (eg moving the utility out of enzymecore since the semantics may change when we add batchactive and don’t want to force a breaking change), and also add more extensive tests for seeds that would normally be active (eg float), duplicated (eg array) and mixed (eg tuple of both)
moving the utility out of enzymecore since the semantics may change when we add batchactive and don’t want to force a breaking change
Are you talking about batchify_activity?
What difference does it make if it lives in EnzymeCore or Enzyme?
add more extensive tests for seeds that would normally be active (eg float), duplicated (eg array) and mixed (eg tuple of both)
Tests of what, batchify_activity? What do you mean by "would normally be"?
I mean having tests with seed(1.0), seen([1.0]), seed((1.0, [1.0])) and similar friends (since each of these end up as different ABIs and I can imagine failures).
and yeah moving batchify_activity as an internal utility within Enzyme proper for now (and we can change its semantics without a breaking change)
I mean having tests with seed(1.0), seen([1.0]), seed((1.0, [1.0])) and similar friends (since each of these end up as different ABIs and I can imagine failures).
Okay, I can add tests for the tuple seed, but there are already tests for the number and vector seeds, do you think those are insufficient?
yeah its the mixed case I'm more concerned about, and I'd probably sprinkle it with a bunch of different seed types, (e.g. a tuple of floats, a non-mutable struct of floats, a mutable-struct of floats, a vector of tuple of floats, etc, etc). These all end up with different ABI's so could crash so we should test (especially since without making the old autodiff call this one with the default seed of one as proposed above, this won't get well tested for different ABIs otherwise)
I addeed tests for the following 6 output types: scalar, vector, tuple, struct of floats, mutable struct of floats, struct of float and vector. The first 5 work well, the last one errors in create_abi_wrapper (example stack trace below). Can you take a look?
Got exception outside of a @test
AssertionError: ; Function Attrs: alwaysinline mustprogress nofree
define internal { [2 x double] } @diffe2julia_f5_34949({} addrspace(10)* nocapture nofree readonly align 8 dereferenceable(24) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Pointer, [-1,8,0]:Integer, [-1,8,1]:Integer, [-1,8,2]:Integer, [-1,8,3]:Integer, [-1,8,4]:Integer, [-1,8,5]:Integer, [-1,8,6]:Integer, [-1,8,7]:Integer, [-1,8,8]:Pointer, [-1,8,8,-1]:Float@double, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer}" "enzymejl_parmtype"="5268576208" "enzymejl_parmtype_ref"="2" %0, [2 x {} addrspace(10)*] %"'", double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="5322450160" "enzymejl_parmtype_ref"="0" %1, [2 x {} addrspace(10)*] %differeturn, { { {} addrspace(10)*, double, double, double, double* }, double, [2 x {} addrspace(10)*], double } %tapeArg) local_unnamed_addr #18 !dbg !902 {
top:
%"iv'ac" = alloca i64, align 8
%"'de" = alloca [2 x {} addrspace(10)*], align 8
%2 = getelementptr [2 x {} addrspace(10)*], [2 x {} addrspace(10)*]* %"'de", i64 0, i32 0
store {} addrspace(10)* @ejl_jl_nothing, {} addrspace(10)** %2, align 8
%3 = getelementptr [2 x {} addrspace(10)*], [2 x {} addrspace(10)*]* %"'de", i64 0, i32 1
store {} addrspace(10)* @ejl_jl_nothing, {} addrspace(10)** %3, align 8
%"'de7" = alloca [2 x double], align 8
%4 = getelementptr [2 x double], [2 x double]* %"'de7", i64 0, i32 0
store double 0.000000e+00, double* %4, align 8
%5 = getelementptr [2 x double], [2 x double]* %"'de7", i64 0, i32 1
store double 0.000000e+00, double* %5, align 8
%"'de9" = alloca [2 x double], align 8
%6 = getelementptr [2 x double], [2 x double]* %"'de9", i64 0, i32 0
store double 0.000000e+00, double* %6, align 8
%7 = getelementptr [2 x double], [2 x double]* %"'de9", i64 0, i32 1
store double 0.000000e+00, double* %7, align 8
%"'de10" = alloca [2 x double], align 8
%8 = getelementptr [2 x double], [2 x double]* %"'de10", i64 0, i32 0
store double 0.000000e+00, double* %8, align 8
%9 = getelementptr [2 x double], [2 x double]* %"'de10", i64 0, i32 1
store double 0.000000e+00, double* %9, align 8
%"'ip_phi_cache" = alloca [2 x {} addrspace(10)* addrspace(13)*], align 8
%"'de38" = alloca [2 x double], align 8
%10 = getelementptr [2 x double], [2 x double]* %"'de38", i64 0, i32 0
store double 0.000000e+00, double* %10, align 8
%11 = getelementptr [2 x double], [2 x double]* %"'de38", i64 0, i32 1
store double 0.000000e+00, double* %11, align 8
%"'de39" = alloca [2 x double], align 8
%12 = getelementptr [2 x double], [2 x double]* %"'de39", i64 0, i32 0
store double 0.000000e+00, double* %12, align 8
%13 = getelementptr [2 x double], [2 x double]* %"'de39", i64 0, i32 1
store double 0.000000e+00, double* %13, align 8
%"'de40" = alloca [2 x double], align 8
%14 = getelementptr [2 x double], [2 x double]* %"'de40", i64 0, i32 0
store double 0.000000e+00, double* %14, align 8
%15 = getelementptr [2 x double], [2 x double]* %"'de40", i64 0, i32 1
store double 0.000000e+00, double* %15, align 8
%"'ip_phi4_cache" = alloca [2 x {} addrspace(10)* addrspace(13)*], align 8
%"'de83" = alloca [2 x double], align 8
%16 = getelementptr [2 x double], [2 x double]* %"'de83", i64 0, i32 0
store double 0.000000e+00, double* %16, align 8
%17 = getelementptr [2 x double], [2 x double]* %"'de83", i64 0, i32 1
store double 0.000000e+00, double* %17, align 8
%"value_phi325'de" = alloca [2 x double], align 8
%18 = getelementptr [2 x double], [2 x double]* %"value_phi325'de", i64 0, i32 0
store double 0.000000e+00, double* %18, align 8
%19 = getelementptr [2 x double], [2 x double]* %"value_phi325'de", i64 0, i32 1
store double 0.000000e+00, double* %19, align 8
%"'de84" = alloca [2 x double], align 8
%20 = getelementptr [2 x double], [2 x double]* %"'de84", i64 0, i32 0
store double 0.000000e+00, double* %20, align 8
%21 = getelementptr [2 x double], [2 x double]* %"'de84", i64 0, i32 1
store double 0.000000e+00, double* %21, align 8
%"'de109" = alloca [2 x double], align 8
%22 = getelementptr [2 x double], [2 x double]* %"'de109", i64 0, i32 0
store double 0.000000e+00, double* %22, align 8
%23 = getelementptr [2 x double], [2 x double]* %"'de109", i64 0, i32 1
store double 0.000000e+00, double* %23, align 8
%"'de119" = alloca [2 x double], align 8
%24 = getelementptr [2 x double], [2 x double]* %"'de119", i64 0, i32 0
store double 0.000000e+00, double* %24, align 8
%25 = getelementptr [2 x double], [2 x double]* %"'de119", i64 0, i32 1
store double 0.000000e+00, double* %25, align 8
%"'de128" = alloca [2 x double], align 8
%26 = getelementptr [2 x double], [2 x double]* %"'de128", i64 0, i32 0
store double 0.000000e+00, double* %26, align 8
%27 = getelementptr [2 x double], [2 x double]* %"'de128", i64 0, i32 1
store double 0.000000e+00, double* %27, align 8
%"'de134" = alloca [2 x double], align 8
%28 = getelementptr [2 x double], [2 x double]* %"'de134", i64 0, i32 0
store double 0.000000e+00, double* %28, align 8
%29 = getelementptr [2 x double], [2 x double]* %"'de134", i64 0, i32 1
store double 0.000000e+00, double* %29, align 8
%"value_phi'de" = alloca [2 x double], align 8
%30 = getelementptr [2 x double], [2 x double]* %"value_phi'de", i64 0, i32 0
store double 0.000000e+00, double* %30, align 8
%31 = getelementptr [2 x double], [2 x double]* %"value_phi'de", i64 0, i32 1
store double 0.000000e+00, double* %31, align 8
%_cache = alloca i8, align 1
%_cache139 = alloca i8, align 1
%pgcstack = call {}*** @julia.get_pgcstack() #20
%tapeArg6 = extractvalue { { {} addrspace(10)*, double, double, double, double* }, double, [2 x {} addrspace(10)*], double } %tapeArg, 0, !dbg !903
%32 = extractvalue { { {} addrspace(10)*, double, double, double, double* }, double, [2 x {} addrspace(10)*], double } %tapeArg, 1, !dbg !903
%33 = bitcast {} addrspace(10)* %0 to i8 addrspace(10)*, !dbg !912
%34 = addrspacecast i8 addrspace(10)* %33 to i8 addrspace(11)*, !dbg !912
%35 = getelementptr inbounds i8, i8 addrspace(11)* %34, i64 16, !dbg !912
%36 = bitcast i8 addrspace(11)* %35 to i64 addrspace(11)*, !dbg !912
%37 = load i64, i64 addrspace(11)* %36, align 8, !dbg !912, !tbaa !30, !alias.scope !926, !noalias !929, !enzyme_type !41, !enzymejl_source_type_Int64 !0, !enzymejl_byref_BITS_VALUE !0, !enzyme_inactive !0
store i8 1, i8* %_cache, align 1, !dbg !932, !invariant.group !933
store i8 4, i8* %_cache139, align 1, !dbg !932, !invariant.group !934
switch i64 %37, label %L34 [
i64 0, label %L111
i64 1, label %L29
], !dbg !932
L29: ; preds = %top
%38 = extractvalue [2 x {} addrspace(10)*] %"'", 0, !dbg !935
%"'ipc18" = bitcast {} addrspace(10)* %38 to { i8*, {} addrspace(10)* } addrspace(10)*, !dbg !935
%39 = extractvalue [2 x {} addrspace(10)*] %"'", 1, !dbg !935
%"'ipc19" = bitcast {} addrspace(10)* %39 to { i8*, {} addrspace(10)* } addrspace(10)*, !dbg !935
%"'ipc20" = addrspacecast { i8*, {} addrspace(10)* } addrspace(10)* %"'ipc18" to { i8*, {} addrspace(10)* } addrspace(11)*, !dbg !935
%"'ipc21" = addrspacecast { i8*, {} addrspace(10)* } addrspace(10)* %"'ipc19" to { i8*, {} addrspace(10)* } addrspace(11)*, !dbg !935
%40 = extractvalue [2 x {} addrspace(10)*] %"'", 0, !dbg !935
%"'ipc26" = bitcast {} addrspace(10)* %40 to {} addrspace(10)** addrspace(10)*, !dbg !935
%41 = extractvalue [2 x {} addrspace(10)*] %"'", 1, !dbg !935
%"'ipc27" = bitcast {} addrspace(10)* %41 to {} addrspace(10)** addrspace(10)*, !dbg !935
%"'ipc28" = addrspacecast {} addrspace(10)** addrspace(10)* %"'ipc26" to {} addrspace(10)** addrspace(11)*, !dbg !935
%"'ipc29" = addrspacecast {} addrspace(10)** addrspace(10)* %"'ipc27" to {} addrspace(10)** addrspace(11)*, !dbg !935
%"'ipl30" = load {} addrspace(10)**, {} addrspace(10)** addrspace(11)* %"'ipc28", align 8, !dbg !935, !tbaa !49, !alias.scope !937, !noalias !938
%"'ipl31" = load {} addrspace(10)**, {} addrspace(10)** addrspace(11)* %"'ipc29", align 8, !dbg !935, !tbaa !49, !alias.scope !939, !noalias !940
%"'ipg" = getelementptr inbounds { i8*, {} addrspace(10)* }, { i8*, {} addrspace(10)* } addrspace(11)* %"'ipc20", i64 0, i32 1, !dbg !935
%"'ipg22" = getelementptr inbounds { i8*, {} addrspace(10)* }, { i8*, {} addrspace(10)* } addrspace(11)* %"'ipc21", i64 0, i32 1, !dbg !935
%"'ipl" = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %"'ipg", align 8, !dbg !935, !tbaa !49, !alias.scope !937, !noalias !938, !dereferenceable_or_null !54
%"'ipl23" = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %"'ipg22", align 8, !dbg !935, !tbaa !49, !alias.scope !939, !noalias !940, !dereferenceable_or_null !54
%42 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %"'ipl", {} addrspace(10)** %"'ipl30"), !dbg !935
%43 = insertvalue [2 x {} addrspace(10)* addrspace(13)*] undef, {} addrspace(10)* addrspace(13)* %42, 0, !dbg !935
%44 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %"'ipl23", {} addrspace(10)** %"'ipl31"), !dbg !935
%45 = insertvalue [2 x {} addrspace(10)* addrspace(13)*] %43, {} addrspace(10)* addrspace(13)* %44, 1, !dbg !935
store [2 x {} addrspace(10)* addrspace(13)*] %45, [2 x {} addrspace(10)* addrspace(13)*]* %"'ip_phi_cache", align 8, !dbg !935, !invariant.group !941
%"'ipc" = bitcast {} addrspace(10)* addrspace(13)* %42 to double addrspace(13)*, !dbg !935
%"'ipc11" = bitcast {} addrspace(10)* addrspace(13)* %44 to double addrspace(13)*, !dbg !935
store i8 4, i8* %_cache, align 1, !dbg !942, !invariant.group !933
store i8 0, i8* %_cache139, align 1, !dbg !942, !invariant.group !934
br label %L111, !dbg !942
L34: ; preds = %top
%46 = icmp sgt i64 %37, 15, !dbg !943
br i1 %46, label %L99, label %L50, !dbg !944
L50: ; preds = %L34
%47 = extractvalue [2 x {} addrspace(10)*] %"'", 0, !dbg !945
%"'ipc61" = bitcast {} addrspace(10)* %47 to { i8*, {} addrspace(10)* } addrspace(10)*, !dbg !945
%48 = extractvalue [2 x {} addrspace(10)*] %"'", 1, !dbg !945
%"'ipc62" = bitcast {} addrspace(10)* %48 to { i8*, {} addrspace(10)* } addrspace(10)*, !dbg !945
%"'ipc63" = addrspacecast { i8*, {} addrspace(10)* } addrspace(10)* %"'ipc61" to { i8*, {} addrspace(10)* } addrspace(11)*, !dbg !945
%"'ipc64" = addrspacecast { i8*, {} addrspace(10)* } addrspace(10)* %"'ipc62" to { i8*, {} addrspace(10)* } addrspace(11)*, !dbg !945
%49 = extractvalue [2 x {} addrspace(10)*] %"'", 0, !dbg !945
%"'ipc71" = bitcast {} addrspace(10)* %49 to {} addrspace(10)** addrspace(10)*, !dbg !945
%50 = extractvalue [2 x {} addrspace(10)*] %"'", 1, !dbg !945
%"'ipc72" = bitcast {} addrspace(10)* %50 to {} addrspace(10)** addrspace(10)*, !dbg !945
%"'ipc73" = addrspacecast {} addrspace(10)** addrspace(10)* %"'ipc71" to {} addrspace(10)** addrspace(11)*, !dbg !945
%"'ipc74" = addrspacecast {} addrspace(10)** addrspace(10)* %"'ipc72" to {} addrspace(10)** addrspace(11)*, !dbg !945
%"'ipl75" = load {} addrspace(10)**, {} addrspace(10)** addrspace(11)* %"'ipc73", align 8, !dbg !945, !tbaa !49, !alias.scope !937, !noalias !938, !invariant.group !947
%"'ipl76" = load {} addrspace(10)**, {} addrspace(10)** addrspace(11)* %"'ipc74", align 8, !dbg !945, !tbaa !49, !alias.scope !939, !noalias !940, !invariant.group !948
%"'ipg65" = getelementptr inbounds { i8*, {} addrspace(10)* }, { i8*, {} addrspace(10)* } addrspace(11)* %"'ipc63", i64 0, i32 1, !dbg !945
%"'ipg66" = getelementptr inbounds { i8*, {} addrspace(10)* }, { i8*, {} addrspace(10)* } addrspace(11)* %"'ipc64", i64 0, i32 1, !dbg !945
%"'ipl67" = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %"'ipg65", align 8, !dbg !945, !tbaa !49, !alias.scope !937, !noalias !938, !invariant.group !949
%"'ipl68" = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %"'ipg66", align 8, !dbg !945, !tbaa !49, !alias.scope !939, !noalias !940, !invariant.group !950
%51 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %"'ipl67", {} addrspace(10)** %"'ipl75"), !dbg !945
%52 = insertvalue [2 x {} addrspace(10)* addrspace(13)*] undef, {} addrspace(10)* addrspace(13)* %51, 0, !dbg !945
%53 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %"'ipl68", {} addrspace(10)** %"'ipl76"), !dbg !945
%54 = insertvalue [2 x {} addrspace(10)* addrspace(13)*] %52, {} addrspace(10)* addrspace(13)* %53, 1, !dbg !945
store [2 x {} addrspace(10)* addrspace(13)*] %54, [2 x {} addrspace(10)* addrspace(13)*]* %"'ip_phi4_cache", align 8, !dbg !945, !invariant.group !951
%"'ipc52" = bitcast {} addrspace(10)* addrspace(13)* %51 to double addrspace(13)*, !dbg !945
%"'ipc53" = bitcast {} addrspace(10)* addrspace(13)* %53 to double addrspace(13)*, !dbg !945
%"'ipg41" = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %51, i64 1, !dbg !952
%"'ipg42" = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %53, i64 1, !dbg !952
%"'ipc43" = bitcast {} addrspace(10)* addrspace(13)* %"'ipg41" to double addrspace(13)*, !dbg !952
%"'ipc44" = bitcast {} addrspace(10)* addrspace(13)* %"'ipg42" to double addrspace(13)*, !dbg !952
%.not2324 = icmp sgt i64 %37, 2, !dbg !954
store i8 3, i8* %_cache, align 1, !dbg !955, !invariant.group !933
store i8 3, i8* %_cache139, align 1, !dbg !955, !invariant.group !934
br i1 %.not2324, label %L77.preheader, label %L111, !dbg !955
L77.preheader: ; preds = %L50
%55 = add i64 %37, -3, !dbg !955
br label %L77, !dbg !955
L77: ; preds = %L77, %L77.preheader
%iv = phi i64 [ 0, %L77.preheader ], [ %iv.next, %L77 ]
%iv.next = add nuw nsw i64 %iv, 1, !dbg !956
%56 = add nuw nsw i64 %iv, 2, !dbg !956
%57 = add nuw nsw i64 %56, 1, !dbg !956
%"'ipg85" = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %51, i64 %56, !dbg !958
%"'ipg86" = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %53, i64 %56, !dbg !958
%"'ipc87" = bitcast {} addrspace(10)* addrspace(13)* %"'ipg85" to double addrspace(13)*, !dbg !958
%"'ipc88" = bitcast {} addrspace(10)* addrspace(13)* %"'ipg86" to double addrspace(13)*, !dbg !958
%exitcond.not = icmp eq i64 %57, %37, !dbg !954
br i1 %exitcond.not, label %L111.loopexit, label %L77, !dbg !955
L99: ; preds = %L34
store i8 0, i8* %_cache, align 1, !dbg !959, !invariant.group !933
store i8 1, i8* %_cache139, align 1, !dbg !959, !invariant.group !934
br label %L111, !dbg !959
L111.loopexit: ; preds = %L77
store i8 2, i8* %_cache, align 1, !dbg !960, !invariant.group !933
store i8 2, i8* %_cache139, align 1, !dbg !960, !invariant.group !934
br label %L111, !dbg !960
L111: ; preds = %L111.loopexit, %L99, %L50, %L29, %top
%value_phi = extractvalue { { {} addrspace(10)*, double, double, double, double* }, double, [2 x {} addrspace(10)*], double } %tapeArg, 3
%current_task118 = getelementptr inbounds {}**, {}*** %pgcstack, i64 -14
%58 = bitcast {}*** %current_task118 to {}*
%59 = fmul double %1, %1, !dbg !961
%"'mi" = call noalias nonnull align 8 dereferenceable(16) "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" {} addrspace(10)* @julia.gc_alloc_obj({}* nonnull %58, i64 noundef 16, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4546560336 to {}*) to {} addrspace(10)*)) #21, !dbg !963
%"'mi131" = call noalias nonnull align 8 dereferenceable(16) "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" {} addrspace(10)* @julia.gc_alloc_obj({}* nonnull %58, i64 noundef 16, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4546560336 to {}*) to {} addrspace(10)*)) #21, !dbg !963
%60 = extractvalue { { {} addrspace(10)*, double, double, double, double* }, double, [2 x {} addrspace(10)*], double } %tapeArg, 2, !dbg !963
%61 = extractvalue [2 x {} addrspace(10)*] %60, 0, !dbg !963
%"'ipc124" = bitcast {} addrspace(10)* %61 to double addrspace(10)*, !dbg !963
%62 = extractvalue [2 x {} addrspace(10)*] %60, 1, !dbg !963
%"'ipc125" = bitcast {} addrspace(10)* %62 to double addrspace(10)*, !dbg !963
%"'ipc126" = addrspacecast double addrspace(10)* %"'ipc124" to double addrspace(11)*, !dbg !963
%"'ipc127" = addrspacecast double addrspace(10)* %"'ipc125" to double addrspace(11)*, !dbg !963
%63 = extractvalue [2 x {} addrspace(10)*] %60, 0, !dbg !963
%"'ipc111" = bitcast {} addrspace(10)* %63 to i8 addrspace(10)*, !dbg !963
%64 = extractvalue [2 x {} addrspace(10)*] %60, 1, !dbg !963
%"'ipc112" = bitcast {} addrspace(10)* %64 to i8 addrspace(10)*, !dbg !963
%"'ipc113" = addrspacecast i8 addrspace(10)* %"'ipc111" to i8 addrspace(11)*, !dbg !963
%"'ipc114" = addrspacecast i8 addrspace(10)* %"'ipc112" to i8 addrspace(11)*, !dbg !963
%"'ipg115" = getelementptr inbounds i8, i8 addrspace(11)* %"'ipc113", i64 8, !dbg !963
%"'ipg116" = getelementptr inbounds i8, i8 addrspace(11)* %"'ipc114", i64 8, !dbg !963
%"'ipc117" = bitcast i8 addrspace(11)* %"'ipg115" to double addrspace(11)*, !dbg !963
%"'ipc118" = bitcast i8 addrspace(11)* %"'ipg116" to double addrspace(11)*, !dbg !963
br label %invertL111, !dbg !963
inverttop: ; preds = %invertL111, %invertL34, %invertL29
%65 = load [2 x double], [2 x double]* %"'de7", align 8, !dbg !903
call fastcc void @diffe2julia__mapreduce_34996({} addrspace(10)* nocapture nofree readonly align 8 %0, [2 x {} addrspace(10)*] %"'", [2 x double] %65, { {} addrspace(10)*, double, double, double, double* } %tapeArg6), !dbg !903
store [2 x double] zeroinitializer, [2 x double]* %"'de7", align 8, !dbg !903
%66 = load [2 x double], [2 x double]* %"'de9", align 8
%67 = insertvalue { [2 x double] } undef, [2 x double] %66, 0
ret { [2 x double] } %67
invertL29: ; preds = %invertL111
%68 = load [2 x double], [2 x double]* %"'de10", align 8, !dbg !935
store [2 x double] zeroinitializer, [2 x double]* %"'de10", align 8, !dbg !935
%69 = load [2 x {} addrspace(10)* addrspace(13)*], [2 x {} addrspace(10)* addrspace(13)*]* %"'ip_phi_cache", align 8, !dbg !935, !invariant.group !941
%_unwrap = extractvalue [2 x {} addrspace(10)* addrspace(13)*] %69, 1, !dbg !935
%"'ipc11_unwrap" = bitcast {} addrspace(10)* addrspace(13)* %_unwrap to double addrspace(13)*, !dbg !935
%_unwrap12 = extractvalue [2 x {} addrspace(10)* addrspace(13)*] %69, 0, !dbg !935
%"'ipc_unwrap" = bitcast {} addrspace(10)* addrspace(13)* %_unwrap12 to double addrspace(13)*, !dbg !935
%70 = extractvalue [2 x double] %68, 0, !dbg !935
%71 = load double, double addrspace(13)* %"'ipc_unwrap", align 8, !dbg !935, !tbaa !58, !alias.scope !964, !noalias !967
%72 = fadd fast double %71, %70, !dbg !935
store double %72, double addrspace(13)* %"'ipc_unwrap", align 8, !dbg !935, !tbaa !58, !alias.scope !964, !noalias !967
%73 = extractvalue [2 x double] %68, 1, !dbg !935
%74 = load double, double addrspace(13)* %"'ipc11_unwrap", align 8, !dbg !935, !tbaa !58, !alias.scope !970, !noalias !971
%75 = fadd fast double %74, %73, !dbg !935
store double %75, double addrspace(13)* %"'ipc11_unwrap", align 8, !dbg !935, !tbaa !58, !alias.scope !970, !noalias !971
br label %inverttop
invertL34: ; preds = %invertL99, %invertL50
br label %inverttop
invertL50: ; preds = %invertL111, %invertL77.preheader
%76 = load [2 x double], [2 x double]* %"'de38", align 8, !dbg !972
store [2 x double] zeroinitializer, [2 x double]* %"'de38", align 8, !dbg !972
%77 = extractvalue [2 x double] %76, 0, !dbg !972
%78 = extractvalue [2 x double] %76, 1, !dbg !972
%79 = load [2 x double], [2 x double]* %"'de39", align 8, !dbg !972
%80 = getelementptr inbounds [2 x double], [2 x double]* %"'de39", i32 0, i32 0, !dbg !972
%81 = load double, double* %80, align 8, !dbg !972
%82 = fadd fast double %81, %77, !dbg !972
store double %82, double* %80, align 8, !dbg !972
%83 = getelementptr inbounds [2 x double], [2 x double]* %"'de39", i32 0, i32 1, !dbg !972
%84 = load double, double* %83, align 8, !dbg !972
%85 = fadd fast double %84, %78, !dbg !972
store double %85, double* %83, align 8, !dbg !972
%86 = extractvalue [2 x double] %76, 0, !dbg !972
%87 = extractvalue [2 x double] %76, 1, !dbg !972
%88 = load [2 x double], [2 x double]* %"'de40", align 8, !dbg !972
%89 = getelementptr inbounds [2 x double], [2 x double]* %"'de40", i32 0, i32 0, !dbg !972
%90 = load double, double* %89, align 8, !dbg !972
%91 = fadd fast double %90, %86, !dbg !972
store double %91, double* %89, align 8, !dbg !972
%92 = getelementptr inbounds [2 x double], [2 x double]* %"'de40", i32 0, i32 1, !dbg !972
%93 = load double, double* %92, align 8, !dbg !972
%94 = fadd fast double %93, %87, !dbg !972
store double %94, double* %92, align 8, !dbg !972
%95 = load [2 x double], [2 x double]* %"'de40", align 8, !dbg !952
store [2 x double] zeroinitializer, [2 x double]* %"'de40", align 8, !dbg !952
%96 = load [2 x {} addrspace(10)* addrspace(13)*], [2 x {} addrspace(10)* addrspace(13)*]* %"'ip_phi4_cache", align 8, !dbg !952, !invariant.group !951
%_unwrap45 = extractvalue [2 x {} addrspace(10)* addrspace(13)*] %96, 1, !dbg !952
%"'ipg42_unwrap" = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %_unwrap45, i64 1, !dbg !952
%"'ipc44_unwrap" = bitcast {} addrspace(10)* addrspace(13)* %"'ipg42_unwrap" to double addrspace(13)*, !dbg !952
%_unwrap46 = extractvalue [2 x {} addrspace(10)* addrspace(13)*] %96, 0, !dbg !952
%"'ipg41_unwrap" = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %_unwrap46, i64 1, !dbg !952
%"'ipc43_unwrap" = bitcast {} addrspace(10)* addrspace(13)* %"'ipg41_unwrap" to double addrspace(13)*, !dbg !952
%97 = extractvalue [2 x double] %95, 0, !dbg !952
%98 = load double, double addrspace(13)* %"'ipc43_unwrap", align 8, !dbg !952, !tbaa !58, !alias.scope !975, !noalias !978
%99 = fadd fast double %98, %97, !dbg !952
store double %99, double addrspace(13)* %"'ipc43_unwrap", align 8, !dbg !952, !tbaa !58, !alias.scope !975, !noalias !978
%100 = extractvalue [2 x double] %95, 1, !dbg !952
%101 = load double, double addrspace(13)* %"'ipc44_unwrap", align 8, !dbg !952, !tbaa !58, !alias.scope !981, !noalias !982
%102 = fadd fast double %101, %100, !dbg !952
store double %102, double addrspace(13)* %"'ipc44_unwrap", align 8, !dbg !952, !tbaa !58, !alias.scope !981, !noalias !982
%103 = load [2 x double], [2 x double]* %"'de39", align 8, !dbg !945
store [2 x double] zeroinitializer, [2 x double]* %"'de39", align 8, !dbg !945
%_unwrap54 = extractvalue [2 x {} addrspace(10)* addrspace(13)*] %96, 1, !dbg !945
%"'ipc53_unwrap" = bitcast {} addrspace(10)* addrspace(13)* %_unwrap54 to double addrspace(13)*, !dbg !945
%_unwrap55 = extractvalue [2 x {} addrspace(10)* addrspace(13)*] %96, 0, !dbg !945
%"'ipc52_unwrap" = bitcast {} addrspace(10)* addrspace(13)* %_unwrap55 to double addrspace(13)*, !dbg !945
%104 = extractvalue [2 x double] %103, 0, !dbg !945
%105 = load double, double addrspace(13)* %"'ipc52_unwrap", align 8, !dbg !945, !tbaa !58, !alias.scope !975, !noalias !978
%106 = fadd fast double %105, %104, !dbg !945
store double %106, double addrspace(13)* %"'ipc52_unwrap", align 8, !dbg !945, !tbaa !58, !alias.scope !975, !noalias !978
%107 = extractvalue [2 x double] %103, 1, !dbg !945
%108 = load double, double addrspace(13)* %"'ipc53_unwrap", align 8, !dbg !945, !tbaa !58, !alias.scope !981, !noalias !982
%109 = fadd fast double %108, %107, !dbg !945
store double %109, double addrspace(13)* %"'ipc53_unwrap", align 8, !dbg !945, !tbaa !58, !alias.scope !981, !noalias !982
br label %invertL34
invertL77.preheader: ; preds = %invertL77
br label %invertL50
invertL77: ; preds = %mergeinvertL77_L111.loopexit, %incinvertL77
%110 = load [2 x double], [2 x double]* %"'de83", align 8, !dbg !983
store [2 x double] zeroinitializer, [2 x double]* %"'de83", align 8, !dbg !983
%111 = extractvalue [2 x double] %110, 0, !dbg !983
%112 = extractvalue [2 x double] %110, 1, !dbg !983
%113 = load [2 x double], [2 x double]* %"value_phi325'de", align 8, !dbg !983
%114 = getelementptr inbounds [2 x double], [2 x double]* %"value_phi325'de", i32 0, i32 0, !dbg !983
%115 = load double, double* %114, align 8, !dbg !983
%116 = fadd fast double %115, %111, !dbg !983
store double %116, double* %114, align 8, !dbg !983
%117 = getelementptr inbounds [2 x double], [2 x double]* %"value_phi325'de", i32 0, i32 1, !dbg !983
%118 = load double, double* %117, align 8, !dbg !983
%119 = fadd fast double %118, %112, !dbg !983
store double %119, double* %117, align 8, !dbg !983
%120 = extractvalue [2 x double] %110, 0, !dbg !983
%121 = extractvalue [2 x double] %110, 1, !dbg !983
%122 = load [2 x double], [2 x double]* %"'de84", align 8, !dbg !983
%123 = getelementptr inbounds [2 x double], [2 x double]* %"'de84", i32 0, i32 0, !dbg !983
%124 = load double, double* %123, align 8, !dbg !983
%125 = fadd fast double %124, %120, !dbg !983
store double %125, double* %123, align 8, !dbg !983
%126 = getelementptr inbounds [2 x double], [2 x double]* %"'de84", i32 0, i32 1, !dbg !983
%127 = load double, double* %126, align 8, !dbg !983
%128 = fadd fast double %127, %121, !dbg !983
store double %128, double* %126, align 8, !dbg !983
%129 = load [2 x double], [2 x double]* %"'de84", align 8, !dbg !958
store [2 x double] zeroinitializer, [2 x double]* %"'de84", align 8, !dbg !958
%130 = load i64, i64* %"iv'ac", align 8, !dbg !958
%_unwrap89 = extractvalue [2 x {} addrspace(10)*] %"'", 1, !dbg !958
%"'ipc62_unwrap" = bitcast {} addrspace(10)* %_unwrap89 to { i8*, {} addrspace(10)* } addrspace(10)*, !dbg !958
%"'ipc64_unwrap" = addrspacecast { i8*, {} addrspace(10)* } addrspace(10)* %"'ipc62_unwrap" to { i8*, {} addrspace(10)* } addrspace(11)*, !dbg !958
%"'ipg66_unwrap" = getelementptr inbounds { i8*, {} addrspace(10)* }, { i8*, {} addrspace(10)* } addrspace(11)* %"'ipc64_unwrap", i64 0, i32 1, !dbg !958
%"'ipl68_unwrap" = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %"'ipg66_unwrap", align 8, !dbg !945, !tbaa !49, !alias.scope !939, !noalias !940, !invariant.group !950
%_unwrap90 = extractvalue [2 x {} addrspace(10)*] %"'", 0, !dbg !958
%"'ipc61_unwrap" = bitcast {} addrspace(10)* %_unwrap90 to { i8*, {} addrspace(10)* } addrspace(10)*, !dbg !958
%"'ipc63_unwrap" = addrspacecast { i8*, {} addrspace(10)* } addrspace(10)* %"'ipc61_unwrap" to { i8*, {} addrspace(10)* } addrspace(11)*, !dbg !958
%"'ipg65_unwrap" = getelementptr inbounds { i8*, {} addrspace(10)* }, { i8*, {} addrspace(10)* } addrspace(11)* %"'ipc63_unwrap", i64 0, i32 1, !dbg !958
%"'ipl67_unwrap" = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %"'ipg65_unwrap", align 8, !dbg !945, !tbaa !49, !alias.scope !937, !noalias !938, !invariant.group !949
%_unwrap94 = extractvalue [2 x {} addrspace(10)*] %"'", 1, !dbg !958
%"'ipc72_unwrap" = bitcast {} addrspace(10)* %_unwrap94 to {} addrspace(10)** addrspace(10)*, !dbg !958
%"'ipc74_unwrap" = addrspacecast {} addrspace(10)** addrspace(10)* %"'ipc72_unwrap" to {} addrspace(10)** addrspace(11)*, !dbg !958
%"'ipl76_unwrap" = load {} addrspace(10)**, {} addrspace(10)** addrspace(11)* %"'ipc74_unwrap", align 8, !dbg !945, !tbaa !49, !alias.scope !939, !noalias !940, !invariant.group !948
%_unwrap95 = extractvalue [2 x {} addrspace(10)*] %"'", 0, !dbg !958
%"'ipc71_unwrap" = bitcast {} addrspace(10)* %_unwrap95 to {} addrspace(10)** addrspace(10)*, !dbg !958
%"'ipc73_unwrap" = addrspacecast {} addrspace(10)** addrspace(10)* %"'ipc71_unwrap" to {} addrspace(10)** addrspace(11)*, !dbg !958
%"'ipl75_unwrap" = load {} addrspace(10)**, {} addrspace(10)** addrspace(11)* %"'ipc73_unwrap", align 8, !dbg !945, !tbaa !49, !alias.scope !937, !noalias !938, !invariant.group !947
%131 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %"'ipl68_unwrap", {} addrspace(10)** %"'ipl76_unwrap"), !dbg !945
%_unwrap99 = add nuw nsw i64 %130, 2, !dbg !958
%"'ipg86_unwrap" = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %131, i64 %_unwrap99, !dbg !958
%"'ipc88_unwrap" = bitcast {} addrspace(10)* addrspace(13)* %"'ipg86_unwrap" to double addrspace(13)*, !dbg !958
%132 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %"'ipl67_unwrap", {} addrspace(10)** %"'ipl75_unwrap"), !dbg !945
%"'ipg85_unwrap" = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %132, i64 %_unwrap99, !dbg !958
%"'ipc87_unwrap" = bitcast {} addrspace(10)* addrspace(13)* %"'ipg85_unwrap" to double addrspace(13)*, !dbg !958
%133 = extractvalue [2 x double] %129, 0, !dbg !958
%134 = load double, double addrspace(13)* %"'ipc87_unwrap", align 8, !dbg !958, !tbaa !58, !alias.scope !975, !noalias !978
%135 = fadd fast double %134, %133, !dbg !958
store double %135, double addrspace(13)* %"'ipc87_unwrap", align 8, !dbg !958, !tbaa !58, !alias.scope !975, !noalias !978
%136 = extractvalue [2 x double] %129, 1, !dbg !958
%137 = load double, double addrspace(13)* %"'ipc88_unwrap", align 8, !dbg !958, !tbaa !58, !alias.scope !981, !noalias !982
%138 = fadd fast double %137, %136, !dbg !958
store double %138, double addrspace(13)* %"'ipc88_unwrap", align 8, !dbg !958, !tbaa !58, !alias.scope !981, !noalias !982
%139 = load [2 x double], [2 x double]* %"value_phi325'de", align 8
store [2 x double] zeroinitializer, [2 x double]* %"value_phi325'de", align 8
%140 = load i64, i64* %"iv'ac", align 8
%141 = icmp eq i64 %140, 0
%142 = xor i1 %141, true
%143 = extractvalue [2 x double] %139, 0
%144 = select fast i1 %141, double %143, double 0.000000e+00
%145 = extractvalue [2 x double] %139, 1
%146 = select fast i1 %141, double %145, double 0.000000e+00
%147 = load [2 x double], [2 x double]* %"'de38", align 8
%148 = getelementptr inbounds [2 x double], [2 x double]* %"'de38", i32 0, i32 0
%149 = load double, double* %148, align 8
%150 = fadd fast double %149, %143
%151 = select fast i1 %141, double %150, double %149
store double %151, double* %148, align 8
%152 = getelementptr inbounds [2 x double], [2 x double]* %"'de38", i32 0, i32 1
%153 = load double, double* %152, align 8
%154 = fadd fast double %153, %145
%155 = select fast i1 %141, double %154, double %153
store double %155, double* %152, align 8
%156 = extractvalue [2 x double] %139, 0
%157 = select fast i1 %142, double %156, double 0.000000e+00
%158 = extractvalue [2 x double] %139, 1
%159 = select fast i1 %142, double %158, double 0.000000e+00
%160 = load [2 x double], [2 x double]* %"'de83", align 8
%161 = getelementptr inbounds [2 x double], [2 x double]* %"'de83", i32 0, i32 0
%162 = load double, double* %161, align 8
%163 = fadd fast double %162, %156
%164 = select fast i1 %141, double %162, double %163
store double %164, double* %161, align 8
%165 = getelementptr inbounds [2 x double], [2 x double]* %"'de83", i32 0, i32 1
%166 = load double, double* %165, align 8
%167 = fadd fast double %166, %158
%168 = select fast i1 %141, double %166, double %167
store double %168, double* %165, align 8
br i1 %141, label %invertL77.preheader, label %incinvertL77
incinvertL77: ; preds = %invertL77
%169 = load i64, i64* %"iv'ac", align 8
%170 = add nsw i64 %169, -1
store i64 %170, i64* %"iv'ac", align 8
br label %invertL77
invertL99: ; preds = %invertL111
%171 = load [2 x double], [2 x double]* %"'de109", align 8, !dbg !986
call fastcc void @diffe2julia_mapreduce_impl_34975({} addrspace(10)* nocapture nofree readonly align 8 %0, [2 x {} addrspace(10)*] %"'", i64 signext 1, i64 signext %37, [2 x double] %171), !dbg !986
store [2 x double] zeroinitializer, [2 x double]* %"'de109", align 8, !dbg !986
br label %invertL34
invertL111.loopexit: ; preds = %invertL111
%_unwrap110 = add i64 %37, -3
br label %mergeinvertL77_L111.loopexit
mergeinvertL77_L111.loopexit: ; preds = %invertL111.loopexit
store i64 %_unwrap110, i64* %"iv'ac", align 8
br label %invertL77
invertL111: ; preds = %L111
store [2 x {} addrspace(10)*] %differeturn, [2 x {} addrspace(10)*]* %"'de", align 8
%172 = load double, double addrspace(11)* %"'ipc117", align 8, !dbg !963, !tbaa !180, !alias.scope !987, !noalias !990
%173 = load double, double addrspace(11)* %"'ipc118", align 8, !dbg !963, !tbaa !180, !alias.scope !993, !noalias !994
store double 0.000000e+00, double addrspace(11)* %"'ipc117", align 8, !dbg !963, !tbaa !180, !alias.scope !987, !noalias !990
store double 0.000000e+00, double addrspace(11)* %"'ipc118", align 8, !dbg !963, !tbaa !180, !alias.scope !993, !noalias !994
%174 = getelementptr inbounds [2 x double], [2 x double]* %"'de119", i32 0, i32 0, !dbg !963
%175 = load double, double* %174, align 8, !dbg !963
%176 = fadd fast double %175, %172, !dbg !963
store double %176, double* %174, align 8, !dbg !963
%177 = getelementptr inbounds [2 x double], [2 x double]* %"'de119", i32 0, i32 1, !dbg !963
%178 = load double, double* %177, align 8, !dbg !963
%179 = fadd fast double %178, %173, !dbg !963
store double %179, double* %177, align 8, !dbg !963
%180 = load double, double addrspace(11)* %"'ipc126", align 8, !dbg !963, !tbaa !180, !alias.scope !987, !noalias !990
%181 = load double, double addrspace(11)* %"'ipc127", align 8, !dbg !963, !tbaa !180, !alias.scope !993, !noalias !994
store double 0.000000e+00, double addrspace(11)* %"'ipc126", align 8, !dbg !963, !tbaa !180, !alias.scope !987, !noalias !990
store double 0.000000e+00, double addrspace(11)* %"'ipc127", align 8, !dbg !963, !tbaa !180, !alias.scope !993, !noalias !994
%182 = getelementptr inbounds [2 x double], [2 x double]* %"'de128", i32 0, i32 0, !dbg !963
%183 = load double, double* %182, align 8, !dbg !963
%184 = fadd fast double %183, %180, !dbg !963
store double %184, double* %182, align 8, !dbg !963
%185 = getelementptr inbounds [2 x double], [2 x double]* %"'de128", i32 0, i32 1, !dbg !963
%186 = load double, double* %185, align 8, !dbg !963
%187 = fadd fast double %186, %181, !dbg !963
store double %187, double* %185, align 8, !dbg !963
%188 = extractvalue { { {} addrspace(10)*, double, double, double, double* }, double, [2 x {} addrspace(10)*], double } %tapeArg, 2, 0, !dbg !963
%189 = extractvalue { { {} addrspace(10)*, double, double, double, double* }, double, [2 x {} addrspace(10)*], double } %tapeArg, 2, 1, !dbg !963
%190 = load [2 x double], [2 x double]* %"'de119", align 8, !dbg !995
store [2 x double] zeroinitializer, [2 x double]* %"'de119", align 8, !dbg !995
%191 = extractvalue [2 x double] %190, 0, !dbg !995
%192 = fmul fast double %191, %value_phi, !dbg !995
%193 = extractvalue [2 x double] %190, 1, !dbg !995
%194 = fmul fast double %193, %value_phi, !dbg !995
%195 = load [2 x double], [2 x double]* %"'de134", align 8, !dbg !995
%196 = getelementptr inbounds [2 x double], [2 x double]* %"'de134", i32 0, i32 0, !dbg !995
%197 = load double, double* %196, align 8, !dbg !995
%198 = fadd fast double %197, %192, !dbg !995
store double %198, double* %196, align 8, !dbg !995
%199 = getelementptr inbounds [2 x double], [2 x double]* %"'de134", i32 0, i32 1, !dbg !995
%200 = load double, double* %199, align 8, !dbg !995
%201 = fadd fast double %200, %194, !dbg !995
store double %201, double* %199, align 8, !dbg !995
%202 = extractvalue [2 x double] %190, 0, !dbg !995
%203 = fmul fast double %202, %59, !dbg !995
%204 = extractvalue [2 x double] %190, 1, !dbg !995
%205 = fmul fast double %204, %59, !dbg !995
%206 = load [2 x double], [2 x double]* %"value_phi'de", align 8, !dbg !995
%207 = getelementptr inbounds [2 x double], [2 x double]* %"value_phi'de", i32 0, i32 0, !dbg !995
%208 = load double, double* %207, align 8, !dbg !995
%209 = fadd fast double %208, %203, !dbg !995
store double %209, double* %207, align 8, !dbg !995
%210 = getelementptr inbounds [2 x double], [2 x double]* %"value_phi'de", i32 0, i32 1, !dbg !995
%211 = load double, double* %210, align 8, !dbg !995
%212 = fadd fast double %211, %205, !dbg !995
store double %212, double* %210, align 8, !dbg !995
%213 = load [2 x double], [2 x double]* %"'de134", align 8, !dbg !961
store [2 x double] zeroinitializer, [2 x double]* %"'de134", align 8, !dbg !961
%214 = extractvalue [2 x double] %213, 0, !dbg !961
%215 = fmul fast double %214, %1, !dbg !961
%216 = extractvalue [2 x double] %213, 1, !dbg !961
%217 = fmul fast double %216, %1, !dbg !961
%218 = load [2 x double], [2 x double]* %"'de9", align 8, !dbg !961
%219 = getelementptr inbounds [2 x double], [2 x double]* %"'de9", i32 0, i32 0, !dbg !961
%220 = load double, double* %219, align 8, !dbg !961
%221 = fadd fast double %220, %215, !dbg !961
store double %221, double* %219, align 8, !dbg !961
%222 = getelementptr inbounds [2 x double], [2 x double]* %"'de9", i32 0, i32 1, !dbg !961
%223 = load double, double* %222, align 8, !dbg !961
%224 = fadd fast double %223, %217, !dbg !961
store double %224, double* %222, align 8, !dbg !961
%225 = extractvalue [2 x double] %213, 0, !dbg !961
%226 = fmul fast double %225, %1, !dbg !961
%227 = extractvalue [2 x double] %213, 1, !dbg !961
%228 = fmul fast double %227, %1, !dbg !961
%229 = load [2 x double], [2 x double]* %"'de9", align 8, !dbg !961
%230 = getelementptr inbounds [2 x double], [2 x double]* %"'de9", i32 0, i32 0, !dbg !961
%231 = load double, double* %230, align 8, !dbg !961
%232 = fadd fast double %231, %226, !dbg !961
store double %232, double* %230, align 8, !dbg !961
%233 = getelementptr inbounds [2 x double], [2 x double]* %"'de9", i32 0, i32 1, !dbg !961
%234 = load double, double* %233, align 8, !dbg !961
%235 = fadd fast double %234, %228, !dbg !961
store double %235, double* %233, align 8, !dbg !961
%236 = load [2 x double], [2 x double]* %"'de128", align 8, !dbg !960
store [2 x double] zeroinitializer, [2 x double]* %"'de128", align 8, !dbg !960
%237 = extractvalue [2 x double] %236, 0, !dbg !960
%238 = fmul fast double %237, %1, !dbg !960
%239 = extractvalue [2 x double] %236, 1, !dbg !960
%240 = fmul fast double %239, %1, !dbg !960
%241 = load [2 x double], [2 x double]* %"'de7", align 8, !dbg !960
%242 = getelementptr inbounds [2 x double], [2 x double]* %"'de7", i32 0, i32 0, !dbg !960
%243 = load double, double* %242, align 8, !dbg !960
%244 = fadd fast double %243, %238, !dbg !960
store double %244, double* %242, align 8, !dbg !960
%245 = getelementptr inbounds [2 x double], [2 x double]* %"'de7", i32 0, i32 1, !dbg !960
%246 = load double, double* %245, align 8, !dbg !960
%247 = fadd fast double %246, %240, !dbg !960
store double %247, double* %245, align 8, !dbg !960
%248 = extractvalue [2 x double] %236, 0, !dbg !960
%249 = fmul fast double %248, %32, !dbg !960
%250 = extractvalue [2 x double] %236, 1, !dbg !960
%251 = fmul fast double %250, %32, !dbg !960
%252 = load [2 x double], [2 x double]* %"'de9", align 8, !dbg !960
%253 = getelementptr inbounds [2 x double], [2 x double]* %"'de9", i32 0, i32 0, !dbg !960
%254 = load double, double* %253, align 8, !dbg !960
%255 = fadd fast double %254, %249, !dbg !960
store double %255, double* %253, align 8, !dbg !960
%256 = getelementptr inbounds [2 x double], [2 x double]* %"'de9", i32 0, i32 1, !dbg !960
%257 = load double, double* %256, align 8, !dbg !960
%258 = fadd fast double %257, %251, !dbg !960
store double %258, double* %256, align 8, !dbg !960
%259 = load [2 x double], [2 x double]* %"value_phi'de", align 8
store [2 x double] zeroinitializer, [2 x double]* %"value_phi'de", align 8
%260 = load i8, i8* %_cache, align 1, !invariant.group !933
%261 = icmp eq i8 0, %260
%262 = icmp eq i8 2, %260
%263 = icmp eq i8 3, %260
%264 = icmp eq i8 4, %260
%265 = extractvalue [2 x double] %259, 0
%266 = select fast i1 %262, double %265, double 0.000000e+00
%267 = extractvalue [2 x double] %259, 1
%268 = select fast i1 %262, double %267, double 0.000000e+00
%269 = load [2 x double], [2 x double]* %"'de83", align 8
%270 = getelementptr inbounds [2 x double], [2 x double]* %"'de83", i32 0, i32 0
%271 = load double, double* %270, align 8
%272 = fadd fast double %271, %265
%273 = select fast i1 %262, double %272, double %271
store double %273, double* %270, align 8
%274 = getelementptr inbounds [2 x double], [2 x double]* %"'de83", i32 0, i32 1
%275 = load double, double* %274, align 8
%276 = fadd fast double %275, %267
%277 = select fast i1 %262, double %276, double %275
store double %277, double* %274, align 8
%278 = extractvalue [2 x double] %259, 0
%279 = select fast i1 %261, double %278, double 0.000000e+00
%280 = extractvalue [2 x double] %259, 1
%281 = select fast i1 %261, double %280, double 0.000000e+00
%282 = load [2 x double], [2 x double]* %"'de109", align 8
%283 = getelementptr inbounds [2 x double], [2 x double]* %"'de109", i32 0, i32 0
%284 = load double, double* %283, align 8
%285 = fadd fast double %284, %278
%286 = select fast i1 %261, double %285, double %284
store double %286, double* %283, align 8
%287 = getelementptr inbounds [2 x double], [2 x double]* %"'de109", i32 0, i32 1
%288 = load double, double* %287, align 8
%289 = fadd fast double %288, %280
%290 = select fast i1 %261, double %289, double %288
store double %290, double* %287, align 8
%291 = extractvalue [2 x double] %259, 0
%292 = select fast i1 %263, double %291, double 0.000000e+00
%293 = extractvalue [2 x double] %259, 1
%294 = select fast i1 %263, double %293, double 0.000000e+00
%295 = load [2 x double], [2 x double]* %"'de38", align 8
%296 = getelementptr inbounds [2 x double], [2 x double]* %"'de38", i32 0, i32 0
%297 = load double, double* %296, align 8
%298 = fadd fast double %297, %291
%299 = select fast i1 %263, double %298, double %297
store double %299, double* %296, align 8
%300 = getelementptr inbounds [2 x double], [2 x double]* %"'de38", i32 0, i32 1
%301 = load double, double* %300, align 8
%302 = fadd fast double %301, %293
%303 = select fast i1 %263, double %302, double %301
store double %303, double* %300, align 8
%304 = extractvalue [2 x double] %259, 0
%305 = select fast i1 %264, double %304, double 0.000000e+00
%306 = extractvalue [2 x double] %259, 1
%307 = select fast i1 %264, double %306, double 0.000000e+00
%308 = load [2 x double], [2 x double]* %"'de10", align 8
%309 = getelementptr inbounds [2 x double], [2 x double]* %"'de10", i32 0, i32 0
%310 = load double, double* %309, align 8
%311 = fadd fast double %310, %304
%312 = select fast i1 %264, double %311, double %310
store double %312, double* %309, align 8
%313 = getelementptr inbounds [2 x double], [2 x double]* %"'de10", i32 0, i32 1
%314 = load double, double* %313, align 8
%315 = fadd fast double %314, %306
%316 = select fast i1 %264, double %315, double %314
store double %316, double* %313, align 8
%317 = load i8, i8* %_cache139, align 1, !invariant.group !934
switch i8 %317, label %inverttop [
i8 0, label %invertL29
i8 1, label %invertL99
i8 2, label %invertL111.loopexit
i8 3, label %invertL50
]
}
Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = MyMutableStruct, literal_rt = MyMixedStruct, rettype = BatchMixedDuplicated{MyMixedStruct, 2}, sret_union=false, pactualRetType=MyMutableStruct
Stacktrace:
[1] create_abi_wrapper(enzymefn::LLVM.Function, TT::Type, rettype::Type, actualRetType::Type, Mode::Enzyme.API.CDerivativeMode, augmented::Ptr{Nothing}, width::Int64, returnPrimal::Bool, shadow_init::Bool, world::UInt64, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing})
@ Enzyme.Compiler ~/Documents/GitHub/Julia/Enzyme.jl/src/compiler.jl:1992
[2] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{N, Bool} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{Int64}, boxedArgs::Set{Int64})
@ Enzyme.Compiler ~/Documents/GitHub/Julia/Enzyme.jl/src/compiler.jl:1737
[3] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/Documents/GitHub/Julia/Enzyme.jl/src/compiler.jl:4669
[4] codegen
@ ~/Documents/GitHub/Julia/Enzyme.jl/src/compiler.jl:3455 [inlined]
[5] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/Documents/GitHub/Julia/Enzyme.jl/src/compiler.jl:5533
[6] _thunk
@ ~/Documents/GitHub/Julia/Enzyme.jl/src/compiler.jl:5533 [inlined]
[7] cached_compilation
@ ~/Documents/GitHub/Julia/Enzyme.jl/src/compiler.jl:5585 [inlined]
[8] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{<:Annotation}, A::Type{<:Annotation}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{N, Bool} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{Any})
@ Enzyme.Compiler ~/Documents/GitHub/Julia/Enzyme.jl/src/compiler.jl:5696
[9] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{N, Bool} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
@ Enzyme.Compiler ~/Documents/GitHub/Julia/Enzyme.jl/src/compiler.jl:5881
[10] autodiff_thunk
@ ~/Documents/GitHub/Julia/Enzyme.jl/src/Enzyme.jl:981 [inlined]
[11] autodiff(::ReverseModeSplit{true, true, false, 0, true, FFIABI, false, false, false}, ::Const{typeof(f5)}, ::BatchSeed{2, MyMixedStruct}, ::BatchDuplicated{Vector{Float64}, 2}, ::Active{Float64})
@ Enzyme ~/Documents/GitHub/Julia/Enzyme.jl/src/sugar.jl:1240
[12] macro expansion
@ ~/Documents/GitHub/Julia/Enzyme.jl/test/seeded.jl:120 [inlined]
[13] macro expansion
@ ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/Test/src/Test.jl:1704 [inlined]
[14] macro expansion
@ ~/Documents/GitHub/Julia/Enzyme.jl/test/seeded.jl:119 [inlined]
[15] macro expansion
@ ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/Test/src/Test.jl:1793 [inlined]
[16] validate_seeded_autodiff(f::typeof(f5), dz::MyMixedStruct, dzs::Tuple{MyMixedStruct, MyMixedStruct})
@ Main ~/Documents/GitHub/Julia/Enzyme.jl/test/seeded.jl:100
[17] macro expansion
@ ~/Documents/GitHub/Julia/Enzyme.jl/test/seeded.jl:162 [inlined]
[18] macro expansion
@ ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/Test/src/Test.jl:1704 [inlined]
[19] top-level scope
@ ~/Documents/GitHub/Julia/Enzyme.jl/test/seeded.jl:162
something very strange is happening where it thinks it was going to return a MyMutableStruct but is actually returning a MyMixedStruct (or vice versa)
My bad, that was a typo, should be fixed
But the fact that i was thrown off by it tells me that we should probably error more nicely when the seed is not coherent with the return type. The thing is, we can't know the return type in advance?
go for it, I think that return type check would need to be in your new autodiff then
Still struggling with the last test case, the one with output
struct MyMixedStruct
bar::Float64
foo::Vector{Float64}
end
Right now the logic is that for mutable output types (RA <: Duplicated), I use Compiler.recursive_accumulate to update the shadow result in-place. What should we do with MixedDuplicated though?