Fix hessian sret type
requires https://github.com/EnzymeAD/Enzyme/pull/2592
Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.
Click here to view the suggested changes.
diff --git a/src/compiler.jl b/src/compiler.jl
index bec4673b..6fad4576 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -1059,15 +1059,15 @@ end
EnumAttribute("willreturn"),
EnumAttribute("nosync"),
EnumAttribute("nofree"),
- StringAttribute("enzyme_preserve_primal", "*"),
+ StringAttribute("enzyme_preserve_primal", "*"),
]
else
LLVM.Attribute[EnumAttribute("memory", NoEffects.data), StringAttribute("enzyme_shouldrecompute"),
EnumAttribute("willreturn"),
EnumAttribute("nosync"),
- EnumAttribute("nofree"),
- StringAttribute("enzyme_preserve_primal", "*"),
- ]
+ EnumAttribute("nofree"),
+ StringAttribute("enzyme_preserve_primal", "*"),
+ ]
end
handleCustom(state, custom, k_name, llvmfn, name, attrs)
return
@@ -6751,11 +6751,11 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri
for f in functions(mod)
for i in 1:length(parameters(f))
for a in collect(parameter_attributes(f, i))
- if kind(a) == "enzyme_sret"
- API.EnzymeDumpValueRef(f)
- end
- @assert kind(a) != "enzyme_sret"
- @assert kind(a) != "enzyme_sret_v"
+ if kind(a) == "enzyme_sret"
+ API.EnzymeDumpValueRef(f)
+ end
+ @assert kind(a) != "enzyme_sret"
+ @assert kind(a) != "enzyme_sret_v"
end
end
end
@@ -6765,7 +6765,7 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri
if DumpPrePostOpt[]
API.EnzymeDumpModuleRef(mod.ref)
end
- post_optimize!(mod, JIT.get_tm(); callconv=false)
+ post_optimize!(mod, JIT.get_tm(); callconv = false)
if DumpPostOpt[]
API.EnzymeDumpModuleRef(mod.ref)
end
diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl
index 6cfb5959..da3b1895 100644
--- a/src/compiler/optimize.jl
+++ b/src/compiler/optimize.jl
@@ -380,7 +380,7 @@ const DumpPostCallConv = Ref(false)
function fixup_callconv!(mod::LLVM.Module, tm::LLVM.TargetMachine)
addr13NoAlias(mod)
- removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)
+ removeDeadArgs!(mod, tm, #=post_gc_fixup=# false)
memcpy_sret_split!(mod)
# if we did the move_sret_tofrom_roots, we will have loaded out of the sret, then stored into the rooted.
@@ -436,17 +436,17 @@ function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool
if callconv
fixup_callconv!(mod, tm)
end
-
+
for f in functions(mod)
- if isempty(blocks(f))
- continue
- end
- if has_fn_attr(f, StringAttribute("enzyme_preserve_primal"))
- delete!(LLVM.function_attributes(f), StringAttribute("enzyme_preserve_primal"))
- end
+ if isempty(blocks(f))
+ continue
+ end
+ if has_fn_attr(f, StringAttribute("enzyme_preserve_primal"))
+ delete!(LLVM.function_attributes(f), StringAttribute("enzyme_preserve_primal"))
+ end
end
- removeDeadArgs!(mod, tm, #=post_gc_fixup=#true)
+ removeDeadArgs!(mod, tm, #=post_gc_fixup=# true)
@dispose pb = NewPMPassBuilder() begin
registerEnzymeAndPassPipeline!(pb)
diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl
index 00381caf..d557abf9 100644
--- a/src/llvm/transforms.jl
+++ b/src/llvm/transforms.jl
@@ -2527,14 +2527,14 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup
)
for u in LLVM.uses(fn)
u = LLVM.user(u)
- if !isa(u, LLVM.CallInst)
- msg = sprint() do io
- println(io, "Unknown user of fn: ", string(u))
- println(io, "fn: ", string(fn))
- println(io, "mod: ", string(parent(fn)))
- end
- throw(AssertionError(msg))
- end
+ if !isa(u, LLVM.CallInst)
+ msg = sprint() do io
+ println(io, "Unknown user of fn: ", string(u))
+ println(io, "fn: ", string(fn))
+ println(io, "mod: ", string(parent(fn)))
+ end
+ throw(AssertionError(msg))
+ end
B = IRBuilder()
nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u))
position!(B, nextInst)
Codecov Report
:white_check_mark: All modified and coverable lines are covered by tests.
:white_check_mark: Project coverage is 48.00%. Comparing base (d64497a) to head (dd3f4eb).
:warning: Report is 5 commits behind head on main.
:exclamation: There is a different number of reports uploaded between BASE (d64497a) and HEAD (dd3f4eb). Click for more details.
HEAD has 34 uploads less than BASE
Flag BASE (d64497a) HEAD (dd3f4eb) 38 4
Additional details and impacted files
@@ Coverage Diff @@
## main #2824 +/- ##
===========================================
- Coverage 67.53% 48.00% -19.53%
===========================================
Files 58 13 -45
Lines 21051 1256 -19795
===========================================
- Hits 14217 603 -13614
+ Misses 6834 653 -6181
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
Benchmark Results
| main | dd3f4eb2d82acd... | main / dd3f4eb2d82acd... | |
|---|---|---|---|
| basics/make_zero/namedtuple | 0.049 ± 0.0096 μs | 0.0492 ± 0.011 μs | 0.996 ± 0.29 |
| basics/make_zero/struct | 0.252 ± 0.0078 μs | 0.251 ± 0.0067 μs | 1.01 ± 0.041 |
| basics/overhead | 3.18 ± 0.003 ns | 3.46 ± 0.003 ns | 0.917 ± 0.0012 |
| basics/remake_zero!/namedtuple | 0.222 ± 0.007 μs | 0.219 ± 0.011 μs | 1.01 ± 0.059 |
| basics/remake_zero!/struct | 0.223 ± 0.011 μs | 0.225 ± 0.0091 μs | 0.992 ± 0.063 |
| fold_broadcast/multidim_sum_bcast/1D | 10.9 ± 0.24 μs | 10.9 ± 0.24 μs | 1 ± 0.031 |
| fold_broadcast/multidim_sum_bcast/2D | 12.2 ± 0.32 μs | 12.2 ± 0.35 μs | 1 ± 0.039 |
| time_to_load | 0.983 ± 0.0081 s | 0.978 ± 0.0058 s | 1 ± 0.01 |
Benchmark Plots
A plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/20237959906/artifacts/4873816709.
@copilot isolate a MWE of whatever test is timing out
@wsmoses I've opened a new pull request, #2848, to work on those changes. Once the pull request is ready, I'll request review from you.